In [3]:
#from train_encoder_script.py

from lib.restartable_pendulum import RestartablePendulumEnv
from lib.state_rep import train_encoder
import numpy as np
from matplotlib import pyplot as plt
import itertools
import sys

def main():
    
    for arg in sys.argv:
        if arg.startswith('--job='):
            job_iter = int(arg.split('--job=')[1]) - 1
    
    #added
    job_iter = 0
    
    # specify environment information
    env = RestartablePendulumEnv()
    state_dim = 3
    act_dim = 1
    
    # specify training details to loop over
    archs = [[64], [64,64], [64,64,64], [128], [128, 128], [128,128,128], [300], [300,300]]
    traj_lens = [5,10,20]
    param_lists = [archs, traj_lens]
    traj_type="drive"
    
    i = job_iter
    tup = list(itertools.product(*param_lists))[i]
    
    #print(total_models)
    
    #for i,tup in enumerate(itertools.product(*param_lists)): # loop over the various architectures
    #    print("\nStarting {0} of {1} representations\n".format(i+1,total_models))

    parameters = {
        "n_episodes" : 3*20000,
        "n_passes" : 1,
        "batch_size" : 100,
        "learning_rate" : 1e-3,
        "widths" : tup[0],
        "traj_len" : tup[1]
    }

    widths = parameters["widths"]
    traj_len = parameters["traj_len"]
    save_dir = "./experiments/state_rep_params/pendulum/{}".format(i+32)
    n_episodes = parameters["n_episodes"]
    n_passes = parameters["n_passes"]
    batch_size = parameters["batch_size"]
    learning_rate = parameters["learning_rate"]    

    init_projectors=None
    init_weights=None
    init_biases=None

    # generate the seeds for the training trajectories
    start_states = [np.array([(np.random.rand(1)[0]*2 - 1)*np.pi, (np.random.rand(1)[0]*2 - 1)*8]) 
                    for _ in range(n_episodes)]
    start_actions = [np.random.rand(1)*4-2 for _ in range(n_episodes)]


    projectors,weights,biases,losses = train_encoder(env, start_states, start_actions, traj_len, n_passes, 
                                                     state_dim, act_dim, widths,
                                                     traj_type=traj_type,
                                                     learning_rate=learning_rate,
                                                     init_projectors=init_projectors,
                                                     init_weights=init_weights,
                                                     init_biases=init_biases,
                                                     batch_size = batch_size,
                                                     save_dir = save_dir,
                                                     show_progress=False,
                                                     track_loss_every = int(n_episodes/(batch_size*200)))

    # save the representation weights
    np.savez(save_dir + "projectors.npz",*projectors)
    np.savez(save_dir + "weights.npz",*weights)
    np.savez(save_dir + "biases.npz",*biases)

    # save the training params
    with open(save_dir + "train_params.txt","w") as f:
        for tup in parameters.items():
            f.write(" ".join([str(v) for v in tup]))
            f.write("\n")


    plt.plot(losses)
    plt.savefig(save_dir + "losses.png")
    plt.clf()
        

In [2]:
main()



UnboundLocalError: local variable 'job_iter' referenced before assignment

In [7]:
from lib.restartable_pendulum import RestartablePendulumEnv
from lib.state_rep_torch import train_encoder
import gym
import numpy as np
from matplotlib import pyplot as plt
import itertools
import sys

def main2():
    
    for arg in sys.argv:
        if arg.startswith('--job='):
            i = int(arg.split('--job=')[1])-1
    
    #added
    i=0
    
    # specify environment information
    env = RestartablePendulumEnv()
    state_dim = 3
    act_dim = 1
    
    # specify training details to loop over
    jobs = [2, 5, 6, 8, 9, 10, 13, 20, 24, 28, 32]
    archs = [[state_dim]+arch for arch in [[128],
                                           [256],
                                           [512],
                                           [1024],
                                           [128,128],
                                           [256,256],
                                           [512,512],
                                           [512,256],
                                           [512,128]
                                          ]]
    traj_lens = [20]
    lrs = [.0001, .0005, .001, .005]
    param_lists = [archs, traj_lens, lrs]
    
    
    tup = list(itertools.product(*param_lists))[jobs[i]]
    
    parameters = {
        "n_episodes" :30000,
        "batch_size" : 50,
        "learning_rate" : tup[2],
        "widths" : tup[0],
        "traj_len" : tup[1]
    }

    widths = parameters["widths"]
    traj_len = parameters["traj_len"]
    save_dir = "./experiments/extra_train_exps/{}".format(i)
    n_episodes = parameters["n_episodes"]
    batch_size = parameters["batch_size"]
    learning_rate = parameters["learning_rate"]    

    params, losses = train_encoder(env, traj_len, state_dim, act_dim, widths, n_episodes,
                                   lr=learning_rate,
                                   batch_size = batch_size,
                                   show_progress=False,
                                   track_loss_every = 10,
                                   drift=True
                                  )

    weights = [w for w in params[:2*(len(widths)-1):2]]
    biases = [b.flatten() for b in params[1:2*(len(widths)-1):2]]
    projectors = params[2*(len(widths)-1):2*(len(widths)-1)+traj_len+1]
                            
    
    # save the representation weights
    np.savez(save_dir + "projectors.npz",*projectors)
    np.savez(save_dir + "weights.npz",*weights)
    np.savez(save_dir + "biases.npz",*biases)
    
    # save the training params
    with open(save_dir + "train_params.txt","w") as f:
        for tup in parameters.items():
            f.write(" ".join([str(v) for v in tup]))
            f.write("\n")


    plt.plot(losses)
    plt.savefig(save_dir + "losses.png")
    plt.clf()

In [8]:
main2()



TypeError: train_encoder() got multiple values for argument 'lr'

In [None]:
import sys
from lib.restartable_pendulum import RestartablePendulumEnv
from lib import state_rep_torch as srt
import gym
import numpy as np
from matplotlib import pyplot as plt
import torch
from lib import utils
from lib import encoder_wrappers as ew

# specify environment information
n_repeats = 3 # step the environment this many times for each action, concatenate the pixel observations
env = RestartablePendulumEnv(repeats=n_repeats,pixels=True)


#nonlin = torch.nn.ELU()
nonlin = torch.nn.functional.relu
layers = [50, 10, 5] # architecture of encoder after the 2 conv layers
save_dir = "./"
n_episodes = 100000 # total batches to draw
batch_size = 25
learning_rate = .001
save_every = int(n_episodes/4) # save the model every so often

encnet = srt.ConvEncoderNet(layers,env.observation_space.shape[1:],sigma=nonlin)

# use the following commented out lines for PredictorNet (I changed deterministic sampling though...
#prednet = srt.PredictorNet(encnet,T,layers[-1],1)
#deterministic_args = (samples[i], batch_size, 35, method, n_repeats,T) 

#prednet = srt.ForwardNet(encnet,layers[-1],1)
prednet = srt.PiecewiseForwardNet(encnet,layers[-1],1,2)
deterministic_args = None 

traj_sampler = srt.SimpleTrajectorySampler(env,
                                     srt.sample_pendulum_action_batch,
                                     srt.sample_pendulum_state_batch_old,
                                     device=torch.device("cpu"),
                                     deterministic=False,
                                     deterministic_args=deterministic_args)

net, losses = srt.train_encoder(prednet,traj_sampler,n_episodes,
                                batch_size=batch_size,
                                track_loss_every=int(n_episodes/100),
                                lr=learning_rate,
                                save_every=save_every,
                                save_path=save_dir)

torch.save(net,save_dir+"net")



# what follows is code to visualize the representations
d = 5 # must match the final entry in layers
n_samps = 500
env = ew.TorchEncoderWrapper(env,net.encoder,np.eye(d))
X = np.empty((n_samps,d))
for i,ang in enumerate(np.linspace(0,2*np.pi,n_samps)): # go through the angles from 0 to 2pi
    X[i,:] = env.reset(state=[ang,0])
utils.visualize_trajectory(X)


  allow_unreachable=True)  # allow_unreachable flag


Epoch Completion: 88.000%, Loss: 0.066

In [1]:

import sys
from lib.restartable_pendulum import RestartablePendulumEnv
from lib import state_rep_torch as srt
import gym
import numpy as np
from matplotlib import pyplot as plt
import torch
from lib import utils
from lib import encoder_wrappers as ew

# specify environment information
n_repeats = 3 # step the environment this many times for each action, concatenate the pixel observations
env = RestartablePendulumEnv(repeats=n_repeats,pixels=True)


#nonlin = torch.nn.ELU()
nonlin = torch.nn.functional.relu
layers = [50, 10, 5] # architecture of encoder after the 2 conv layers
save_dir = "./"
n_episodes = 100000 # total batches to draw
batch_size = 25
learning_rate = .001
save_every = int(n_episodes/4) # save the model every so often

encnet = srt.ConvEncoderNet(layers,env.observation_space.shape[1:],sigma=nonlin)

# use the following commented out lines for PredictorNet (I changed deterministic sampling though...
#prednet = srt.PredictorNet(encnet,T,layers[-1],1)
#deterministic_args = (samples[i], batch_size, 35, method, n_repeats,T) 

#prednet = srt.ForwardNet(encnet,layers[-1],1)
prednet = srt.PiecewiseForwardNet(encnet,layers[-1],1,2)
deterministic_args = None 

traj_sampler = srt.SimpleTrajectorySampler(env,
                                     srt.sample_pendulum_action_batch,
                                     srt.sample_pendulum_state_batch_old,
                                     device=torch.device("cpu"),
                                     deterministic=False,
                                     deterministic_args=deterministic_args)

net3 = torch.load("_3.0net")

In [2]:
d = 5 # must match the final entry in layers
n_samps = 500
env = ew.TorchEncoderWrapper(env,net3.encoder,np.eye(d))
X = np.empty((n_samps,d))
for i,ang in enumerate(np.linspace(0,2*np.pi,n_samps)): # go through the angles from 0 to 2pi
    X[i,:] = env.reset(state=[ang,0])
utils.visualize_trajectory(X)

  self.explained_variance_ratio_ = exp_var / full_var
  self.explained_variance_ratio_ = exp_var / full_var


In [3]:
X

array([[ 1.69577461e-07, -2.51884460e-02,  4.55234947e-11,
         9.04095410e-10,  9.53646606e-10],
       [ 1.69577461e-07, -2.51884460e-02,  4.55234947e-11,
         9.04095410e-10,  9.53646606e-10],
       [ 1.69577461e-07, -2.51884460e-02,  4.55234947e-11,
         9.04095410e-10,  9.53646606e-10],
       ...,
       [ 1.69577461e-07, -2.51884460e-02,  4.55234947e-11,
         9.04095410e-10,  9.53646606e-10],
       [ 1.69577461e-07, -2.51884460e-02,  4.55234947e-11,
         9.04095410e-10,  9.53646606e-10],
       [ 1.69577461e-07, -2.51884460e-02,  4.55234947e-11,
         9.04095410e-10,  9.53646606e-10]])

In [6]:
import sys
from lib.restartable_pendulum import RestartablePendulumEnv
from lib import state_rep_torch as srt
import gym
import numpy as np
from matplotlib import pyplot as plt
import torch
from lib import utils
from lib import encoder_wrappers as ew

# specify environment information
n_repeats = 3 # step the environment this many times for each action, concatenate the pixel observations
env = RestartablePendulumEnv(repeats=n_repeats,pixels=True)


#nonlin = torch.nn.ELU()
nonlin = torch.nn.functional.relu
layers = [50, 10, 5] # architecture of encoder after the 2 conv layers
save_dir = "./"
n_episodes = 1000 # total batches to draw
batch_size = 25
learning_rate = .001
save_every = int(n_episodes/4) # save the model every so often

encnet = srt.ConvEncoderNet(layers,env.observation_space.shape[1:],sigma=nonlin)

# use the following commented out lines for PredictorNet (I changed deterministic sampling though...
#prednet = srt.PredictorNet(encnet,T,layers[-1],1)
#deterministic_args = (samples[i], batch_size, 35, method, n_repeats,T) 

#prednet = srt.ForwardNet(encnet,layers[-1],1)
prednet = srt.PiecewiseForwardNet(encnet,layers[-1],1,2)
deterministic_args = None 

traj_sampler = srt.SimpleTrajectorySampler(env,
                                     srt.sample_pendulum_action_batch,
                                     srt.sample_pendulum_state_batch_old,
                                     device=torch.device("cpu"),
                                     deterministic=False,
                                     deterministic_args=deterministic_args)

net, losses = srt.train_encoder(prednet,traj_sampler,n_episodes,
                                batch_size=batch_size,
                                track_loss_every=int(n_episodes/100),
                                lr=learning_rate,
                                save_every=save_every,
                                save_path=save_dir)

torch.save(net,save_dir+"net")



# what follows is code to visualize the representations
d = 5 # must match the final entry in layers
n_samps = 500
env = ew.TorchEncoderWrapper(env,net.encoder,np.eye(d))
X = np.empty((n_samps,d))
for i,ang in enumerate(np.linspace(0,2*np.pi,n_samps)): # go through the angles from 0 to 2pi
    X[i,:] = env.reset(state=[ang,0])
utils.visualize_trajectory(X)

Epoch Completion: 100.000%, Loss: 0.071

In [1]:
import sys
from lib.restartable_pendulum import RestartablePendulumEnv
from lib import state_rep_torch as srt
import gym
import numpy as np
from matplotlib import pyplot as plt
import torch
from lib import utils
from lib import encoder_wrappers as ew

# specify environment information
n_repeats = 3 # step the environment this many times for each action, concatenate the pixel observations
env = RestartablePendulumEnv(repeats=n_repeats,pixels=True)


#nonlin = torch.nn.ELU()
nonlin = torch.nn.functional.relu
layers = [50, 10, 5] # architecture of encoder after the 2 conv layers
save_dir = "./"
n_episodes = 500 # total batches to draw
batch_size = 25
learning_rate = .001
save_every = int(n_episodes/2) # save the model every so often

encnet = srt.ConvEncoderNet(layers,env.observation_space.shape[1:],sigma=nonlin)

# use the following commented out lines for PredictorNet (I changed deterministic sampling though...
#prednet = srt.PredictorNet(encnet,T,layers[-1],1)
#deterministic_args = (samples[i], batch_size, 35, method, n_repeats,T) 

# 2*state dim+action dim
rnet = srt.EncoderNet([2*5+1, 50, 10, 1])

#prednet = srt.ForwardNet(encnet,layers[-1],1)
prednet = srt.PiecewiseForwardNet(encnet,layers[-1],1,2,fit_reward=True,mu=1, r_encoder = rnet,alpha=1)
deterministic_args = None 



traj_sampler = srt.SimpleTrajectorySampler(env,
                                     srt.sample_pendulum_action_batch,
                                     srt.sample_pendulum_state_batch_old,
                                     device=torch.device("cpu"),
                                     deterministic=False,
                                     deterministic_args=deterministic_args,
                                          output_rewards=True)

net, losses = srt.train_encoder(prednet,traj_sampler,n_episodes,
                                batch_size=batch_size,
                                track_loss_every=int(n_episodes/100),
                                lr=learning_rate,
                                save_every=save_every,
                                save_path=save_dir)

torch.save(net,save_dir+".net")



# what follows is code to visualize the representations
# = 5 # must match the final entry in layers
#_samps = 500
#nv = ew.TorchEncoderWrapper(env,net.encoder,np.eye(d))
# = np.empty((n_samps,d))
#or i,ang in enumerate(np.linspace(0,2*np.pi,n_samps)): # go through the angles from 0 to 2pi
#   X[i,:] = env.reset(state=[ang,0])
#tils.visualize_trajectory(X)


[tensor([[-0.1054, -0.2070,  0.0709, -0.1421, -0.1375],
        [-0.1044, -0.2060,  0.0715, -0.1445, -0.1340],
        [-0.1025, -0.2048,  0.0740, -0.1428, -0.1361],
        [-0.1040, -0.2048,  0.0739, -0.1422, -0.1346],
        [-0.1073, -0.2075,  0.0714, -0.1391, -0.1334],
        [-0.1084, -0.2078,  0.0706, -0.1434, -0.1331],
        [-0.1083, -0.2075,  0.0696, -0.1447, -0.1354],
        [-0.1062, -0.2063,  0.0714, -0.1428, -0.1349],
        [-0.1041, -0.2062,  0.0719, -0.1438, -0.1370],
        [-0.1066, -0.2044,  0.0705, -0.1461, -0.1354],
        [-0.1100, -0.2079,  0.0665, -0.1399, -0.1331],
        [-0.1122, -0.2087,  0.0682, -0.1417, -0.1324],
        [-0.1100, -0.2081,  0.0700, -0.1418, -0.1320],
        [-0.1046, -0.2049,  0.0731, -0.1432, -0.1356],
        [-0.1101, -0.2080,  0.0686, -0.1426, -0.1314],
        [-0.1097, -0.2072,  0.0676, -0.1415, -0.1333],
        [-0.1044, -0.2052,  0.0714, -0.1427, -0.1369],
        [-0.1057, -0.2068,  0.0712, -0.1413, -0.1339],
        [

  allow_unreachable=True)  # allow_unreachable flag


[tensor([[-0.0819, -0.1786,  0.0884, -0.1459, -0.1152],
        [-0.0810, -0.1788,  0.0909, -0.1429, -0.1147],
        [-0.0763, -0.1737,  0.0892, -0.1463, -0.1160],
        [-0.0817, -0.1783,  0.0877, -0.1459, -0.1154],
        [-0.0805, -0.1759,  0.0890, -0.1479, -0.1138],
        [-0.0777, -0.1755,  0.0896, -0.1443, -0.1174],
        [-0.0862, -0.1828,  0.0848, -0.1428, -0.1137],
        [-0.0816, -0.1789,  0.0863, -0.1447, -0.1164],
        [-0.0795, -0.1760,  0.0885, -0.1460, -0.1146],
        [-0.0823, -0.1775,  0.0851, -0.1471, -0.1152],
        [-0.0769, -0.1759,  0.0890, -0.1423, -0.1180],
        [-0.0804, -0.1764,  0.0877, -0.1484, -0.1158],
        [-0.0798, -0.1767,  0.0875, -0.1456, -0.1147],
        [-0.0804, -0.1773,  0.0890, -0.1448, -0.1153],
        [-0.0777, -0.1762,  0.0909, -0.1433, -0.1159],
        [-0.0784, -0.1758,  0.0887, -0.1456, -0.1155],
        [-0.0829, -0.1797,  0.0874, -0.1442, -0.1140],
        [-0.0792, -0.1753,  0.0877, -0.1471, -0.1150],
        [

[tensor([[-0.0821, -0.1667,  0.0890, -0.1480, -0.0845],
        [-0.0805, -0.1646,  0.0899, -0.1499, -0.0855],
        [-0.0801, -0.1641,  0.0888, -0.1506, -0.0868],
        [-0.0817, -0.1661,  0.0878, -0.1488, -0.0861],
        [-0.0816, -0.1661,  0.0893, -0.1485, -0.0847],
        [-0.0817, -0.1662,  0.0888, -0.1486, -0.0852],
        [-0.0820, -0.1666,  0.0895, -0.1480, -0.0841],
        [-0.0795, -0.1633,  0.0895, -0.1512, -0.0869],
        [-0.0822, -0.1668,  0.0883, -0.1480, -0.0850],
        [-0.0809, -0.1652,  0.0899, -0.1494, -0.0851],
        [-0.0814, -0.1659,  0.0905, -0.1486, -0.0839],
        [-0.0804, -0.1645,  0.0889, -0.1501, -0.0864],
        [-0.0806, -0.1646,  0.0871, -0.1503, -0.0878],
        [-0.0805, -0.1646,  0.0888, -0.1501, -0.0864],
        [-0.0800, -0.1640,  0.0894, -0.1506, -0.0865],
        [-0.0780, -0.1614,  0.0894, -0.1531, -0.0886],
        [-0.0825, -0.1673,  0.0886, -0.1475, -0.0843],
        [-0.0810, -0.1651,  0.0866, -0.1499, -0.0879],
        [

[tensor([[-0.1069, -0.1891,  0.0699, -0.1172, -0.0661],
        [-0.1063, -0.1883,  0.0719, -0.1176, -0.0650],
        [-0.1080, -0.1905,  0.0703, -0.1158, -0.0645],
        [-0.1031, -0.1841,  0.0720, -0.1216, -0.0683],
        [-0.1037, -0.1848,  0.0717, -0.1210, -0.0680],
        [-0.1057, -0.1875,  0.0736, -0.1181, -0.0641],
        [-0.1064, -0.1885,  0.0721, -0.1174, -0.0647],
        [-0.1059, -0.1877,  0.0700, -0.1185, -0.0671],
        [-0.1058, -0.1875,  0.0693, -0.1188, -0.0680],
        [-0.1079, -0.1904,  0.0702, -0.1159, -0.0648],
        [-0.1065, -0.1886,  0.0713, -0.1175, -0.0653],
        [-0.1055, -0.1873,  0.0720, -0.1186, -0.0657],
        [-0.1058, -0.1877,  0.0723, -0.1182, -0.0651],
        [-0.1067, -0.1888,  0.0703, -0.1174, -0.0659],
        [-0.1042, -0.1856,  0.0730, -0.1200, -0.0663],
        [-0.1063, -0.1882,  0.0708, -0.1179, -0.0660],
        [-0.1074, -0.1897,  0.0695, -0.1167, -0.0659],
        [-0.1066, -0.1887,  0.0720, -0.1172, -0.0645],
        [

[tensor([[-0.1089, -0.1852,  0.0615, -0.1156, -0.0675],
        [-0.1110, -0.1880,  0.0578, -0.1137, -0.0687],
        [-0.1108, -0.1877,  0.0592, -0.1136, -0.0676],
        [-0.1084, -0.1846,  0.0602, -0.1164, -0.0692],
        [-0.1090, -0.1853,  0.0593, -0.1159, -0.0695],
        [-0.1087, -0.1849,  0.0599, -0.1162, -0.0693],
        [-0.1103, -0.1870,  0.0595, -0.1143, -0.0680],
        [-0.1078, -0.1837,  0.0606, -0.1172, -0.0696],
        [-0.1090, -0.1853,  0.0608, -0.1157, -0.0682],
        [-0.1092, -0.1855,  0.0603, -0.1155, -0.0684],
        [-0.1104, -0.1871,  0.0609, -0.1139, -0.0665],
        [-0.1110, -0.1879,  0.0604, -0.1132, -0.0663],
        [-0.1103, -0.1871,  0.0601, -0.1141, -0.0673],
        [-0.1097, -0.1862,  0.0609, -0.1147, -0.0673],
        [-0.1101, -0.1867,  0.0593, -0.1145, -0.0683],
        [-0.1119, -0.1892,  0.0605, -0.1120, -0.0652],
        [-0.1120, -0.1893,  0.0602, -0.1120, -0.0654],
        [-0.1096, -0.1861,  0.0589, -0.1153, -0.0693],
        [

[tensor([[-0.1041, -0.1713,  0.0529, -0.1223, -0.0752],
        [-0.1001, -0.1659,  0.0549, -0.1269, -0.0777],
        [-0.1013, -0.1676,  0.0531, -0.1258, -0.0781],
        [-0.1016, -0.1680,  0.0526, -0.1255, -0.0782],
        [-0.0999, -0.1657,  0.0542, -0.1273, -0.0785],
        [-0.0996, -0.1653,  0.0553, -0.1274, -0.0778],
        [-0.1002, -0.1661,  0.0543, -0.1269, -0.0781],
        [-0.1019, -0.1684,  0.0540, -0.1248, -0.0766],
        [-0.1029, -0.1697,  0.0531, -0.1238, -0.0764],
        [-0.0985, -0.1639,  0.0549, -0.1289, -0.0794],
        [-0.1032, -0.1701,  0.0529, -0.1234, -0.0762],
        [-0.1006, -0.1667,  0.0535, -0.1265, -0.0784],
        [-0.0992, -0.1648,  0.0538, -0.1283, -0.0797],
        [-0.1019, -0.1684,  0.0535, -0.1249, -0.0770],
        [-0.1022, -0.1688,  0.0532, -0.1246, -0.0770],
        [-0.0994, -0.1650,  0.0543, -0.1279, -0.0790],
        [-0.1005, -0.1665,  0.0547, -0.1264, -0.0774],
        [-0.1007, -0.1668,  0.0541, -0.1263, -0.0777],
        [

[tensor([[-0.1036, -0.1643,  0.0545, -0.1219, -0.0703],
        [-0.1008, -0.1605,  0.0548, -0.1255, -0.0732],
        [-0.0999, -0.1593,  0.0555, -0.1264, -0.0735],
        [-0.0991, -0.1582,  0.0572, -0.1271, -0.0727],
        [-0.1010, -0.1608,  0.0555, -0.1251, -0.0723],
        [-0.0993, -0.1585,  0.0566, -0.1270, -0.0731],
        [-0.1005, -0.1601,  0.0562, -0.1255, -0.0721],
        [-0.0993, -0.1587,  0.0547, -0.1273, -0.0748],
        [-0.1010, -0.1609,  0.0556, -0.1250, -0.0721],
        [-0.0986, -0.1575,  0.0566, -0.1279, -0.0739],
        [-0.1041, -0.1650,  0.0537, -0.1215, -0.0706],
        [-0.0978, -0.1565,  0.0571, -0.1288, -0.0742],
        [-0.0998, -0.1592,  0.0560, -0.1265, -0.0731],
        [-0.1005, -0.1602,  0.0553, -0.1257, -0.0730],
        [-0.1022, -0.1624,  0.0549, -0.1236, -0.0715],
        [-0.0980, -0.1568,  0.0565, -0.1286, -0.0745],
        [-0.1002, -0.1598,  0.0565, -0.1258, -0.0721],
        [-0.1008, -0.1604,  0.0573, -0.1250, -0.0708],
        [

[tensor([[-0.0984, -0.1463,  0.0656, -0.1247, -0.0585],
        [-0.1003, -0.1490,  0.0645, -0.1225, -0.0575],
        [-0.0998, -0.1483,  0.0655, -0.1229, -0.0570],
        [-0.0987, -0.1468,  0.0660, -0.1242, -0.0578],
        [-0.0986, -0.1467,  0.0654, -0.1245, -0.0585],
        [-0.1008, -0.1497,  0.0644, -0.1219, -0.0571],
        [-0.1025, -0.1518,  0.0661, -0.1194, -0.0536],
        [-0.1004, -0.1490,  0.0651, -0.1224, -0.0569],
        [-0.1008, -0.1496,  0.0647, -0.1219, -0.0567],
        [-0.0986, -0.1465,  0.0669, -0.1242, -0.0570],
        [-0.0980, -0.1457,  0.0663, -0.1251, -0.0582],
        [-0.0986, -0.1466,  0.0648, -0.1247, -0.0591],
        [-0.0994, -0.1479,  0.0643, -0.1237, -0.0587],
        [-0.1002, -0.1487,  0.0658, -0.1224, -0.0564],
        [-0.1042, -0.1541,  0.0659, -0.1173, -0.0518],
        [-0.1000, -0.1484,  0.0673, -0.1223, -0.0550],
        [-0.0995, -0.1478,  0.0662, -0.1232, -0.0567],
        [-0.1015, -0.1506,  0.0655, -0.1208, -0.0552],
        [

[tensor([[-0.0802, -0.1148,  0.0779, -0.1404, -0.0560],
        [-0.0763, -0.1092,  0.0810, -0.1446, -0.0570],
        [-0.0784, -0.1125,  0.0757, -0.1432, -0.0602],
        [-0.0794, -0.1137,  0.0778, -0.1414, -0.0570],
        [-0.0787, -0.1129,  0.0768, -0.1425, -0.0587],
        [-0.0781, -0.1121,  0.0770, -0.1432, -0.0592],
        [-0.0801, -0.1148,  0.0764, -0.1409, -0.0576],
        [-0.0747, -0.1075,  0.0775, -0.1474, -0.0623],
        [-0.0811, -0.1163,  0.0752, -0.1399, -0.0578],
        [-0.0785, -0.1126,  0.0769, -0.1427, -0.0589],
        [-0.0781, -0.1121,  0.0770, -0.1432, -0.0591],
        [-0.0783, -0.1124,  0.0768, -0.1430, -0.0592],
        [-0.0794, -0.1135,  0.0795, -0.1411, -0.0553],
        [-0.0783, -0.1122,  0.0785, -0.1426, -0.0574],
        [-0.0804, -0.1153,  0.0756, -0.1407, -0.0581],
        [-0.0752, -0.1079,  0.0792, -0.1464, -0.0601],
        [-0.0765, -0.1099,  0.0772, -0.1452, -0.0607],
        [-0.0779, -0.1116,  0.0780, -0.1433, -0.0584],
        [

[tensor([[-0.1075, -0.0870,  0.0527, -0.1152,  0.0107],
        [-0.1044, -0.0822,  0.0538, -0.1189,  0.0088],
        [-0.1062, -0.0833,  0.0561, -0.1162,  0.0141],
        [-0.1050, -0.0839,  0.0550, -0.1177,  0.0098],
        [-0.1024, -0.0744,  0.0525, -0.1225,  0.0097],
        [-0.1071, -0.0857,  0.0546, -0.1153,  0.0125],
        [-0.1049, -0.0806,  0.0555, -0.1181,  0.0130],
        [-0.1028, -0.0782,  0.0525, -0.1215,  0.0074],
        [-0.1037, -0.0813,  0.0581, -0.1186,  0.0122],
        [-0.1075, -0.0825,  0.0511, -0.1163,  0.0130],
        [-0.1063, -0.0835,  0.0535, -0.1167,  0.0116],
        [-0.0988, -0.0690,  0.0570, -0.1258,  0.0108],
        [-0.1043, -0.0841,  0.0549, -0.1184,  0.0079],
        [-0.1101, -0.0896,  0.0530, -0.1119,  0.0145],
        [-0.1053, -0.0776,  0.0531, -0.1187,  0.0142],
        [-0.1030, -0.0801,  0.0568, -0.1198,  0.0103],
        [-0.1049, -0.0840,  0.0569, -0.1172,  0.0114],
        [-0.1027, -0.0809,  0.0586, -0.1196,  0.0106],
        [

[tensor([[-0.1254, -0.0696,  0.0279, -0.0994,  0.0440],
        [-0.1175, -0.0647,  0.0288, -0.1083,  0.0313],
        [-0.1241, -0.0761,  0.0296, -0.0995,  0.0370],
        [-0.1206, -0.0691,  0.0295, -0.1042,  0.0351],
        [-0.1200, -0.0679,  0.0287, -0.1053,  0.0340],
        [-0.1199, -0.0693,  0.0303, -0.1048,  0.0340],
        [-0.1182, -0.0633,  0.0279, -0.1080,  0.0332],
        [-0.1178, -0.0610,  0.0280, -0.1087,  0.0345],
        [-0.1192, -0.0681,  0.0313, -0.1053,  0.0347],
        [-0.1189, -0.0708,  0.0302, -0.1057,  0.0304],
        [-0.1235, -0.0654,  0.0280, -0.1020,  0.0434],
        [-0.1218, -0.0702,  0.0272, -0.1035,  0.0347],
        [-0.1239, -0.0733,  0.0284, -0.1004,  0.0378],
        [-0.1218, -0.0692,  0.0279, -0.1034,  0.0362],
        [-0.1180, -0.0668,  0.0316, -0.1067,  0.0333],
        [-0.1215, -0.0723,  0.0313, -0.1023,  0.0361],
        [-0.1205, -0.0719,  0.0301, -0.1039,  0.0330],
        [-0.1187, -0.0640,  0.0270, -0.1076,  0.0329],
        [

[tensor([[-0.0844, -0.0793,  0.0366, -0.1344, -0.0442],
        [-0.0873, -0.0781,  0.0356, -0.1316, -0.0376],
        [-0.0819, -0.0709,  0.0346, -0.1385, -0.0444],
        [-0.0849, -0.0787,  0.0355, -0.1342, -0.0437],
        [-0.0873, -0.0821,  0.0352, -0.1313, -0.0415],
        [-0.0867, -0.0784,  0.0340, -0.1326, -0.0406],
        [-0.0869, -0.0793,  0.0351, -0.1320, -0.0400],
        [-0.0858, -0.0785,  0.0359, -0.1330, -0.0409],
        [-0.0878, -0.0842,  0.0369, -0.1301, -0.0406],
        [-0.0841, -0.0773,  0.0360, -0.1350, -0.0436],
        [-0.0853, -0.0771,  0.0354, -0.1339, -0.0413],
        [-0.0847, -0.0789,  0.0373, -0.1338, -0.0425],
        [-0.0865, -0.0784,  0.0367, -0.1321, -0.0385],
        [-0.0862, -0.0826,  0.0371, -0.1318, -0.0424],
        [-0.0843, -0.0787,  0.0365, -0.1346, -0.0440],
        [-0.0885, -0.0822,  0.0347, -0.1300, -0.0391],
        [-0.0876, -0.0795,  0.0338, -0.1316, -0.0399],
        [-0.0878, -0.0801,  0.0367, -0.1304, -0.0370],
        [

[tensor([[-0.0780, -0.0846,  0.0346, -0.1363, -0.0628],
        [-0.0800, -0.0844,  0.0338, -0.1342, -0.0589],
        [-0.0790, -0.0848,  0.0339, -0.1353, -0.0613],
        [-0.0775, -0.0803,  0.0352, -0.1368, -0.0593],
        [-0.0812, -0.0881,  0.0341, -0.1324, -0.0589],
        [-0.0769, -0.0824,  0.0345, -0.1376, -0.0633],
        [-0.0771, -0.0836,  0.0348, -0.1372, -0.0636],
        [-0.0786, -0.0851,  0.0346, -0.1355, -0.0618],
        [-0.0794, -0.0833,  0.0338, -0.1349, -0.0590],
        [-0.0787, -0.0845,  0.0349, -0.1354, -0.0608],
        [-0.0778, -0.0813,  0.0359, -0.1362, -0.0589],
        [-0.0800, -0.0829,  0.0345, -0.1340, -0.0566],
        [-0.0786, -0.0835,  0.0348, -0.1356, -0.0602],
        [-0.0789, -0.0862,  0.0340, -0.1352, -0.0626],
        [-0.0770, -0.0806,  0.0340, -0.1378, -0.0620],
        [-0.0785, -0.0832,  0.0346, -0.1358, -0.0604],
        [-0.0767, -0.0831,  0.0355, -0.1375, -0.0633],
        [-0.0765, -0.0794,  0.0342, -0.1384, -0.0619],
        [

[tensor([[-0.0803, -0.0842,  0.0317, -0.1271, -0.0559],
        [-0.0801, -0.0856,  0.0326, -0.1271, -0.0568],
        [-0.0802, -0.0851,  0.0324, -0.1270, -0.0562],
        [-0.0810, -0.0847,  0.0324, -0.1260, -0.0539],
        [-0.0814, -0.0845,  0.0323, -0.1256, -0.0530],
        [-0.0799, -0.0845,  0.0324, -0.1273, -0.0563],
        [-0.0810, -0.0870,  0.0329, -0.1258, -0.0556],
        [-0.0810, -0.0830,  0.0310, -0.1266, -0.0538],
        [-0.0821, -0.0836,  0.0317, -0.1249, -0.0511],
        [-0.0772, -0.0802,  0.0323, -0.1307, -0.0587],
        [-0.0817, -0.0858,  0.0318, -0.1253, -0.0539],
        [-0.0787, -0.0840,  0.0323, -0.1288, -0.0587],
        [-0.0835, -0.0827,  0.0299, -0.1240, -0.0489],
        [-0.0810, -0.0813,  0.0310, -0.1265, -0.0521],
        [-0.0820, -0.0849,  0.0318, -0.1250, -0.0525],
        [-0.0802, -0.0827,  0.0313, -0.1275, -0.0552],
        [-0.0821, -0.0854,  0.0311, -0.1251, -0.0534],
        [-0.0762, -0.0771,  0.0333, -0.1316, -0.0572],
        [

[tensor([[-0.0910, -0.0713,  0.0252, -0.1092, -0.0222],
        [-0.0922, -0.0682,  0.0238, -0.1082, -0.0179],
        [-0.0948, -0.0761,  0.0241, -0.1050, -0.0192],
        [-0.0949, -0.0758,  0.0242, -0.1048, -0.0185],
        [-0.0951, -0.0753,  0.0243, -0.1045, -0.0174],
        [-0.0955, -0.0731,  0.0231, -0.1045, -0.0158],
        [-0.0925, -0.0771,  0.0267, -0.1069, -0.0228],
        [-0.0923, -0.0774,  0.0260, -0.1074, -0.0243],
        [-0.0951, -0.0733,  0.0241, -0.1046, -0.0157],
        [-0.0905, -0.0725,  0.0262, -0.1095, -0.0235],
        [-0.0947, -0.0753,  0.0245, -0.1050, -0.0182],
        [-0.0926, -0.0752,  0.0254, -0.1072, -0.0220],
        [-0.0940, -0.0772,  0.0256, -0.1054, -0.0204],
        [-0.0907, -0.0699,  0.0242, -0.1099, -0.0227],
        [-0.0921, -0.0761,  0.0258, -0.1078, -0.0237],
        [-0.0918, -0.0736,  0.0259, -0.1080, -0.0219],
        [-0.0927, -0.0747,  0.0255, -0.1071, -0.0214],
        [-0.0956, -0.0768,  0.0239, -0.1041, -0.0181],
        [

Epoch Completion: 11.000%, Loss: 87.675[tensor([[-0.1061, -0.0629,  0.0179, -0.0871,  0.0155],
        [-0.1035, -0.0638,  0.0187, -0.0901,  0.0094],
        [-0.1039, -0.0649,  0.0189, -0.0896,  0.0094],
        [-0.1061, -0.0660,  0.0198, -0.0865,  0.0145],
        [-0.1039, -0.0644,  0.0190, -0.0895,  0.0102],
        [-0.1041, -0.0670,  0.0198, -0.0891,  0.0090],
        [-0.1040, -0.0639,  0.0202, -0.0889,  0.0120],
        [-0.1062, -0.0638,  0.0181, -0.0870,  0.0150],
        [-0.1057, -0.0652,  0.0181, -0.0876,  0.0126],
        [-0.1064, -0.0613,  0.0162, -0.0872,  0.0160],
        [-0.1045, -0.0653,  0.0188, -0.0888,  0.0105],
        [-0.1033, -0.0651,  0.0205, -0.0897,  0.0096],
        [-0.1019, -0.0620,  0.0192, -0.0919,  0.0077],
        [-0.1043, -0.0650,  0.0196, -0.0888,  0.0110],
        [-0.1045, -0.0633,  0.0191, -0.0887,  0.0127],
        [-0.1037, -0.0619,  0.0189, -0.0896,  0.0120],
        [-0.1038, -0.0677,  0.0199, -0.0895,  0.0076],
        [-0.1040, -0.064

[tensor([[-0.1055, -0.0549,  0.0158, -0.0825,  0.0210],
        [-0.1078, -0.0595,  0.0176, -0.0790,  0.0241],
        [-0.1046, -0.0534,  0.0175, -0.0828,  0.0223],
        [-0.1084, -0.0560,  0.0170, -0.0783,  0.0281],
        [-0.1092, -0.0585,  0.0172, -0.0773,  0.0279],
        [-0.1053, -0.0546,  0.0174, -0.0820,  0.0227],
        [-0.1092, -0.0588,  0.0170, -0.0774,  0.0274],
        [-0.1073, -0.0578,  0.0159, -0.0802,  0.0228],
        [-0.1048, -0.0584,  0.0186, -0.0824,  0.0191],
        [-0.1079, -0.0592,  0.0182, -0.0787,  0.0250],
        [-0.1058, -0.0589,  0.0170, -0.0818,  0.0193],
        [-0.1074, -0.0572,  0.0171, -0.0796,  0.0246],
        [-0.1048, -0.0568,  0.0185, -0.0824,  0.0205],
        [-0.1061, -0.0515,  0.0167, -0.0810,  0.0267],
        [-0.1084, -0.0582,  0.0168, -0.0784,  0.0260],
        [-0.1056, -0.0575,  0.0178, -0.0816,  0.0211],
        [-0.1061, -0.0547,  0.0177, -0.0809,  0.0246],
        [-0.1055, -0.0556,  0.0174, -0.0819,  0.0221],
        [

[tensor([[-0.0948, -0.0523,  0.0212, -0.0885,  0.0055],
        [-0.1003, -0.0574,  0.0199, -0.0820,  0.0125],
        [-0.1002, -0.0590,  0.0197, -0.0824,  0.0104],
        [-0.0968, -0.0542,  0.0207, -0.0862,  0.0079],
        [-0.0984, -0.0600,  0.0206, -0.0846,  0.0061],
        [-0.0948, -0.0548,  0.0225, -0.0881,  0.0047],
        [-0.0985, -0.0586,  0.0220, -0.0838,  0.0090],
        [-0.0983, -0.0556,  0.0204, -0.0844,  0.0099],
        [-0.0959, -0.0568,  0.0219, -0.0871,  0.0045],
        [-0.0942, -0.0544,  0.0234, -0.0886,  0.0044],
        [-0.0994, -0.0525,  0.0177, -0.0837,  0.0126],
        [-0.0990, -0.0546,  0.0189, -0.0839,  0.0110],
        [-0.0997, -0.0558,  0.0194, -0.0829,  0.0120],
        [-0.1006, -0.0559,  0.0201, -0.0815,  0.0147],
        [-0.0997, -0.0558,  0.0195, -0.0829,  0.0121],
        [-0.1003, -0.0533,  0.0179, -0.0825,  0.0141],
        [-0.0959, -0.0504,  0.0202, -0.0871,  0.0090],
        [-0.0965, -0.0541,  0.0205, -0.0866,  0.0071],
        [

[tensor([[-0.0794, -0.0610,  0.0295, -0.0991, -0.0263],
        [-0.0735, -0.0603,  0.0319, -0.1061, -0.0369],
        [-0.0777, -0.0595,  0.0306, -0.1007, -0.0275],
        [-0.0751, -0.0609,  0.0307, -0.1044, -0.0349],
        [-0.0780, -0.0626,  0.0306, -0.1007, -0.0299],
        [-0.0763, -0.0615,  0.0318, -0.1024, -0.0316],
        [-0.0750, -0.0638,  0.0326, -0.1042, -0.0361],
        [-0.0748, -0.0599,  0.0303, -0.1050, -0.0352],
        [-0.0760, -0.0656,  0.0320, -0.1033, -0.0361],
        [-0.0743, -0.0638,  0.0338, -0.1047, -0.0367],
        [-0.0774, -0.0621,  0.0308, -0.1015, -0.0307],
        [-0.0759, -0.0572,  0.0301, -0.1031, -0.0301],
        [-0.0740, -0.0599,  0.0314, -0.1056, -0.0360],
        [-0.0738, -0.0575,  0.0323, -0.1052, -0.0333],
        [-0.0773, -0.0636,  0.0307, -0.1019, -0.0325],
        [-0.0745, -0.0593,  0.0315, -0.1047, -0.0340],
        [-0.0742, -0.0636,  0.0334, -0.1050, -0.0370],
        [-0.0755, -0.0614,  0.0312, -0.1037, -0.0339],
        [

[tensor([[-0.0671, -0.0641,  0.0358, -0.1052, -0.0455],
        [-0.0641, -0.0616,  0.0376, -0.1081, -0.0481],
        [-0.0656, -0.0640,  0.0364, -0.1070, -0.0482],
        [-0.0647, -0.0660,  0.0384, -0.1077, -0.0502],
        [-0.0686, -0.0670,  0.0365, -0.1033, -0.0441],
        [-0.0626, -0.0661,  0.0394, -0.1104, -0.0545],
        [-0.0653, -0.0641,  0.0379, -0.1067, -0.0473],
        [-0.0645, -0.0659,  0.0380, -0.1082, -0.0510],
        [-0.0648, -0.0629,  0.0369, -0.1077, -0.0486],
        [-0.0694, -0.0658,  0.0357, -0.1022, -0.0417],
        [-0.0659, -0.0663,  0.0374, -0.1065, -0.0486],
        [-0.0655, -0.0653,  0.0378, -0.1067, -0.0483],
        [-0.0639, -0.0607,  0.0372, -0.1085, -0.0482],
        [-0.0663, -0.0628,  0.0357, -0.1061, -0.0462],
        [-0.0665, -0.0652,  0.0357, -0.1062, -0.0480],
        [-0.0603, -0.0615,  0.0395, -0.1127, -0.0552],
        [-0.0654, -0.0662,  0.0378, -0.1070, -0.0493],
        [-0.0669, -0.0631,  0.0368, -0.1049, -0.0439],
        [

[tensor([[-0.0721, -0.0542,  0.0319, -0.0918, -0.0243],
        [-0.0758, -0.0529,  0.0296, -0.0872, -0.0166],
        [-0.0740, -0.0542,  0.0317, -0.0891, -0.0200],
        [-0.0753, -0.0532,  0.0317, -0.0871, -0.0159],
        [-0.0699, -0.0497,  0.0320, -0.0940, -0.0250],
        [-0.0715, -0.0474,  0.0313, -0.0916, -0.0196],
        [-0.0749, -0.0537,  0.0321, -0.0876, -0.0170],
        [-0.0704, -0.0445,  0.0314, -0.0924, -0.0191],
        [-0.0687, -0.0522,  0.0338, -0.0955, -0.0287],
        [-0.0731, -0.0528,  0.0325, -0.0898, -0.0200],
        [-0.0721, -0.0551,  0.0345, -0.0908, -0.0224],
        [-0.0752, -0.0503,  0.0287, -0.0879, -0.0163],
        [-0.0681, -0.0499,  0.0348, -0.0954, -0.0266],
        [-0.0740, -0.0546,  0.0316, -0.0892, -0.0204],
        [-0.0731, -0.0587,  0.0336, -0.0905, -0.0247],
        [-0.0742, -0.0541,  0.0312, -0.0891, -0.0200],
        [-0.0702, -0.0540,  0.0338, -0.0936, -0.0267],
        [-0.0754, -0.0546,  0.0316, -0.0872, -0.0171],
        [

[tensor([[-0.0855, -0.0444,  0.0227, -0.0698,  0.0103],
        [-0.0837, -0.0466,  0.0238, -0.0724,  0.0047],
        [-0.0849, -0.0425,  0.0228, -0.0701,  0.0109],
        [-0.0883, -0.0441,  0.0219, -0.0658,  0.0167],
        [-0.0834, -0.0471,  0.0262, -0.0719,  0.0060],
        [-0.0852, -0.0468,  0.0234, -0.0705,  0.0078],
        [-0.0887, -0.0453,  0.0222, -0.0655,  0.0166],
        [-0.0818, -0.0454,  0.0260, -0.0739,  0.0038],
        [-0.0860, -0.0446,  0.0234, -0.0687,  0.0120],
        [-0.0881, -0.0444,  0.0209, -0.0668,  0.0145],
        [-0.0877, -0.0407,  0.0220, -0.0659,  0.0187],
        [-0.0829, -0.0423,  0.0248, -0.0722,  0.0082],
        [-0.0894, -0.0417,  0.0212, -0.0641,  0.0208],
        [-0.0858, -0.0459,  0.0231, -0.0694,  0.0100],
        [-0.0812, -0.0427,  0.0259, -0.0743,  0.0049],
        [-0.0854, -0.0432,  0.0231, -0.0694,  0.0117],
        [-0.0875, -0.0442,  0.0223, -0.0669,  0.0149],
        [-0.0843, -0.0380,  0.0227, -0.0701,  0.0139],
        [

[tensor([[-8.3283e-02, -4.4341e-02,  2.3054e-02, -6.7158e-02,  1.0771e-02],
        [-8.0321e-02, -4.8474e-02,  2.5049e-02, -7.1785e-02,  1.2340e-03],
        [-7.7102e-02, -4.5605e-02,  2.6062e-02, -7.5426e-02, -2.5674e-03],
        [-7.7606e-02, -3.8369e-02,  2.4349e-02, -7.3578e-02,  4.5035e-03],
        [-8.2930e-02, -4.4265e-02,  2.2624e-02, -6.7884e-02,  9.5160e-03],
        [-7.7877e-02, -4.1492e-02,  2.5453e-02, -7.3453e-02,  3.0537e-03],
        [-7.8151e-02, -4.4692e-02,  2.5732e-02, -7.3749e-02,  6.3349e-04],
        [-7.9345e-02, -5.0180e-02,  2.7794e-02, -7.2380e-02, -1.2778e-05],
        [-7.9659e-02, -4.6817e-02,  2.4762e-02, -7.2494e-02,  1.0467e-03],
        [-7.8193e-02, -4.3516e-02,  2.3888e-02, -7.4275e-02, -1.1854e-05],
        [-7.9021e-02, -4.4941e-02,  2.4289e-02, -7.3198e-02,  9.5188e-04],
        [-8.1896e-02, -4.5470e-02,  2.2494e-02, -6.9842e-02,  5.5409e-03],
        [-7.6705e-02, -5.0864e-02,  2.8052e-02, -7.6453e-02, -6.9773e-03],
        [-7.7288e-02, -4

[tensor([[-0.0689, -0.0572,  0.0293, -0.0822, -0.0235],
        [-0.0702, -0.0555,  0.0278, -0.0804, -0.0200],
        [-0.0690, -0.0582,  0.0302, -0.0819, -0.0234],
        [-0.0687, -0.0564,  0.0299, -0.0819, -0.0224],
        [-0.0680, -0.0571,  0.0307, -0.0828, -0.0241],
        [-0.0677, -0.0574,  0.0302, -0.0837, -0.0257],
        [-0.0666, -0.0567,  0.0302, -0.0852, -0.0279],
        [-0.0667, -0.0570,  0.0310, -0.0847, -0.0269],
        [-0.0673, -0.0561,  0.0304, -0.0837, -0.0250],
        [-0.0673, -0.0574,  0.0304, -0.0842, -0.0265],
        [-0.0693, -0.0609,  0.0296, -0.0826, -0.0264],
        [-0.0690, -0.0597,  0.0299, -0.0824, -0.0253],
        [-0.0679, -0.0566,  0.0295, -0.0835, -0.0251],
        [-0.0680, -0.0560,  0.0302, -0.0828, -0.0234],
        [-0.0682, -0.0573,  0.0304, -0.0827, -0.0241],
        [-0.0706, -0.0587,  0.0294, -0.0800, -0.0208],
        [-0.0687, -0.0593,  0.0299, -0.0827, -0.0255],
        [-0.0683, -0.0587,  0.0302, -0.0831, -0.0256],
        [

[tensor([[-0.0521, -0.0622,  0.0377, -0.0967, -0.0502],
        [-0.0520, -0.0620,  0.0376, -0.0969, -0.0503],
        [-0.0530, -0.0634,  0.0366, -0.0962, -0.0504],
        [-0.0515, -0.0605,  0.0376, -0.0971, -0.0498],
        [-0.0526, -0.0633,  0.0367, -0.0968, -0.0512],
        [-0.0524, -0.0607,  0.0376, -0.0958, -0.0479],
        [-0.0508, -0.0599,  0.0373, -0.0983, -0.0514],
        [-0.0513, -0.0624,  0.0381, -0.0979, -0.0520],
        [-0.0530, -0.0626,  0.0370, -0.0958, -0.0491],
        [-0.0508, -0.0609,  0.0387, -0.0979, -0.0510],
        [-0.0530, -0.0616,  0.0371, -0.0954, -0.0478],
        [-0.0528, -0.0629,  0.0365, -0.0965, -0.0505],
        [-0.0529, -0.0637,  0.0368, -0.0964, -0.0508],
        [-0.0515, -0.0648,  0.0387, -0.0981, -0.0537],
        [-0.0538, -0.0636,  0.0366, -0.0950, -0.0486],
        [-0.0529, -0.0638,  0.0374, -0.0961, -0.0503],
        [-0.0533, -0.0616,  0.0370, -0.0948, -0.0470],
        [-0.0512, -0.0623,  0.0383, -0.0980, -0.0521],
        [

[tensor([[-0.0421, -0.0506,  0.0406, -0.0959, -0.0435],
        [-0.0410, -0.0484,  0.0414, -0.0964, -0.0429],
        [-0.0410, -0.0484,  0.0414, -0.0963, -0.0427],
        [-0.0418, -0.0505,  0.0407, -0.0963, -0.0441],
        [-0.0416, -0.0494,  0.0405, -0.0963, -0.0435],
        [-0.0426, -0.0464,  0.0403, -0.0936, -0.0374],
        [-0.0437, -0.0511,  0.0398, -0.0938, -0.0408],
        [-0.0418, -0.0491,  0.0405, -0.0959, -0.0427],
        [-0.0413, -0.0519,  0.0419, -0.0971, -0.0459],
        [-0.0420, -0.0487,  0.0410, -0.0951, -0.0411],
        [-0.0436, -0.0507,  0.0402, -0.0936, -0.0400],
        [-0.0418, -0.0495,  0.0408, -0.0958, -0.0426],
        [-0.0423, -0.0512,  0.0409, -0.0957, -0.0434],
        [-0.0427, -0.0517,  0.0404, -0.0954, -0.0435],
        [-0.0391, -0.0469,  0.0416, -0.0988, -0.0456],
        [-0.0402, -0.0492,  0.0419, -0.0978, -0.0454],
        [-0.0402, -0.0474,  0.0406, -0.0978, -0.0447],
        [-0.0409, -0.0519,  0.0422, -0.0976, -0.0465],
        [

[tensor([[-0.0387, -0.0418,  0.0416, -0.0892, -0.0297],
        [-0.0395, -0.0391,  0.0409, -0.0870, -0.0247],
        [-0.0374, -0.0395,  0.0424, -0.0899, -0.0292],
        [-0.0398, -0.0407,  0.0403, -0.0875, -0.0266],
        [-0.0386, -0.0318,  0.0390, -0.0861, -0.0195],
        [-0.0391, -0.0372,  0.0400, -0.0874, -0.0244],
        [-0.0396, -0.0395,  0.0410, -0.0869, -0.0248],
        [-0.0385, -0.0403,  0.0420, -0.0885, -0.0275],
        [-0.0411, -0.0372,  0.0393, -0.0842, -0.0196],
        [-0.0379, -0.0407,  0.0429, -0.0893, -0.0289],
        [-0.0397, -0.0393,  0.0404, -0.0870, -0.0250],
        [-0.0391, -0.0373,  0.0399, -0.0874, -0.0246],
        [-0.0393, -0.0379,  0.0407, -0.0868, -0.0238],
        [-0.0405, -0.0403,  0.0403, -0.0861, -0.0241],
        [-0.0383, -0.0400,  0.0419, -0.0887, -0.0277],
        [-0.0401, -0.0402,  0.0404, -0.0866, -0.0249],
        [-0.0397, -0.0385,  0.0415, -0.0860, -0.0226],
        [-0.0384, -0.0380,  0.0408, -0.0884, -0.0263],
        [

[tensor([[-0.0405, -0.0328,  0.0379, -0.0745, -0.0052],
        [-0.0400, -0.0332,  0.0379, -0.0756, -0.0072],
        [-0.0388, -0.0325,  0.0378, -0.0775, -0.0098],
        [-0.0380, -0.0308,  0.0381, -0.0778, -0.0093],
        [-0.0407, -0.0332,  0.0376, -0.0746, -0.0057],
        [-0.0389, -0.0324,  0.0380, -0.0772, -0.0092],
        [-0.0397, -0.0301,  0.0373, -0.0750, -0.0047],
        [-0.0393, -0.0299,  0.0377, -0.0754, -0.0050],
        [-0.0413, -0.0324,  0.0371, -0.0735, -0.0036],
        [-0.0413, -0.0359,  0.0385, -0.0744, -0.0067],
        [-0.0399, -0.0321,  0.0376, -0.0754, -0.0063],
        [-0.0383, -0.0320,  0.0382, -0.0779, -0.0100],
        [-0.0409, -0.0348,  0.0372, -0.0752, -0.0077],
        [-0.0395, -0.0329,  0.0385, -0.0760, -0.0076],
        [-0.0402, -0.0326,  0.0394, -0.0741, -0.0042],
        [-0.0391, -0.0348,  0.0402, -0.0766, -0.0092],
        [-0.0379, -0.0284,  0.0389, -0.0764, -0.0054],
        [-0.0388, -0.0346,  0.0396, -0.0774, -0.0105],
        [

[tensor([[-0.0465, -0.0305,  0.0309, -0.0609,  0.0114],
        [-0.0475, -0.0309,  0.0296, -0.0601,  0.0121],
        [-0.0480, -0.0298,  0.0292, -0.0589,  0.0144],
        [-0.0518, -0.0258,  0.0255, -0.0522,  0.0262],
        [-0.0488, -0.0294,  0.0306, -0.0562,  0.0192],
        [-0.0539, -0.0299,  0.0257, -0.0504,  0.0267],
        [-0.0517, -0.0255,  0.0254, -0.0523,  0.0262],
        [-0.0475, -0.0297,  0.0300, -0.0592,  0.0142],
        [-0.0483, -0.0299,  0.0292, -0.0583,  0.0153],
        [-0.0505, -0.0327,  0.0296, -0.0557,  0.0179],
        [-0.0467, -0.0304,  0.0313, -0.0603,  0.0125],
        [-0.0495, -0.0314,  0.0291, -0.0571,  0.0163],
        [-0.0492, -0.0260,  0.0274, -0.0558,  0.0209],
        [-0.0487, -0.0286,  0.0292, -0.0569,  0.0183],
        [-0.0497, -0.0282,  0.0280, -0.0557,  0.0200],
        [-0.0484, -0.0292,  0.0292, -0.0579,  0.0163],
        [-0.0528, -0.0266,  0.0258, -0.0506,  0.0283],
        [-0.0493, -0.0295,  0.0282, -0.0569,  0.0174],
        [

[tensor([[-0.0647, -0.0280,  0.0126, -0.0337,  0.0437],
        [-0.0653, -0.0266,  0.0114, -0.0325,  0.0460],
        [-0.0643, -0.0285,  0.0126, -0.0346,  0.0420],
        [-0.0655, -0.0308,  0.0127, -0.0338,  0.0421],
        [-0.0688, -0.0248,  0.0092, -0.0265,  0.0556],
        [-0.0669, -0.0288,  0.0116, -0.0308,  0.0475],
        [-0.0635, -0.0285,  0.0138, -0.0353,  0.0413],
        [-0.0626, -0.0304,  0.0151, -0.0374,  0.0374],
        [-0.0638, -0.0274,  0.0128, -0.0348,  0.0424],
        [-0.0649, -0.0273,  0.0114, -0.0338,  0.0436],
        [-0.0644, -0.0274,  0.0127, -0.0339,  0.0437],
        [-0.0639, -0.0273,  0.0131, -0.0345,  0.0431],
        [-0.0672, -0.0303,  0.0119, -0.0309,  0.0467],
        [-0.0651, -0.0304,  0.0142, -0.0333,  0.0434],
        [-0.0641, -0.0282,  0.0138, -0.0340,  0.0435],
        [-0.0648, -0.0289,  0.0127, -0.0340,  0.0428],
        [-0.0619, -0.0265,  0.0132, -0.0375,  0.0390],
        [-0.0662, -0.0280,  0.0125, -0.0311,  0.0478],
        [

[tensor([[-0.0680, -0.0356,  0.0041, -0.0333,  0.0318],
        [-0.0689, -0.0372,  0.0042, -0.0325,  0.0321],
        [-0.0717, -0.0383,  0.0027, -0.0290,  0.0365],
        [-0.0663, -0.0349,  0.0048, -0.0356,  0.0287],
        [-0.0687, -0.0345,  0.0042, -0.0312,  0.0355],
        [-0.0696, -0.0367,  0.0036, -0.0314,  0.0339],
        [-0.0689, -0.0367,  0.0037, -0.0325,  0.0322],
        [-0.0708, -0.0393,  0.0031, -0.0309,  0.0332],
        [-0.0677, -0.0381,  0.0048, -0.0349,  0.0282],
        [-0.0685, -0.0346,  0.0036, -0.0320,  0.0341],
        [-0.0704, -0.0370,  0.0031, -0.0303,  0.0353],
        [-0.0684, -0.0367,  0.0048, -0.0328,  0.0321],
        [-0.0690, -0.0349,  0.0031, -0.0316,  0.0344],
        [-0.0689, -0.0331,  0.0027, -0.0310,  0.0361],
        [-0.0685, -0.0375,  0.0052, -0.0328,  0.0318],
        [-0.0675, -0.0357,  0.0046, -0.0340,  0.0308],
        [-0.0688, -0.0360,  0.0038, -0.0322,  0.0331],
        [-0.0667, -0.0367,  0.0060, -0.0352,  0.0287],
        [

[tensor([[-0.0645, -0.0499,  0.0041, -0.0451,  0.0013],
        [-0.0641, -0.0493,  0.0040, -0.0454,  0.0011],
        [-0.0670, -0.0495,  0.0021, -0.0413,  0.0067],
        [-0.0640, -0.0486,  0.0038, -0.0454,  0.0014],
        [-0.0647, -0.0497,  0.0041, -0.0445,  0.0023],
        [-0.0652, -0.0518,  0.0042, -0.0447,  0.0009],
        [-0.0650, -0.0493,  0.0036, -0.0441,  0.0030],
        [-0.0666, -0.0522,  0.0037, -0.0427,  0.0037],
        [-0.0657, -0.0518,  0.0043, -0.0437,  0.0025],
        [-0.0651, -0.0514,  0.0042, -0.0448,  0.0010],
        [-0.0629, -0.0504,  0.0061, -0.0470, -0.0013],
        [-0.0650, -0.0517,  0.0045, -0.0449,  0.0008],
        [-0.0660, -0.0506,  0.0036, -0.0429,  0.0042],
        [-0.0630, -0.0463,  0.0039, -0.0458,  0.0020],
        [-0.0636, -0.0497,  0.0050, -0.0459,  0.0004],
        [-0.0640, -0.0509,  0.0050, -0.0460, -0.0002],
        [-0.0634, -0.0531,  0.0057, -0.0480, -0.0042],
        [-0.0626, -0.0497,  0.0058, -0.0473, -0.0015],
        [

[tensor([[-0.0561, -0.0602,  0.0091, -0.0611, -0.0300],
        [-0.0558, -0.0609,  0.0098, -0.0617, -0.0310],
        [-0.0560, -0.0636,  0.0110, -0.0622, -0.0328],
        [-0.0570, -0.0639,  0.0106, -0.0608, -0.0309],
        [-0.0585, -0.0585,  0.0077, -0.0566, -0.0229],
        [-0.0563, -0.0623,  0.0103, -0.0612, -0.0308],
        [-0.0577, -0.0611,  0.0089, -0.0587, -0.0269],
        [-0.0600, -0.0582,  0.0058, -0.0548, -0.0205],
        [-0.0567, -0.0612,  0.0100, -0.0600, -0.0287],
        [-0.0576, -0.0613,  0.0088, -0.0592, -0.0277],
        [-0.0568, -0.0612,  0.0091, -0.0605, -0.0296],
        [-0.0563, -0.0600,  0.0096, -0.0603, -0.0287],
        [-0.0568, -0.0613,  0.0097, -0.0602, -0.0290],
        [-0.0567, -0.0619,  0.0095, -0.0607, -0.0302],
        [-0.0596, -0.0591,  0.0068, -0.0554, -0.0216],
        [-0.0566, -0.0611,  0.0102, -0.0600, -0.0285],
        [-0.0570, -0.0626,  0.0097, -0.0604, -0.0301],
        [-0.0577, -0.0607,  0.0083, -0.0591, -0.0275],
        [

[tensor([[-0.0527, -0.0602,  0.0104, -0.0652, -0.0385],
        [-0.0512, -0.0584,  0.0116, -0.0661, -0.0387],
        [-0.0532, -0.0554,  0.0085, -0.0625, -0.0328],
        [-0.0524, -0.0616,  0.0111, -0.0662, -0.0406],
        [-0.0524, -0.0599,  0.0109, -0.0652, -0.0383],
        [-0.0518, -0.0594,  0.0113, -0.0659, -0.0390],
        [-0.0516, -0.0597,  0.0118, -0.0661, -0.0393],
        [-0.0518, -0.0604,  0.0113, -0.0665, -0.0404],
        [-0.0512, -0.0585,  0.0112, -0.0665, -0.0394],
        [-0.0523, -0.0603,  0.0114, -0.0653, -0.0386],
        [-0.0524, -0.0597,  0.0104, -0.0654, -0.0386],
        [-0.0518, -0.0619,  0.0120, -0.0669, -0.0415],
        [-0.0521, -0.0597,  0.0113, -0.0654, -0.0384],
        [-0.0525, -0.0615,  0.0119, -0.0653, -0.0390],
        [-0.0522, -0.0575,  0.0097, -0.0649, -0.0370],
        [-0.0510, -0.0594,  0.0121, -0.0668, -0.0401],
        [-0.0520, -0.0615,  0.0118, -0.0664, -0.0406],
        [-0.0512, -0.0587,  0.0114, -0.0664, -0.0394],
        [

        [ 0.6806]])]
[tensor([[-0.0475, -0.0501,  0.0119, -0.0663, -0.0370],
        [-0.0472, -0.0512,  0.0124, -0.0670, -0.0385],
        [-0.0457, -0.0500,  0.0132, -0.0687, -0.0401],
        [-0.0460, -0.0502,  0.0127, -0.0686, -0.0403],
        [-0.0467, -0.0513,  0.0129, -0.0678, -0.0395],
        [-0.0471, -0.0506,  0.0117, -0.0673, -0.0388],
        [-0.0461, -0.0525,  0.0134, -0.0693, -0.0422],
        [-0.0472, -0.0526,  0.0130, -0.0675, -0.0397],
        [-0.0470, -0.0491,  0.0111, -0.0670, -0.0378],
        [-0.0452, -0.0493,  0.0134, -0.0691, -0.0403],
        [-0.0459, -0.0503,  0.0132, -0.0685, -0.0401],
        [-0.0456, -0.0485,  0.0123, -0.0685, -0.0394],
        [-0.0476, -0.0477,  0.0099, -0.0658, -0.0357],
        [-0.0469, -0.0486,  0.0115, -0.0666, -0.0369],
        [-0.0460, -0.0502,  0.0133, -0.0681, -0.0394],
        [-0.0458, -0.0488,  0.0119, -0.0686, -0.0398],
        [-0.0470, -0.0512,  0.0117, -0.0679, -0.0400],
        [-0.0454, -0.0501,  0.0135, -0.0691

[tensor([[-0.0376, -0.0255,  0.0143, -0.0646, -0.0254],
        [-0.0375, -0.0263,  0.0145, -0.0650, -0.0263],
        [-0.0366, -0.0292,  0.0157, -0.0677, -0.0313],
        [-0.0368, -0.0257,  0.0150, -0.0658, -0.0270],
        [-0.0387, -0.0272,  0.0141, -0.0637, -0.0250],
        [-0.0385, -0.0286,  0.0141, -0.0649, -0.0272],
        [-0.0365, -0.0261,  0.0153, -0.0663, -0.0279],
        [-0.0378, -0.0286,  0.0142, -0.0662, -0.0292],
        [-0.0366, -0.0268,  0.0158, -0.0662, -0.0279],
        [-0.0372, -0.0297,  0.0165, -0.0665, -0.0296],
        [-0.0366, -0.0265,  0.0151, -0.0667, -0.0287],
        [-0.0376, -0.0270,  0.0139, -0.0657, -0.0278],
        [-0.0368, -0.0281,  0.0162, -0.0665, -0.0289],
        [-0.0372, -0.0282,  0.0150, -0.0665, -0.0292],
        [-0.0362, -0.0288,  0.0161, -0.0680, -0.0314],
        [-0.0366, -0.0293,  0.0161, -0.0677, -0.0311],
        [-0.0369, -0.0263,  0.0147, -0.0662, -0.0280],
        [-0.0358, -0.0266,  0.0165, -0.0672, -0.0291],
        [

[tensor([[-3.9382e-02, -1.2569e-02,  1.4115e-02, -5.0502e-02, -3.0036e-03],
        [-3.8858e-02, -8.2762e-03,  1.2707e-02, -4.9683e-02, -2.7722e-04],
        [-3.9600e-02, -1.0304e-02,  1.3476e-02, -4.9029e-02, -4.6045e-05],
        [-4.0715e-02, -1.1867e-02,  1.2804e-02, -4.8356e-02,  6.9082e-05],
        [-3.9746e-02, -9.9642e-03,  1.3680e-02, -4.8376e-02,  1.0905e-03],
        [-4.0720e-02, -7.5096e-03,  1.1334e-02, -4.6505e-02,  4.2899e-03],
        [-4.0266e-02, -1.2843e-02,  1.3511e-02, -4.9380e-02, -1.6600e-03],
        [-4.0261e-02, -1.0069e-02,  1.2533e-02, -4.8247e-02,  9.5192e-04],
        [-3.8886e-02, -1.0056e-02,  1.3736e-02, -5.0078e-02, -1.3788e-03],
        [-3.9578e-02, -1.0649e-02,  1.3892e-02, -4.9007e-02, -6.6109e-05],
        [-4.1534e-02, -1.3931e-02,  1.2644e-02, -4.8231e-02, -7.0085e-04],
        [-3.9342e-02, -1.1246e-02,  1.3215e-02, -5.0340e-02, -2.4058e-03],
        [-3.8388e-02, -9.3472e-03,  1.4090e-02, -5.0335e-02, -1.3504e-03],
        [-3.9425e-02, -9

[tensor([[-0.0385, -0.0117,  0.0167, -0.0456,  0.0025],
        [-0.0397, -0.0126,  0.0172, -0.0435,  0.0052],
        [-0.0404, -0.0111,  0.0165, -0.0416,  0.0084],
        [-0.0399, -0.0116,  0.0162, -0.0431,  0.0061],
        [-0.0394, -0.0104,  0.0171, -0.0426,  0.0075],
        [-0.0396, -0.0130,  0.0172, -0.0440,  0.0044],
        [-0.0385, -0.0114,  0.0176, -0.0448,  0.0041],
        [-0.0380, -0.0092,  0.0165, -0.0450,  0.0044],
        [-0.0392, -0.0109,  0.0169, -0.0434,  0.0060],
        [-0.0376, -0.0100,  0.0173, -0.0458,  0.0032],
        [-0.0396, -0.0107,  0.0161, -0.0433,  0.0062],
        [-0.0390, -0.0122,  0.0169, -0.0449,  0.0033],
        [-0.0393, -0.0134,  0.0172, -0.0449,  0.0030],
        [-0.0389, -0.0124,  0.0177, -0.0446,  0.0040],
        [-0.0385, -0.0127,  0.0173, -0.0459,  0.0019],
        [-0.0402, -0.0085,  0.0148, -0.0414,  0.0094],
        [-0.0396, -0.0122,  0.0172, -0.0434,  0.0056],
        [-0.0406, -0.0131,  0.0165, -0.0425,  0.0064],
        [

[tensor([[-0.0374, -0.0206,  0.0219, -0.0446, -0.0008],
        [-0.0383, -0.0238,  0.0222, -0.0448, -0.0023],
        [-0.0379, -0.0221,  0.0220, -0.0447, -0.0015],
        [-0.0379, -0.0192,  0.0220, -0.0426,  0.0027],
        [-0.0375, -0.0215,  0.0226, -0.0446, -0.0010],
        [-0.0377, -0.0206,  0.0222, -0.0439,  0.0003],
        [-0.0366, -0.0208,  0.0231, -0.0454, -0.0017],
        [-0.0378, -0.0218,  0.0219, -0.0448, -0.0015],
        [-0.0388, -0.0188,  0.0203, -0.0418,  0.0036],
        [-0.0378, -0.0185,  0.0213, -0.0429,  0.0024],
        [-0.0386, -0.0201,  0.0206, -0.0428,  0.0018],
        [-0.0372, -0.0215,  0.0222, -0.0456, -0.0024],
        [-0.0390, -0.0212,  0.0211, -0.0425,  0.0018],
        [-0.0376, -0.0228,  0.0228, -0.0453, -0.0024],
        [-0.0383, -0.0252,  0.0232, -0.0452, -0.0032],
        [-0.0376, -0.0205,  0.0216, -0.0444, -0.0004],
        [-0.0380, -0.0191,  0.0213, -0.0428,  0.0024],
        [-0.0383, -0.0192,  0.0204, -0.0429,  0.0019],
        [

[tensor([[-0.0345, -0.0312,  0.0260, -0.0490, -0.0121],
        [-0.0344, -0.0264,  0.0243, -0.0469, -0.0074],
        [-0.0352, -0.0323,  0.0262, -0.0482, -0.0113],
        [-0.0336, -0.0302,  0.0268, -0.0496, -0.0123],
        [-0.0342, -0.0327,  0.0275, -0.0495, -0.0131],
        [-0.0339, -0.0324,  0.0267, -0.0507, -0.0148],
        [-0.0349, -0.0323,  0.0266, -0.0486, -0.0118],
        [-0.0347, -0.0292,  0.0252, -0.0476, -0.0094],
        [-0.0336, -0.0299,  0.0261, -0.0498, -0.0126],
        [-0.0353, -0.0339,  0.0261, -0.0492, -0.0135],
        [-0.0333, -0.0297,  0.0263, -0.0501, -0.0129],
        [-0.0353, -0.0322,  0.0262, -0.0480, -0.0109],
        [-0.0343, -0.0310,  0.0259, -0.0494, -0.0125],
        [-0.0326, -0.0299,  0.0272, -0.0510, -0.0139],
        [-0.0349, -0.0326,  0.0262, -0.0490, -0.0126],
        [-0.0345, -0.0291,  0.0253, -0.0480, -0.0099],
        [-0.0340, -0.0323,  0.0271, -0.0499, -0.0135],
        [-0.0349, -0.0311,  0.0259, -0.0482, -0.0109],
        [

Epoch Completion: 29.000%, Loss: 90.302[tensor([[-0.0301, -0.0367,  0.0288, -0.0541, -0.0218],
        [-0.0307, -0.0365,  0.0288, -0.0526, -0.0197],
        [-0.0305, -0.0352,  0.0286, -0.0522, -0.0187],
        [-0.0309, -0.0385,  0.0294, -0.0535, -0.0217],
        [-0.0300, -0.0357,  0.0291, -0.0533, -0.0203],
        [-0.0303, -0.0314,  0.0278, -0.0504, -0.0146],
        [-0.0305, -0.0348,  0.0282, -0.0523, -0.0187],
        [-0.0298, -0.0349,  0.0296, -0.0527, -0.0190],
        [-0.0299, -0.0349,  0.0290, -0.0531, -0.0196],
        [-0.0299, -0.0346,  0.0282, -0.0534, -0.0202],
        [-0.0302, -0.0333,  0.0282, -0.0516, -0.0171],
        [-0.0297, -0.0349,  0.0287, -0.0537, -0.0206],
        [-0.0308, -0.0372,  0.0289, -0.0530, -0.0205],
        [-0.0292, -0.0345,  0.0286, -0.0545, -0.0216],
        [-0.0305, -0.0346,  0.0283, -0.0520, -0.0182],
        [-0.0312, -0.0361,  0.0288, -0.0514, -0.0178],
        [-0.0310, -0.0375,  0.0285, -0.0530, -0.0207],
        [-0.0324, -0.032

KeyboardInterrupt: 

In [2]:
np.asarray([1,1])-np.asarray([[2],[2]])

array([[-1, -1],
       [-1, -1]])

In [26]:
# what follows is code to visualize the representations
d = 5 # must match the final entry in layers
n_samps = 500
#env = ew.TorchEncoderWrapper(env,net.encoder,np.eye(d))
X = np.empty((n_samps,d))
for i,ang in enumerate(np.linspace(0,2*np.pi,n_samps)): # go through the angles from 0 to 2pi
    #print(env.reset(state=[ang,0]))
    obs=env.reset(state=[ang,0])
    with torch.no_grad():
        X[i,:] = net.encoder.forward(torch.from_numpy(np.expand_dims(obs,0)).float())
utils.visualize_trajectory(X)

In [27]:
X

array([[ 0.09636454, -0.12076874, -0.11199814,  0.209089  ,  0.06439343],
       [ 0.09653349, -0.11903687, -0.11283582,  0.20988123,  0.06624708],
       [ 0.09653349, -0.11903687, -0.11283582,  0.20988123,  0.06624708],
       ...,
       [ 0.09636454, -0.12076874, -0.11199814,  0.209089  ,  0.06439343],
       [ 0.09636454, -0.12076874, -0.11199814,  0.209089  ,  0.06439343],
       [ 0.09636454, -0.12076874, -0.11199814,  0.209089  ,  0.06439343]])

In [17]:
net.encoder.forward

<bound method ConvEncoderNet.forward of ConvEncoderNet(
  (conv1): Conv2d(1, 16, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(16, 16, kernel_size=(4, 4), stride=(2, 2))
  (layers): ModuleList(
    (0): Linear(in_features=2112, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=5, bias=True)
  )
)>

In [2]:
#what are weights

import sys
from lib.restartable_pendulum import RestartablePendulumEnv
from lib import state_rep_torch as srt
import gym
import numpy as np
from matplotlib import pyplot as plt
import torch
from lib import utils
from lib import encoder_wrappers as ew

# specify environment information
n_repeats = 3 # step the environment this many times for each action, concatenate the pixel observations
env = RestartablePendulumEnv(repeats=n_repeats,pixels=True)


#nonlin = torch.nn.ELU()
nonlin = torch.nn.functional.relu
layers = [50, 10, 5] # architecture of encoder after the 2 conv layers
save_dir = "./"
n_episodes = 500 # total batches to draw
batch_size = 25
learning_rate = .001
save_every = int(n_episodes/2) # save the model every so often

encnet = srt.ConvEncoderNet(layers,env.observation_space.shape[1:],sigma=nonlin)

# use the following commented out lines for PredictorNet (I changed deterministic sampling though...
#prednet = srt.PredictorNet(encnet,T,layers[-1],1)
#deterministic_args = (samples[i], batch_size, 35, method, n_repeats,T) 

# 2*state dim+action dim
rnet = srt.EncoderNet([2*5+1, 50, 10, 1])

#prednet = srt.ForwardNet(encnet,layers[-1],1)
prednet = srt.PiecewiseForwardNet(encnet,layers[-1],1,2,fit_reward=True,mu=1, r_encoder = rnet,alpha=1)

In [15]:
encnet

ConvEncoderNet(
  (conv1): Conv2d(1, 16, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(16, 16, kernel_size=(4, 4), stride=(2, 2))
  (layers): ModuleList(
    (0): Linear(in_features=2112, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=5, bias=True)
  )
)

In [14]:
encnet.layers[2].weight

Parameter containing:
tensor([[-0.1487, -0.1759,  0.1770, -0.2456,  0.3060,  0.0667, -0.1194,  0.0529,
         -0.1871, -0.1306],
        [ 0.1675,  0.2610,  0.0371, -0.1310,  0.0397, -0.1577,  0.0479, -0.1478,
         -0.1023,  0.0937],
        [ 0.0280,  0.0736, -0.2789,  0.1069,  0.1072,  0.0639,  0.2972, -0.0238,
          0.0280,  0.2696],
        [ 0.0513, -0.3018,  0.2311,  0.0414,  0.2740, -0.0295, -0.2415, -0.2992,
         -0.2916,  0.0901],
        [ 0.1452,  0.2717,  0.2936, -0.2310, -0.0724, -0.2903,  0.2604,  0.1252,
          0.1103, -0.2748]], requires_grad=True)

In [19]:
#simple identity with actual rep

import sys
from lib.restartable_pendulum import RestartablePendulumEnv
from lib import state_rep_torch as srt
import gym
import numpy as np
from matplotlib import pyplot as plt
import torch
from lib import utils
from lib import encoder_wrappers as ew
import torch.nn as nn

# specify environment information
n_repeats = 3 # step the environment this many times for each action, concatenate the pixel observations
env = RestartablePendulumEnv(repeats=n_repeats,pixels=False)

#nonlin = torch.nn.ELU()
save_dir = "./"
n_episodes = 500 # total batches to draw
batch_size = 25
learning_rate = .001
save_every = int(n_episodes/2) # save the model every so often

encnet = nn.Linear(9,9)

# use the following commented out lines for PredictorNet (I changed deterministic sampling though...
#prednet = srt.PredictorNet(encnet,T,layers[-1],1)
#deterministic_args = (samples[i], batch_size, 35, method, n_repeats,T) 

# 2*state dim+action dim
rnet = srt.EncoderNet([2*9+1, 50, 10, 1])

#prednet = srt.ForwardNet(encnet,layers[-1],1)
prednet = srt.PiecewiseForwardNet(encnet,9,1,2,fit_reward=True,mu=1, r_encoder = rnet,alpha=1)

deterministic_args = None 

traj_sampler = srt.SimpleTrajectorySampler(env,
                                     srt.sample_pendulum_action_batch,
                                     srt.sample_pendulum_state_batch_old,
                                     device=torch.device("cpu"),
                                     deterministic=False,
                                     deterministic_args=deterministic_args,
                                          output_rewards=True)

net, losses = srt.train_encoder(prednet,traj_sampler,n_episodes,
                                batch_size=batch_size,
                                track_loss_every=int(n_episodes/100),
                                lr=learning_rate,
                                save_every=save_every,
                                save_path=save_dir)

torch.save(net,save_dir+".net")




RuntimeError: mat1 and mat2 shapes cannot be multiplied (25x9 and 5x2)