In [1]:
import tensorflow as tf
import torch
import numpy as np
import pickle as pkl
import gym
import torch.optim as optim
from torch.utils import data
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.autograd import Variable

In [2]:
def render_test(policy_fn):
    env = gym.make('Humanoid-v2')
    max_steps = env.spec.timestep_limit
    obs = env.reset()
    done = False
    steps = 0
    totalr = 0
    while not done:
        action = np.asarray(policy_fn(torch.from_numpy(np.asarray(obs,dtype=np.float32))).data)
        
        obs, r, done, _ = env.step(action)
        totalr += r
        steps += 1
#         env.render()
        if steps >= max_steps:
            break
    print("Total reward = ",totalr)
    #env.close()

In [3]:
humanoid_rollouts = pkl.load(open("expert_data/Humanoid-v2.pkl",'rb'))

In [4]:
obs_mean = humanoid_rollouts['observations'].mean(axis=0)

In [5]:
obs_sqmean = np.square(humanoid_rollouts['observations']).mean(axis=0)

In [6]:
obs_std = np.sqrt(np.maximum(0, obs_sqmean - np.square(obs_mean)))

In [7]:
X = pkl.load(open("expert_data/Humanoid-v2.pkl",'rb'))
Y = X['actions']
Y = np.reshape(Y,(Y.shape[0],-1))
Y = Y.astype(np.float32)
Y = list(Y)
X = X['observations']
X = (X - obs_mean)/(obs_std + 1e-6)
X = X.astype(np.float32)
X = list(X)

In [8]:
# X,Y = shuffle(X,Y)

In [8]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=2000, shuffle=True, random_state=42)

In [9]:
len(X_train)

97065

In [None]:
# obs_mean = np.asarray(X_train).mean(axis=0)
# obs_sqmean = np.square(np.asarray(X_train)).mean(axis=0)
# obs_std = np.sqrt(np.maximum(0,obs_sqmean - np.square(obs_mean)))

In [10]:
class Dataset(data.Dataset):
    def __init__(self, observations, actions):
        'Initialization'
        self.observations = observations
        self.actions = actions
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.observations)
    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        # Load data and get label
        obs = self.observations[index]
        act = self.actions[index]
        return obs, act

In [11]:
params = {'batch_size': 100,
          'shuffle': True,
          'num_workers': 4}
training_set = Dataset(X_train, Y_train)
training_generator = data.DataLoader(training_set, **params)

validation_set = Dataset(X_test, Y_test)
validation_generator = data.DataLoader(validation_set, **params)

In [12]:
class policy(torch.nn.Module):
    def __init__(self, D_in, H1, H2, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(policy, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2, D_out)
    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h1_relu = nn.functional.relu(self.linear1(x))
        h2_relu = nn.functional.relu(self.linear2(h1_relu))
        y_pred = self.linear3(h2_relu)
        return y_pred

In [13]:
obs_size, h1_size, h2_size , act_size = X_train[0].shape[0], 100, 50, Y_train[0].shape[0]
model = policy(obs_size, h1_size, h2_size, act_size)

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss

In [15]:
def test(model):
    with torch.no_grad():
        test_loss = 0
        correct = 0
        for sample, target in validation_generator:
            output = model(sample)
            # sum up batch loss
            test_loss += loss_func(output, target).item()
 
        test_loss /= len(validation_generator.dataset)
        print('\nTest set: Average loss: {:.4f}'.format(test_loss))

In [16]:
for epoch in range(50):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, sample in enumerate(training_generator, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = sample
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 10 == 0:    # print every 10 mini-batches
            print('[%d, %d] loss: %f' %(epoch + 1, i + 1, running_loss/2000))
            running_loss = 0.0

    torch.save(model,'epoch_%d_humanoid.model'%(epoch))
    # testing
    print("running on val set")
    test(model)
    
            
print('Finished Training')

[1, 1] loss: 0.000638
[1, 11] loss: 0.005084
[1, 21] loss: 0.003015
[1, 31] loss: 0.002040
[1, 41] loss: 0.001529
[1, 51] loss: 0.001310
[1, 61] loss: 0.001227
[1, 71] loss: 0.001087
[1, 81] loss: 0.000987
[1, 91] loss: 0.000925
[1, 101] loss: 0.000876


  "type " + obj.__name__ + ". It won't be checked "


[1, 111] loss: 0.000775
[1, 121] loss: 0.000751
[1, 131] loss: 0.000693
[1, 141] loss: 0.000676
[1, 151] loss: 0.000644
[1, 161] loss: 0.000617
[1, 171] loss: 0.000554
[1, 181] loss: 0.000502
[1, 191] loss: 0.000483
[1, 201] loss: 0.000490
[1, 211] loss: 0.000496
[1, 221] loss: 0.000437
[1, 231] loss: 0.000531
[1, 241] loss: 0.000418
[1, 251] loss: 0.000392
[1, 261] loss: 0.000399
[1, 271] loss: 0.000402
[1, 281] loss: 0.000376
[1, 291] loss: 0.000343
[1, 301] loss: 0.000363
[1, 311] loss: 0.000347
[1, 321] loss: 0.000330
[1, 331] loss: 0.000322
[1, 341] loss: 0.000309
[1, 351] loss: 0.000310
[1, 361] loss: 0.000295
[1, 371] loss: 0.000302
[1, 381] loss: 0.000297
[1, 391] loss: 0.000283
[1, 401] loss: 0.000287
[1, 411] loss: 0.000304
[1, 421] loss: 0.000279
[1, 431] loss: 0.000290
[1, 441] loss: 0.000275
[1, 451] loss: 0.000262
[1, 461] loss: 0.000273
[1, 471] loss: 0.000237
[1, 481] loss: 0.000264
[1, 491] loss: 0.000244
[1, 501] loss: 0.000255
[1, 511] loss: 0.000250
[1, 521] loss: 0

[4, 551] loss: 0.000096
[4, 561] loss: 0.000102
[4, 571] loss: 0.000098
[4, 581] loss: 0.000099
[4, 591] loss: 0.000107
[4, 601] loss: 0.000100
[4, 611] loss: 0.000101
[4, 621] loss: 0.000104
[4, 631] loss: 0.000095
[4, 641] loss: 0.000104
[4, 651] loss: 0.000096
[4, 661] loss: 0.000103
[4, 671] loss: 0.000108
[4, 681] loss: 0.000098
[4, 691] loss: 0.000101
[4, 701] loss: 0.000104
[4, 711] loss: 0.000095
[4, 721] loss: 0.000148
[4, 731] loss: 0.000103
[4, 741] loss: 0.000103
[4, 751] loss: 0.000145
[4, 761] loss: 0.000094
[4, 771] loss: 0.000101
[4, 781] loss: 0.000102
[4, 791] loss: 0.000102
[4, 801] loss: 0.000100
[4, 811] loss: 0.000100
[4, 821] loss: 0.000099
[4, 831] loss: 0.000103
[4, 841] loss: 0.000101
[4, 851] loss: 0.000111
[4, 861] loss: 0.000106
[4, 871] loss: 0.000109
[4, 881] loss: 0.000101
[4, 891] loss: 0.000110
[4, 901] loss: 0.000108
[4, 911] loss: 0.000096
[4, 921] loss: 0.000096
[4, 931] loss: 0.000106
[4, 941] loss: 0.000096
[4, 951] loss: 0.000126
[4, 961] loss: 0


Test set: Average loss: 0.0002
[8, 1] loss: 0.000009
[8, 11] loss: 0.000097
[8, 21] loss: 0.000090
[8, 31] loss: 0.000082
[8, 41] loss: 0.000084
[8, 51] loss: 0.000080
[8, 61] loss: 0.000084
[8, 71] loss: 0.000083
[8, 81] loss: 0.000090
[8, 91] loss: 0.000082
[8, 101] loss: 0.000088
[8, 111] loss: 0.000086
[8, 121] loss: 0.000089
[8, 131] loss: 0.000084
[8, 141] loss: 0.000076
[8, 151] loss: 0.000085
[8, 161] loss: 0.000083
[8, 171] loss: 0.000084
[8, 181] loss: 0.000074
[8, 191] loss: 0.000096
[8, 201] loss: 0.000083
[8, 211] loss: 0.000088
[8, 221] loss: 0.000084
[8, 231] loss: 0.000083
[8, 241] loss: 0.000088
[8, 251] loss: 0.000080
[8, 261] loss: 0.000079
[8, 271] loss: 0.000091
[8, 281] loss: 0.000085
[8, 291] loss: 0.000081
[8, 301] loss: 0.000081
[8, 311] loss: 0.000081
[8, 321] loss: 0.000085
[8, 331] loss: 0.000084
[8, 341] loss: 0.000080
[8, 351] loss: 0.000080
[8, 361] loss: 0.000077
[8, 371] loss: 0.000082
[8, 381] loss: 0.000085
[8, 391] loss: 0.000085
[8, 401] loss: 0.00

[11, 361] loss: 0.000077
[11, 371] loss: 0.000077
[11, 381] loss: 0.000084
[11, 391] loss: 0.000077
[11, 401] loss: 0.000078
[11, 411] loss: 0.000078
[11, 421] loss: 0.000074
[11, 431] loss: 0.000086
[11, 441] loss: 0.000072
[11, 451] loss: 0.000074
[11, 461] loss: 0.000079
[11, 471] loss: 0.000077
[11, 481] loss: 0.000080
[11, 491] loss: 0.000073
[11, 501] loss: 0.000078
[11, 511] loss: 0.000087
[11, 521] loss: 0.000079
[11, 531] loss: 0.000087
[11, 541] loss: 0.000076
[11, 551] loss: 0.000072
[11, 561] loss: 0.000076
[11, 571] loss: 0.000081
[11, 581] loss: 0.000072
[11, 591] loss: 0.000074
[11, 601] loss: 0.000082
[11, 611] loss: 0.000083
[11, 621] loss: 0.000085
[11, 631] loss: 0.000080
[11, 641] loss: 0.000078
[11, 651] loss: 0.000077
[11, 661] loss: 0.000075
[11, 671] loss: 0.000076
[11, 681] loss: 0.000081
[11, 691] loss: 0.000078
[11, 701] loss: 0.000081
[11, 711] loss: 0.000075
[11, 721] loss: 0.000079
[11, 731] loss: 0.000079
[11, 741] loss: 0.000077
[11, 751] loss: 0.000075


[14, 651] loss: 0.000074
[14, 661] loss: 0.000075
[14, 671] loss: 0.000072
[14, 681] loss: 0.000076
[14, 691] loss: 0.000075
[14, 701] loss: 0.000078
[14, 711] loss: 0.000072
[14, 721] loss: 0.000076
[14, 731] loss: 0.000074
[14, 741] loss: 0.000079
[14, 751] loss: 0.000073
[14, 761] loss: 0.000076
[14, 771] loss: 0.000099
[14, 781] loss: 0.000081
[14, 791] loss: 0.000090
[14, 801] loss: 0.000072
[14, 811] loss: 0.000083
[14, 821] loss: 0.000071
[14, 831] loss: 0.000073
[14, 841] loss: 0.000072
[14, 851] loss: 0.000070
[14, 861] loss: 0.000077
[14, 871] loss: 0.000076
[14, 881] loss: 0.000072
[14, 891] loss: 0.000076
[14, 901] loss: 0.000073
[14, 911] loss: 0.000069
[14, 921] loss: 0.000078
[14, 931] loss: 0.000075
[14, 941] loss: 0.000073
[14, 951] loss: 0.000075
[14, 961] loss: 0.000074
[14, 971] loss: 0.000070
running on val set

Test set: Average loss: 0.0002
[15, 1] loss: 0.000007
[15, 11] loss: 0.000072
[15, 21] loss: 0.000074
[15, 31] loss: 0.000072
[15, 41] loss: 0.000073
[15, 

[17, 941] loss: 0.000069
[17, 951] loss: 0.000074
[17, 961] loss: 0.000066
[17, 971] loss: 0.000066
running on val set

Test set: Average loss: 0.0002
[18, 1] loss: 0.000009
[18, 11] loss: 0.000068
[18, 21] loss: 0.000064
[18, 31] loss: 0.000070
[18, 41] loss: 0.000068
[18, 51] loss: 0.000066
[18, 61] loss: 0.000069
[18, 71] loss: 0.000073
[18, 81] loss: 0.000068
[18, 91] loss: 0.000073
[18, 101] loss: 0.000072
[18, 111] loss: 0.000069
[18, 121] loss: 0.000066
[18, 131] loss: 0.000071
[18, 141] loss: 0.000072
[18, 151] loss: 0.000069
[18, 161] loss: 0.000068
[18, 171] loss: 0.000068
[18, 181] loss: 0.000067
[18, 191] loss: 0.000069
[18, 201] loss: 0.000071
[18, 211] loss: 0.000068
[18, 221] loss: 0.000068
[18, 231] loss: 0.000068
[18, 241] loss: 0.000066
[18, 251] loss: 0.000067
[18, 261] loss: 0.000067
[18, 271] loss: 0.000071
[18, 281] loss: 0.000063
[18, 291] loss: 0.000068
[18, 301] loss: 0.000070
[18, 311] loss: 0.000068
[18, 321] loss: 0.000066
[18, 331] loss: 0.000061
[18, 341] 

[21, 241] loss: 0.000071
[21, 251] loss: 0.000063
[21, 261] loss: 0.000064
[21, 271] loss: 0.000062
[21, 281] loss: 0.000065
[21, 291] loss: 0.000059
[21, 301] loss: 0.000065
[21, 311] loss: 0.000064
[21, 321] loss: 0.000069
[21, 331] loss: 0.000062
[21, 341] loss: 0.000063
[21, 351] loss: 0.000065
[21, 361] loss: 0.000065
[21, 371] loss: 0.000064
[21, 381] loss: 0.000062
[21, 391] loss: 0.000071
[21, 401] loss: 0.000074
[21, 411] loss: 0.000064
[21, 421] loss: 0.000066
[21, 431] loss: 0.000069
[21, 441] loss: 0.000061
[21, 451] loss: 0.000069
[21, 461] loss: 0.000078
[21, 471] loss: 0.000070
[21, 481] loss: 0.000064
[21, 491] loss: 0.000068
[21, 501] loss: 0.000075
[21, 511] loss: 0.000073
[21, 521] loss: 0.000070
[21, 531] loss: 0.000072
[21, 541] loss: 0.000064
[21, 551] loss: 0.000067
[21, 561] loss: 0.000068
[21, 571] loss: 0.000066
[21, 581] loss: 0.000064
[21, 591] loss: 0.000078
[21, 601] loss: 0.000070
[21, 611] loss: 0.000063
[21, 621] loss: 0.000068
[21, 631] loss: 0.000071


[24, 541] loss: 0.000071
[24, 551] loss: 0.000071
[24, 561] loss: 0.000074
[24, 571] loss: 0.000066
[24, 581] loss: 0.000064
[24, 591] loss: 0.000064
[24, 601] loss: 0.000063
[24, 611] loss: 0.000062
[24, 621] loss: 0.000068
[24, 631] loss: 0.000066
[24, 641] loss: 0.000066
[24, 651] loss: 0.000065
[24, 661] loss: 0.000066
[24, 671] loss: 0.000064
[24, 681] loss: 0.000062
[24, 691] loss: 0.000061
[24, 701] loss: 0.000065
[24, 711] loss: 0.000061
[24, 721] loss: 0.000060
[24, 731] loss: 0.000068
[24, 741] loss: 0.000059
[24, 751] loss: 0.000070
[24, 761] loss: 0.000068
[24, 771] loss: 0.000065
[24, 781] loss: 0.000064
[24, 791] loss: 0.000071
[24, 801] loss: 0.000064
[24, 811] loss: 0.000062
[24, 821] loss: 0.000063
[24, 831] loss: 0.000062
[24, 841] loss: 0.000059
[24, 851] loss: 0.000064
[24, 861] loss: 0.000064
[24, 871] loss: 0.000069
[24, 881] loss: 0.000070
[24, 891] loss: 0.000065
[24, 901] loss: 0.000064
[24, 911] loss: 0.000069
[24, 921] loss: 0.000067
[24, 931] loss: 0.000067


[27, 831] loss: 0.000064
[27, 841] loss: 0.000063
[27, 851] loss: 0.000066
[27, 861] loss: 0.000070
[27, 871] loss: 0.000062
[27, 881] loss: 0.000063
[27, 891] loss: 0.000061
[27, 901] loss: 0.000062
[27, 911] loss: 0.000058
[27, 921] loss: 0.000066
[27, 931] loss: 0.000061
[27, 941] loss: 0.000065
[27, 951] loss: 0.000063
[27, 961] loss: 0.000061
[27, 971] loss: 0.000064
running on val set

Test set: Average loss: 0.0001
[28, 1] loss: 0.000005
[28, 11] loss: 0.000067
[28, 21] loss: 0.000061
[28, 31] loss: 0.000059
[28, 41] loss: 0.000057
[28, 51] loss: 0.000062
[28, 61] loss: 0.000057
[28, 71] loss: 0.000062
[28, 81] loss: 0.000059
[28, 91] loss: 0.000062
[28, 101] loss: 0.000066
[28, 111] loss: 0.000057
[28, 121] loss: 0.000064
[28, 131] loss: 0.000065
[28, 141] loss: 0.000063
[28, 151] loss: 0.000061
[28, 161] loss: 0.000069
[28, 171] loss: 0.000061
[28, 181] loss: 0.000065
[28, 191] loss: 0.000065
[28, 201] loss: 0.000062
[28, 211] loss: 0.000064
[28, 221] loss: 0.000063
[28, 231] 

[31, 131] loss: 0.000059
[31, 141] loss: 0.000064
[31, 151] loss: 0.000061
[31, 161] loss: 0.000061
[31, 171] loss: 0.000058
[31, 181] loss: 0.000057
[31, 191] loss: 0.000061
[31, 201] loss: 0.000057
[31, 211] loss: 0.000060
[31, 221] loss: 0.000057
[31, 231] loss: 0.000061
[31, 241] loss: 0.000060
[31, 251] loss: 0.000061
[31, 261] loss: 0.000061
[31, 271] loss: 0.000059
[31, 281] loss: 0.000061
[31, 291] loss: 0.000063
[31, 301] loss: 0.000061
[31, 311] loss: 0.000061
[31, 321] loss: 0.000058
[31, 331] loss: 0.000061
[31, 341] loss: 0.000055
[31, 351] loss: 0.000061
[31, 361] loss: 0.000059
[31, 371] loss: 0.000063
[31, 381] loss: 0.000059
[31, 391] loss: 0.000061
[31, 401] loss: 0.000062
[31, 411] loss: 0.000058
[31, 421] loss: 0.000060
[31, 431] loss: 0.000062
[31, 441] loss: 0.000059
[31, 451] loss: 0.000064
[31, 461] loss: 0.000059
[31, 471] loss: 0.000060
[31, 481] loss: 0.000061
[31, 491] loss: 0.000063
[31, 501] loss: 0.000060
[31, 511] loss: 0.000062
[31, 521] loss: 0.000066


[34, 421] loss: 0.000063
[34, 431] loss: 0.000063
[34, 441] loss: 0.000056
[34, 451] loss: 0.000057
[34, 461] loss: 0.000059
[34, 471] loss: 0.000059
[34, 481] loss: 0.000064
[34, 491] loss: 0.000061
[34, 501] loss: 0.000057
[34, 511] loss: 0.000058
[34, 521] loss: 0.000055
[34, 531] loss: 0.000062
[34, 541] loss: 0.000061
[34, 551] loss: 0.000065
[34, 561] loss: 0.000061
[34, 571] loss: 0.000062
[34, 581] loss: 0.000063
[34, 591] loss: 0.000064
[34, 601] loss: 0.000062
[34, 611] loss: 0.000058
[34, 621] loss: 0.000063
[34, 631] loss: 0.000060
[34, 641] loss: 0.000069
[34, 651] loss: 0.000063
[34, 661] loss: 0.000066
[34, 671] loss: 0.000064
[34, 681] loss: 0.000062
[34, 691] loss: 0.000061
[34, 701] loss: 0.000061
[34, 711] loss: 0.000057
[34, 721] loss: 0.000057
[34, 731] loss: 0.000057
[34, 741] loss: 0.000056
[34, 751] loss: 0.000061
[34, 761] loss: 0.000058
[34, 771] loss: 0.000059
[34, 781] loss: 0.000056
[34, 791] loss: 0.000062
[34, 801] loss: 0.000059
[34, 811] loss: 0.000056


[37, 741] loss: 0.000062
[37, 751] loss: 0.000064
[37, 761] loss: 0.000066
[37, 771] loss: 0.000061
[37, 781] loss: 0.000055
[37, 791] loss: 0.000060
[37, 801] loss: 0.000056
[37, 811] loss: 0.000060
[37, 821] loss: 0.000057
[37, 831] loss: 0.000060
[37, 841] loss: 0.000059
[37, 851] loss: 0.000058
[37, 861] loss: 0.000059
[37, 871] loss: 0.000063
[37, 881] loss: 0.000062
[37, 891] loss: 0.000059
[37, 901] loss: 0.000059
[37, 911] loss: 0.000061
[37, 921] loss: 0.000056
[37, 931] loss: 0.000059
[37, 941] loss: 0.000058
[37, 951] loss: 0.000067
[37, 961] loss: 0.000058
[37, 971] loss: 0.000053
running on val set

Test set: Average loss: 0.0001
[38, 1] loss: 0.000006
[38, 11] loss: 0.000056
[38, 21] loss: 0.000063
[38, 31] loss: 0.000061
[38, 41] loss: 0.000061
[38, 51] loss: 0.000055
[38, 61] loss: 0.000062
[38, 71] loss: 0.000054
[38, 81] loss: 0.000071
[38, 91] loss: 0.000054
[38, 101] loss: 0.000059
[38, 111] loss: 0.000056
[38, 121] loss: 0.000058
[38, 131] loss: 0.000059
[38, 141] 

[41, 51] loss: 0.000057
[41, 61] loss: 0.000061
[41, 71] loss: 0.000057
[41, 81] loss: 0.000054
[41, 91] loss: 0.000052
[41, 101] loss: 0.000053
[41, 111] loss: 0.000054
[41, 121] loss: 0.000057
[41, 131] loss: 0.000062
[41, 141] loss: 0.000064
[41, 151] loss: 0.000053
[41, 161] loss: 0.000060
[41, 171] loss: 0.000055
[41, 181] loss: 0.000054
[41, 191] loss: 0.000052
[41, 201] loss: 0.000057
[41, 211] loss: 0.000056
[41, 221] loss: 0.000057
[41, 231] loss: 0.000061
[41, 241] loss: 0.000062
[41, 251] loss: 0.000056
[41, 261] loss: 0.000057
[41, 271] loss: 0.000060
[41, 281] loss: 0.000059
[41, 291] loss: 0.000058
[41, 301] loss: 0.000054
[41, 311] loss: 0.000054
[41, 321] loss: 0.000058
[41, 331] loss: 0.000057
[41, 341] loss: 0.000057
[41, 351] loss: 0.000055
[41, 361] loss: 0.000058
[41, 371] loss: 0.000056
[41, 381] loss: 0.000060
[41, 391] loss: 0.000060
[41, 401] loss: 0.000061
[41, 411] loss: 0.000058
[41, 421] loss: 0.000056
[41, 431] loss: 0.000058
[41, 441] loss: 0.000068
[41, 

[44, 351] loss: 0.000059
[44, 361] loss: 0.000060
[44, 371] loss: 0.000057
[44, 381] loss: 0.000055
[44, 391] loss: 0.000059
[44, 401] loss: 0.000056
[44, 411] loss: 0.000063
[44, 421] loss: 0.000053
[44, 431] loss: 0.000054
[44, 441] loss: 0.000057
[44, 451] loss: 0.000056
[44, 461] loss: 0.000055
[44, 471] loss: 0.000057
[44, 481] loss: 0.000059
[44, 491] loss: 0.000060
[44, 501] loss: 0.000057
[44, 511] loss: 0.000055
[44, 521] loss: 0.000060
[44, 531] loss: 0.000053
[44, 541] loss: 0.000068
[44, 551] loss: 0.000059
[44, 561] loss: 0.000058
[44, 571] loss: 0.000058
[44, 581] loss: 0.000058
[44, 591] loss: 0.000066
[44, 601] loss: 0.000058
[44, 611] loss: 0.000091
[44, 621] loss: 0.000061
[44, 631] loss: 0.000059
[44, 641] loss: 0.000058
[44, 651] loss: 0.000060
[44, 661] loss: 0.000059
[44, 671] loss: 0.000061
[44, 681] loss: 0.000057
[44, 691] loss: 0.000058
[44, 701] loss: 0.000062
[44, 711] loss: 0.000056
[44, 721] loss: 0.000052
[44, 731] loss: 0.000057
[44, 741] loss: 0.000057


[47, 661] loss: 0.000055
[47, 671] loss: 0.000056
[47, 681] loss: 0.000056
[47, 691] loss: 0.000054
[47, 701] loss: 0.000059
[47, 711] loss: 0.000053
[47, 721] loss: 0.000051
[47, 731] loss: 0.000057
[47, 741] loss: 0.000056
[47, 751] loss: 0.000052
[47, 761] loss: 0.000053
[47, 771] loss: 0.000060
[47, 781] loss: 0.000056
[47, 791] loss: 0.000053
[47, 801] loss: 0.000059
[47, 811] loss: 0.000057
[47, 821] loss: 0.000053
[47, 831] loss: 0.000054
[47, 841] loss: 0.000056
[47, 851] loss: 0.000054
[47, 861] loss: 0.000052
[47, 871] loss: 0.000056
[47, 881] loss: 0.000054
[47, 891] loss: 0.000056
[47, 901] loss: 0.000057
[47, 911] loss: 0.000056
[47, 921] loss: 0.000056
[47, 931] loss: 0.000058
[47, 941] loss: 0.000057
[47, 951] loss: 0.000058
[47, 961] loss: 0.000055
[47, 971] loss: 0.000054
running on val set

Test set: Average loss: 0.0001
[48, 1] loss: 0.000004
[48, 11] loss: 0.000056
[48, 21] loss: 0.000055
[48, 31] loss: 0.000054
[48, 41] loss: 0.000054
[48, 51] loss: 0.000057
[48, 6

[50, 951] loss: 0.000056
[50, 961] loss: 0.000055
[50, 971] loss: 0.000057
running on val set

Test set: Average loss: 0.0001
Finished Training


In [17]:
torch.save(model.state_dict(), 'humanoid_policy.model')

In [19]:
model.state_dict()

OrderedDict([('linear1.weight',
              tensor([[ 0.0554, -0.0433, -0.0404,  ..., -0.0181,  0.0126,  0.0498],
                      [-0.0085,  0.0212, -0.0469,  ..., -0.0155,  0.0342, -0.0259],
                      [-0.0270, -0.0223, -0.0371,  ...,  0.0343,  0.0255,  0.0073],
                      ...,
                      [-0.0456, -0.0264, -0.0058,  ...,  0.0351, -0.0218,  0.0394],
                      [ 0.0310,  0.0266, -0.0273,  ..., -0.0125,  0.0337, -0.0293],
                      [-0.0137,  0.0523, -0.0085,  ..., -0.0119, -0.0054,  0.0054]])),
             ('linear1.bias',
              tensor([-0.0316, -0.0318,  0.0182, -0.0282, -0.0055, -0.0232,  0.0747, -0.0293,
                       0.0496, -0.0448,  0.0114,  0.0463,  0.0391, -0.0436, -0.0146,  0.0086,
                      -0.0145, -0.0082,  0.0252, -0.0085, -0.0402,  0.0634, -0.0283, -0.0004,
                      -0.0153, -0.0132, -0.0265,  0.0077, -0.0320, -0.0074,  0.0462,  0.0095,
                       0.003

In [18]:
obs_stats = dict()
obs_stats['mean'] = obs_mean
obs_stats['std'] = obs_std
with open('obs_stats.pkl','wb') as fp:
    pkl.dump(obs_stats, fp)