In [1]:
# import sys
# sys.path.append('./algorithms/')
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy
from typing import Any, Dict, List, Tuple

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class ActorCritic(nn.Module):

    def __init__(self, policy, encoder):
        super().__init__()
        self.policy = policy
        self.encoder = encoder
    
    def get_actor_params(self):
        return self.policy.get_actor_params()

    def get_critic_params(self):
        return self.policy.get_critic_params()

    def forward_actor(self, inputs):
        return self.policy.forward_actor(inputs)

    def forward_critic(self, inputs):
        return self.policy.forward_critic(inputs)
    
    def act(self, state, latent, belief=None, task=None, deterministic = False):
        return self.policy.act(state, latent, belief, task, deterministic)

    def get_value(self, state, latent, belief=None, task=None):
        value, _ = self.policy.forward(state, latent, belief, task)
        return value

    def evaluate_actions(self, state, latent, belief, task, action):
        """Call policy eval, set task, belief to None"""
        return self.policy.evaluate_actions(state, latent, belief, task, action)
    
        ## TODO: what to do about 'sample'? check what this arg is?
    # def forward(self, actions, states, rewards, hidden_state, return_prior=False, sample=True, detach_every=None):
    #     # really want this to take the inputs for the encoder and then output the outputs of the policy
    #     # we only want to get the prior when there are no previous rewards, actions or hidden states
    #     # should only occur at the very start of the continual learning process
    #     if hidden_state is None:
    #         # print('Hidden state is None!!:', hidden_state)
    #         _, latent_mean, latent_logvar, hidden_state = self.encoder.prior(states.shape[1]) # check that this gets the batch size?
    #     else:
    #         _, latent_mean, latent_logvar, hidden_state = self.encoder(actions, states, rewards, hidden_state, return_prior, sample, detach_every)
        
    #     latent_mean = F.relu(latent_mean)
    #     latent_logvar = F.relu(latent_logvar)
    #     latent = torch.cat((latent_mean, latent_logvar), dim=-1).reshape(1, -1)
    #     # none for belief and task
    #     return self.policy(states, latent, None, None), hidden_state, latent
    
    # def prior(self, num_processes):
    #     return self.encoder.prior(num_processes)
        

In [3]:
# get RL2 trained policy for example
RUN_FOLDER = './logs/logs_ML10-v2/rl2_73__25:10_21:13:08'
policy_net = torch.load(RUN_FOLDER + '/models/policy.pt')
encoder_net = torch.load(RUN_FOLDER + '/models/encoder.pt')

In [4]:
#1. get prior at start for base latent
# (does this reset the hidden state? I think so)
#2. feed policy observation + latent -> gets action, obs, reward, done
#3. feed encoder action, obs, reward, done and hidden state to get next action
import metaworld
import random

ml10 = metaworld.ML10() # Construct the benchmark, sampling tasks

training_envs = []
for name, env_cls in ml10.test_classes.items():
  env = env_cls()
  task = random.choice([task for task in ml10.test_tasks
                        if task.env_name == name])
  env.set_task(task)
  training_envs.append(env)

In [5]:
from environments.env_utils.vec_env.subproc_vec_env import SubprocVecEnv
from environments.parallel_envs import VecPyTorch

def make_continual_env(env_id, **kwargs):
    def _thunk():

        env = gym.make(env_id, **kwargs)
        # if tasks is not None:
        #     env.unwrapped.reset_task = lambda x: env.unwrapped.set_task(random.choice(tasks))
        # if seed is not None:
        #     env.seed(seed + rank)
        # if str(env.__class__.__name__).find('TimeLimit') >= 0:
        #     env = TimeLimitMask(env)
        # env = VariBadWrapper(env=env, episodes_per_task=episodes_per_task, add_done_info=add_done_info)
        return env
    return _thunk

vec_envs = SubprocVecEnv([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(4)])
# pyt_vec = VecPyTorch(vec_envs, device)
# pyt_vec2 = VecPyTorch([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(4)], device)

In [14]:

from algorithms.custom_storage import CustomOnlineStorage
from algorithms.custom_ppo import CustomPPO
from environments.metaworld_envs.test_continual_env import ContinualEnv

num_processes = 4
# combined network with encoder + policy
ac = ActorCritic(policy_net, encoder_net)
agent = CustomPPO(
    actor_critic=ac,
    value_loss_coef = 0.4,
    entropy_coef = 0.001,
    policy_optimiser='adam',
    policy_anneal_lr=False,
    train_steps = 2,
    lr = 1.0e-5,
    eps=1.0e-8,
    clip_param = 0.2,
    ppo_epoch = 3,
    use_huber_loss = True,
    use_clipped_value_loss=True,
    context_window=None
)
# env = training_envs[0]
# env = ContinualEnv(training_envs, 500)
envs = SubprocVecEnv([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(num_processes)])


In [15]:
## TODO: make sure this can get all other properties of envs
from environments.env_utils.vec_env import VecEnvWrapper
class PyTorchVecEnvCont(VecEnvWrapper):

    def __init__(self, vec_envs, device):
        super(PyTorchVecEnvCont, self).__init__(vec_envs)
        # self.vec_envs = vec_envs
        self.device = device

    # def step_wait(self, actions)

    # def step(self, actions):
    #     actions.cpu().detach()
    #     obs, reward, done, info = self.vec_envs.step(actions)
    #     obs = torch.from_numpy(obs).permute(1, 0, 2).to(self.device)
    #     reward = torch.from_numpy(reward).to(self.device)
    #     return obs, reward, done, info
    
    def step_async(self, actions):
        # actions = actions.squeeze(1).cpu().numpy()
        # convert actions for worker .permute(1, 0, 2)
        actions = actions.permute(1, 0, 2).squeeze().cpu().numpy()
        self.venv.step_async(actions)

    def step_wait(self):
        state, reward, done, info = self.venv.step_wait()
        if isinstance(state, list):  # raw + normalised .permute(1, 0, 2)
            state = [torch.from_numpy(s).float().to(self.device) for s in state]
        else:
            state = torch.from_numpy(state).permute(1, 0, 2).float().to(self.device)
        # reshape rewards to have dim T X B X D .reshape(1, -1, 1)
        if isinstance(reward, list):  # raw + normalised
            reward = [torch.from_numpy(r).unsqueeze(dim=1).reshape(1, -1, 1).float().to(self.device) for r in reward]
        else:
            reward = torch.from_numpy(reward).unsqueeze(dim=1).reshape(1, -1, 1).float().to(self.device)
        return state, reward, done, info
    
    def reset(self):
        # if task is not None:
        #     assert isinstance(task, list)
        state = self.venv.reset()
        ## permute state to have dimensions T X B X D .permute(1,0,2)
        if isinstance(state, list):
            state = [torch.from_numpy(s).float().to(self.device) for s in state]
        else:
            state = torch.from_numpy(state).float().to(self.device)
        return state
    
    # def reset(self):
    #     obs = self.vec_envs.reset()
    #     obs = torch.from_numpy(obs).permute(1, 0, 2).to(self.device)
    #     return obs

    def __getattr__(self, attr):
        """ If env does not have the attribute then call the attribute in the wrapped_env """

        if attr in ['_max_episode_steps', 'task_dim', 'belief_dim', 'num_states']:
            return self.unwrapped.get_env_attr(attr)

        try:
            orig_attr = self.__getattribute__(attr)
        except AttributeError:
            orig_attr = self.unwrapped.__getattribute__(attr)

        if callable(orig_attr):
            def hooked(*args, **kwargs):
                result = orig_attr(*args, **kwargs)
                return result

            return hooked
        else:
            return orig_attr


In [16]:
# num_episodes_per_update = 4
envs = SubprocVecEnv([make_continual_env('continualMW-v0', **{'envs' : training_envs, 'steps_per_env': 500}) for i in range(num_processes)])
env = PyTorchVecEnvCont(envs, device)
storage = CustomOnlineStorage(
    500, num_processes, env.observation_space.shape[0]+1, 0, 0,
    env.action_space, ac.encoder.hidden_size, ac.encoder.latent_dim, False)
res = dict()
while env.get_env_attr('cur_step') < env.get_env_attr('steps_limit'):
    print(f"current step: {env.get_env_attr('cur_step')}; limit: {env.get_env_attr('steps_limit')}")
    step = 0
    eps = 0
    
    # if I do this, I need to make sure my returns are calculated correctly / the done flags work
    # for i in range(num_episodes_per_update):
    obs = env.reset()
    done = [False for _ in range(num_processes)]
    # print(f"running episode {i}")
    ## get prior??? how frequent?
    # do at start of each episode for now
    with torch.no_grad():
        _, latent_mean, latent_logvar, hidden_state = agent.actor_critic.encoder.prior(num_processes)
        print(step)
        ## TODO: set the 500 to some sort of variable (max episode len?)
        assert len(storage.latent) == 0  # make sure we emptied buffers
        # print(f"saving hidden state to {i * 500}")
        storage.hidden_states[:1].copy_(hidden_state)
        latent = torch.cat((latent_mean.clone(), latent_logvar.clone()), dim=-1)#.reshape(1, -1)
        storage.latent.append(latent)

    while not all(done):
        value, action = agent.act(obs, latent, None, None)
        next_obs, reward, done, info = env.step(action)
        assert all(done) == any(done), "Metaworld envs should all end simultaneously"

        obs = next_obs

        ## TODO: do I even need masks?
        # create mask for episode ends
        masks_done = torch.FloatTensor([[0.0] if _done else [1.0] for _done in done]).to(device)
        # bad_mask is true if episode ended because time limit was reached
        # don't care for metaworld
        bad_masks = torch.FloatTensor([[0.0] for _done in done]).to(device)

        # if done:
        #     print(f'{step}: done!')
        #     hidden_state = agent.actor_critic.encoder.reset_hidden(hidden_state, masks_done)
        # print(action.size(), obs.squeeze(0).size(), reward.squeeze(0).size(), hidden_state.size(), latent.size(), masks_done.size(), bad_masks.size())
        _, latent_mean, latent_logvar, hidden_state = agent.actor_critic.encoder(action, obs, reward, hidden_state, return_prior = False)
        latent = torch.cat((latent_mean.clone(), latent_logvar.clone()), dim = -1)[None,:]#.reshape(1, -1)

        
        storage.next_state[step] = obs.clone()
        # print(action.squeeze(0).size(), obs.squeeze(0).size(), reward.squeeze(0).size(), hidden_state.size(), latent.size(), masks_done.size(), bad_masks.size())
        storage.insert(
            state=obs.squeeze(),
            belief=None,
            task=None,
            actions=action,
            rewards_raw=reward.squeeze(0),
            rewards_normalised=reward.squeeze(0),#rew_normalised,
            value_preds=value.squeeze(0),
            masks=masks_done.squeeze(0), # do I even need these?
            bad_masks=bad_masks.squeeze(0), 
            done=torch.from_numpy(done)[:,None].float(),
            hidden_states = hidden_state.squeeze(),
            latent = latent#.unsqueeze(1),
        )

        step += 1
        ### update
        # if step % num_updates ==0:

    # update at the end of each episode?
    res[eps] = agent.update(storage)
    # # # should clear out old data
    storage.after_update()
    eps+=1
        




current step: 0; limit: 2500
0













































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































current step: 500; limit: 2500
0









































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































In [17]:
res

{0: (nan, nan, nan, nan)}