In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from tensorboardX import SummaryWriter

import gym
# import roboschool
import sys

#%%

import metaworld
import wandb
from tqdm import tqdm



class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super(MLP, self).__init__()
        self.l1 = nn.Linear(in_dim, hid_dim)
        self.l2 = nn.Linear(hid_dim, hid_dim)
        self.l3 = nn.Linear(hid_dim, out_dim)
    
    def forward(self, x):
#         print(x)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x
    



In [2]:
class ReplayBuffer(object):
    """Buffer to store tuples of experience replay"""

    def __init__(self, max_size=1000000):
        """
        Args:
            max_size (int): total amount of tuples to store
        """

        self.storage = []
        self.max_size = max_size
        self.ptr = 0

    def add(self, data):
        """Add experience tuples to buffer

        Args:
            data (tuple): experience replay tuple
        """

        if len(self.storage) == self.max_size:
            self.storage[int(self.ptr)] = data
            self.ptr = (self.ptr + 1) % self.max_size
        else:
            self.storage.append(data)

    def sample(self, batch_size):
        """Samples a random amount of experiences from buffer of batch size

        Args:
            batch_size (int): size of sample
        """

        ind = np.random.randint(0, len(self.storage), size=batch_size)
        states, actions, next_states, rewards, dones = [], [], [], [], []

        for i in ind:
            s, a, s_, r, d = self.storage[i]
            states.append(np.array(s, copy=False))
            actions.append(np.array(a, copy=False))
            next_states.append(np.array(s_, copy=False))
            rewards.append(np.array(r, copy=False))
            dones.append(np.array(d, copy=False))

        return np.array(states), np.array(actions), np.array(next_states), np.array(rewards).reshape(-1, 1), np.array(
            dones).reshape(-1, 1)

    def save(self):
        np_buffer = np.asarray(self.storage)
        with open('replaybuffer' + str(env._last_rand_vec[0]) + '_' + str(env._last_rand_vec[1]) + '.npy', 'wb') as f:
            print("Saving replay buffer in ", f)
            np.save(f, np_buffer)
            
    def load(self, filename):
        with open(filename, 'rb') as f:
            self.storage = np.load(f, allow_pickle=True)

In [3]:
class Actor(nn.Module):
    """Initialize parameters and build model.
        Args:
            state_size (int): Dimension of each state
            action_size (int): Dimension of each action
            max_action (float): highest action to take
            seed (int): Random seed
            h1_units (int): Number of nodes in first hidden layer
            h2_units (int): Number of nodes in second hidden layer

        Return:
            action output of network with tanh activation
    """

    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, 100)
        self.l2 = nn.Linear(100, 100)
        self.l3 = nn.Linear(100, action_dim)

        self.max_action = max_action

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.max_action * torch.tanh(self.l3(x))
        return x
    
    def load(self, filename="best_avg"):
        state_dict = torch.load(filename)
        for el in state_dict.keys():
            e = state_dict[el]
            print(e.shape)
        self.load_state_dict(state_dict)

    

In [8]:
state_dim = 39
target_actor1,target_actor2,target_actor3,target_actor4 = Actor(state_dim, 4, 1.0),Actor(state_dim, 4, 1.0),Actor(state_dim, 4, 1.0),Actor(state_dim, 4, 1.0)
target_actor1.load('models/best_avg-0.09_0.86_actor.pth')
target_actor2.load('models/best_avg-0.09_0.89_actor.pth')
target_actor3.load('models/best_avg0.07_0.86_actor.pth')
target_actor4.load('models/best_avg0.07_0.89_actor.pth')
target_actors = [target_actor1,target_actor2,target_actor3,target_actor4 ]
for a in target_actors:
    a.eval()

torch.Size([100, 39])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])
torch.Size([4, 100])
torch.Size([4])
torch.Size([100, 39])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])
torch.Size([4, 100])
torch.Size([4])
torch.Size([100, 39])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])
torch.Size([4, 100])
torch.Size([4])
torch.Size([100, 39])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])
torch.Size([4, 100])
torch.Size([4])


In [9]:
r1, r2, r3, r4 = ReplayBuffer(), ReplayBuffer(), ReplayBuffer(), ReplayBuffer()

r1.load('replaybuffers/replaybuffer-0.09_0.86.npy')
r2.load('replaybuffers/replaybuffer-0.09_0.89.npy')
r3.load('replaybuffers/replaybuffer0.07_0.86.npy')
r4.load('replaybuffers/replaybuffer0.07_0.89.npy')

replay_buffers = [r1,r2,r3,r4]


In [10]:
goal_MLP = MLP(2, 100, 4)
state_MLP = MLP(39, 100, 4)
goal_optimizer = torch.optim.Adam(goal_MLP.parameters(), lr=1e-3)
state_optimizer = torch.optim.Adam(state_MLP.parameters(), lr=1e-3)

In [11]:
import wandb
wandb.init(project="RL-transfer-learning", entity="frl", settings=wandb.Settings(start_method="fork"))
wandb.watch(goal_MLP, log_freq=10000)
wandb.watch(state_MLP, log_freq=10000)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfgossi[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[]

In [None]:
iterations = 1000000
batch_size = 100
goals = [[-0.09, 0.86], [-0.09, 0.89], [0.07, 0.86], [0.07, 0.89]]
for it in range(iterations):
    for goal_idx in range(4):
#         goal_idx = 3
        goal = torch.Tensor(goals[goal_idx])
        s, a, next_s, r, d = replay_buffers[goal_idx].sample(batch_size)
#         print(s)
        loss = 0
        for state in s:
            ss = torch.Tensor(state)
#             print(ss, ss.shape)
            target_action = target_actors[goal_idx](ss)
    #         print(target_action)
            pred_action = torch.mul(goal_MLP(goal), state_MLP(ss))
    
            loss += F.mse_loss(pred_action, target_action)
        print("loss:",loss.item())
        wandb.log({"loss": loss.item(), "step": it})
        
        goal_MLP.zero_grad()
        state_MLP.zero_grad()
        
        loss.backward()
        
        goal_optimizer.step()
        state_optimizer.step()
        

loss: 91.23941802978516
loss: 77.10416412353516
loss: 86.05088806152344
loss: 81.11884307861328
loss: 88.03341674804688
loss: 69.64188385009766
loss: 80.38008117675781
loss: 79.09566497802734
loss: 84.76258850097656
loss: 64.92466735839844
loss: 70.6837158203125
loss: 70.5404052734375
loss: 81.7104263305664
loss: 51.95756149291992
loss: 64.40206146240234
loss: 62.23408126831055
loss: 73.98877716064453
loss: 41.77582931518555
loss: 52.174903869628906
loss: 66.32875061035156
loss: 86.42039489746094
loss: 39.15113067626953
loss: 57.56932067871094
loss: 66.44692993164062
loss: 81.00995635986328
loss: 34.931949615478516
loss: 55.200660705566406
loss: 72.35468292236328
loss: 80.92044830322266
loss: 34.985679626464844
loss: 54.57392501831055
loss: 60.51991271972656
loss: 74.8302230834961
loss: 40.13710021972656
loss: 51.880104064941406
loss: 65.5450439453125
loss: 72.92865753173828
loss: 36.32811737060547
loss: 51.54486083984375
loss: 57.11216354370117
loss: 79.06961822509766
loss: 32.7033576

loss: 25.975696563720703
loss: 44.558563232421875
loss: 59.83476638793945
loss: 51.28065872192383
loss: 27.452959060668945
loss: 46.50086975097656
loss: 66.97631072998047
loss: 54.78048324584961
loss: 27.355195999145508
loss: 47.852474212646484
loss: 60.76836013793945
loss: 47.53154373168945
loss: 29.991390228271484
loss: 45.766170501708984
loss: 67.91533660888672
loss: 40.74586486816406
loss: 26.925477981567383
loss: 49.71335983276367
loss: 59.70710372924805
loss: 49.15348434448242
loss: 26.039594650268555
loss: 50.29188537597656
loss: 56.05305862426758
loss: 53.989742279052734
loss: 31.971651077270508
loss: 46.550174713134766
loss: 63.11445617675781
loss: 53.197330474853516
loss: 25.98990821838379
loss: 44.39219284057617
loss: 57.434593200683594
loss: 46.016597747802734
loss: 27.191776275634766
loss: 45.054039001464844
loss: 60.39162826538086
loss: 55.426841735839844
loss: 24.82500648498535
loss: 42.421043395996094
loss: 54.96552658081055
loss: 48.24217224121094
loss: 31.318059921264

loss: 41.75934982299805
loss: 43.973384857177734
loss: 24.65531349182129
loss: 35.26808166503906
loss: 50.02027130126953
loss: 43.56224822998047
loss: 31.22261619567871
loss: 34.82826232910156
loss: 47.66535568237305
loss: 35.68125534057617
loss: 21.564210891723633
loss: 33.73727035522461
loss: 39.25083923339844
loss: 33.14562225341797
loss: 20.801406860351562
loss: 33.1852912902832
loss: 43.35874938964844
loss: 42.76700973510742
loss: 22.89787483215332
loss: 30.387516021728516
loss: 56.898658752441406
loss: 39.246910095214844
loss: 18.21648597717285
loss: 30.53596305847168
loss: 42.892127990722656
loss: 40.55101013183594
loss: 18.3188419342041
loss: 33.61196517944336
loss: 53.78633499145508
loss: 38.814693450927734
loss: 23.877548217773438
loss: 32.77232360839844
loss: 51.19873046875
loss: 34.91233825683594
loss: 20.26171112060547
loss: 32.92099380493164
loss: 41.304752349853516
loss: 40.18426513671875
loss: 18.778078079223633
loss: 32.636474609375
loss: 54.72296142578125
loss: 37.964

loss: 18.265182495117188
loss: 37.327659606933594
loss: 42.7440299987793
loss: 35.03252029418945
loss: 15.861538887023926
loss: 29.39775276184082
loss: 39.8787841796875
loss: 34.89787673950195
loss: 21.64488410949707
loss: 27.571138381958008
loss: 44.19996643066406
loss: 32.20074462890625
loss: 17.017173767089844
loss: 27.908605575561523
loss: 39.701602935791016
loss: 39.860408782958984
loss: 17.251171112060547
loss: 28.165021896362305
loss: 38.12971115112305
loss: 29.663135528564453
loss: 19.512147903442383
loss: 33.08592987060547
loss: 38.21183395385742
loss: 33.36843490600586
loss: 26.843341827392578
loss: 30.860517501831055
loss: 35.80459213256836
loss: 39.02946853637695
loss: 13.852197647094727
loss: 27.070894241333008
loss: 37.496910095214844
loss: 34.72197341918945
loss: 16.126737594604492
loss: 29.805252075195312
loss: 38.215641021728516
loss: 32.134220123291016
loss: 21.427255630493164
loss: 27.60041046142578
loss: 37.61056900024414
loss: 31.958492279052734
loss: 21.2613048553