In [1]:
# import sys
# sys.path.append('./algorithms/')
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy
from typing import Any, Dict, List, Tuple

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class ActorCritic(nn.Module):

    def __init__(self, policy, encoder):
        super().__init__()
        self.policy = policy
        self.encoder = encoder
    
    def get_actor_params(self):
        return self.policy.get_actor_params()

    def get_critic_params(self):
        return self.policy.get_critic_params()

    def forward_actor(self, inputs):
        return self.policy.forward_actor(inputs)

    def forward_critic(self, inputs):
        return self.policy.forward_critic(inputs)
    
    def act(self, state, latent, belief=None, task=None, deterministic = False):
        return self.policy.act(state, latent, belief, task, deterministic)

    def get_value(self, state, latent, belief=None, task=None):
        value, _ = self.policy.forward(state, latent, belief, task)
        return value

    def evaluate_actions(self, state, latent, belief, task, action):
        """Call policy eval, set task, belief to None"""
        return self.policy.evaluate_actions(state, latent, belief, task, action)
    
        ## TODO: what to do about 'sample'? check what this arg is?
    # def forward(self, actions, states, rewards, hidden_state, return_prior=False, sample=True, detach_every=None):
    #     # really want this to take the inputs for the encoder and then output the outputs of the policy
    #     # we only want to get the prior when there are no previous rewards, actions or hidden states
    #     # should only occur at the very start of the continual learning process
    #     if hidden_state is None:
    #         # print('Hidden state is None!!:', hidden_state)
    #         _, latent_mean, latent_logvar, hidden_state = self.encoder.prior(states.shape[1]) # check that this gets the batch size?
    #     else:
    #         _, latent_mean, latent_logvar, hidden_state = self.encoder(actions, states, rewards, hidden_state, return_prior, sample, detach_every)
        
    #     latent_mean = F.relu(latent_mean)
    #     latent_logvar = F.relu(latent_logvar)
    #     latent = torch.cat((latent_mean, latent_logvar), dim=-1).reshape(1, -1)
    #     # none for belief and task
    #     return self.policy(states, latent, None, None), hidden_state, latent
    
    # def prior(self, num_processes):
    #     return self.encoder.prior(num_processes)
        

In [3]:
# get RL2 trained policy for example
RUN_FOLDER = './logs/logs_ML10-v2/rl2_73__25:10_21:13:08'
policy_net = torch.load(RUN_FOLDER + '/models/policy.pt')
encoder_net = torch.load(RUN_FOLDER + '/models/encoder.pt')

In [4]:
#1. get prior at start for base latent
# (does this reset the hidden state? I think so)
#2. feed policy observation + latent -> gets action, obs, reward, done
#3. feed encoder action, obs, reward, done and hidden state to get next action
import metaworld
import random

ml10 = metaworld.ML10() # Construct the benchmark, sampling tasks

training_envs = []
for name, env_cls in ml10.test_classes.items():
  env = env_cls()
  task = random.choice([task for task in ml10.test_tasks
                        if task.env_name == name])
  env.set_task(task)
  training_envs.append(env)

In [5]:
from environments.env_utils.vec_env.subproc_vec_env import SubprocVecEnv
from environments.parallel_envs import VecPyTorch

def make_continual_env(env_id, **kwargs):
    def _thunk():

        env = gym.make(env_id, **kwargs)
        # if tasks is not None:
        #     env.unwrapped.reset_task = lambda x: env.unwrapped.set_task(random.choice(tasks))
        # if seed is not None:
        #     env.seed(seed + rank)
        # if str(env.__class__.__name__).find('TimeLimit') >= 0:
        #     env = TimeLimitMask(env)
        # env = VariBadWrapper(env=env, episodes_per_task=episodes_per_task, add_done_info=add_done_info)
        return env
    return _thunk

vec_envs = SubprocVecEnv([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(4)])
# pyt_vec = VecPyTorch(vec_envs, device)
# pyt_vec2 = VecPyTorch([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(4)], device)

In [6]:

from algorithms.custom_storage import CustomOnlineStorage
from algorithms.custom_ppo import CustomPPO
from environments.metaworld_envs.test_continual_env import ContinualEnv

num_processes = 4
# combined network with encoder + policy
ac = ActorCritic(policy_net, encoder_net)
agent = CustomPPO(
    actor_critic=ac,
    value_loss_coef = 1,
    entropy_coef = 0.001,
    policy_optimiser='adam',
    policy_anneal_lr=False,
    train_steps = 2,
    lr = 0.001,
    eps=1.0e-8,
    clip_param = 0.2,
    ppo_epoch = 3,
    use_huber_loss = True,
    use_clipped_value_loss=True,
    context_window=None
)
# env = training_envs[0]
# env = ContinualEnv(training_envs, 500)
envs = SubprocVecEnv([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(num_processes)])


In [7]:
## TODO: make sure this can get all other properties of envs
from environments.env_utils.vec_env import VecEnvWrapper
class PyTorchVecEnvCont(VecEnvWrapper):

    def __init__(self, vec_envs, device):
        super(PyTorchVecEnvCont, self).__init__(vec_envs)
        # self.vec_envs = vec_envs
        self.device = device

    # def step_wait(self, actions)

    # def step(self, actions):
    #     actions.cpu().detach()
    #     obs, reward, done, info = self.vec_envs.step(actions)
    #     obs = torch.from_numpy(obs).permute(1, 0, 2).to(self.device)
    #     reward = torch.from_numpy(reward).to(self.device)
    #     return obs, reward, done, info
    
    def step_async(self, actions):
        # actions = actions.squeeze(1).cpu().numpy()
        # convert actions for worker .permute(1, 0, 2)
        actions = actions.permute(1, 0, 2).squeeze().cpu().numpy()
        self.venv.step_async(actions)

    def step_wait(self):
        state, reward, done, info = self.venv.step_wait()
        if isinstance(state, list):  # raw + normalised .permute(1, 0, 2)
            state = [torch.from_numpy(s).float().to(self.device) for s in state]
        else:
            state = torch.from_numpy(state).permute(1, 0, 2).float().to(self.device)
        # reshape rewards to have dim T X B X D .reshape(1, -1, 1)
        if isinstance(reward, list):  # raw + normalised
            reward = [torch.from_numpy(r).unsqueeze(dim=1).reshape(1, -1, 1).float().to(self.device) for r in reward]
        else:
            reward = torch.from_numpy(reward).unsqueeze(dim=1).reshape(1, -1, 1).float().to(self.device)
        return state, reward, done, info
    
    def reset(self):
        # if task is not None:
        #     assert isinstance(task, list)
        state = self.venv.reset()
        ## permute state to have dimensions T X B X D .permute(1,0,2)
        if isinstance(state, list):
            state = [torch.from_numpy(s).float().to(self.device) for s in state]
        else:
            state = torch.from_numpy(state).float().to(self.device)
        return state
    
    # def reset(self):
    #     obs = self.vec_envs.reset()
    #     obs = torch.from_numpy(obs).permute(1, 0, 2).to(self.device)
    #     return obs

    def __getattr__(self, attr):
        """ If env does not have the attribute then call the attribute in the wrapped_env """

        if attr in ['_max_episode_steps', 'task_dim', 'belief_dim', 'num_states']:
            return self.unwrapped.get_env_attr(attr)

        try:
            orig_attr = self.__getattribute__(attr)
        except AttributeError:
            orig_attr = self.unwrapped.__getattribute__(attr)

        if callable(orig_attr):
            def hooked(*args, **kwargs):
                result = orig_attr(*args, **kwargs)
                return result

            return hooked
        else:
            return orig_attr


In [8]:
# num_episodes_per_update = 4
envs = SubprocVecEnv([make_continual_env('continualMW-v0', **{'envs' : training_envs}) for i in range(num_processes)])
env = PyTorchVecEnvCont(envs, device)
storage = CustomOnlineStorage(
    500, num_processes, env.observation_space.shape[0]+1, 0, 0,
    env.action_space, ac.encoder.hidden_size, ac.encoder.latent_dim, False)

while env.get_env_attr('cur_step') < env.get_env_attr('steps_limit'):
    print(f"current step: {env.get_env_attr('cur_step')}; limit: {env.get_env_attr('steps_limit')}")
    step = 0
    
    # if I do this, I need to make sure my returns are calculated correctly / the done flags work
    # for i in range(num_episodes_per_update):
    obs = env.reset()
    done = [False for _ in range(num_processes)]
    # print(f"running episode {i}")
    ## get prior??? how frequent?
    # do at start of each episode for now
    with torch.no_grad():
        _, latent_mean, latent_logvar, hidden_state = agent.actor_critic.encoder.prior(num_processes)
        print(step)
        ## TODO: set the 500 to some sort of variable (max episode len?)
        assert len(storage.latent) == 0  # make sure we emptied buffers
        # print(f"saving hidden state to {i * 500}")
        storage.hidden_states[:1].copy_(hidden_state)
        latent = torch.cat((latent_mean.clone(), latent_logvar.clone()), dim=-1)#.reshape(1, -1)
        storage.latent.append(latent)

    while not all(done):
        value, action = agent.act(obs, latent, None, None)
        next_obs, reward, done, info = env.step(action)
        assert all(done) == any(done), "Metaworld envs should all end simultaneously"

        obs = next_obs

        ## TODO: do I even need masks?
        # create mask for episode ends
        masks_done = torch.FloatTensor([[0.0] if _done else [1.0] for _done in done]).to(device)
        # bad_mask is true if episode ended because time limit was reached
        # don't care for metaworld
        bad_masks = torch.FloatTensor([[0.0] for _done in done]).to(device)

        # if done:
        #     print(f'{step}: done!')
        #     hidden_state = agent.actor_critic.encoder.reset_hidden(hidden_state, masks_done)
        # print(action.size(), obs.squeeze(0).size(), reward.squeeze(0).size(), hidden_state.size(), latent.size(), masks_done.size(), bad_masks.size())
        _, latent_mean, latent_logvar, hidden_state = agent.actor_critic.encoder(action, obs, reward, hidden_state, return_prior = False)
        latent = torch.cat((latent_mean.clone(), latent_logvar.clone()), dim = -1)[None,:]#.reshape(1, -1)

        
        storage.next_state[step] = obs.clone()
        # print(action.squeeze(0).size(), obs.squeeze(0).size(), reward.squeeze(0).size(), hidden_state.size(), latent.size(), masks_done.size(), bad_masks.size())
        storage.insert(
            state=obs.squeeze(),
            belief=None,
            task=None,
            actions=action,
            rewards_raw=reward.squeeze(0),
            rewards_normalised=reward.squeeze(0),#rew_normalised,
            value_preds=value.squeeze(0),
            masks=masks_done.squeeze(0), # do I even need these?
            bad_masks=bad_masks.squeeze(0), 
            done=torch.from_numpy(done)[:,None].float(),
            hidden_states = hidden_state.squeeze(),
            latent = latent#.unsqueeze(1),
        )

        step += 1
        ### update
        # if step % num_updates ==0:

    # update at the end of each episode?
    agent.update(storage)
    # should clear out old data
    storage.after_update()
        




current step: 0; limit: 50000
0


/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [21,0,0], thread: [32,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [21,0,0], thread: [33,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [21,0,0], thread: [34,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [21,0,0], thread: [35,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [21,0,0], thread: [36,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [21,0,0

RuntimeError: CUDA error: device-side assert triggered

In [10]:
torch.cat(storage.latent)

tensor([[[ 0.0037,  0.0568,  0.0397,  ..., -0.0980, -0.0445, -0.0333],
         [ 0.0037,  0.0568,  0.0397,  ..., -0.0980, -0.0445, -0.0333],
         [ 0.0037,  0.0568,  0.0397,  ..., -0.0980, -0.0445, -0.0333],
         [ 0.0037,  0.0568,  0.0397,  ..., -0.0980, -0.0445, -0.0333]],

        [[-0.0807,  0.1167,  0.0092,  ..., -0.1912, -0.2447,  0.0459],
         [-0.0998,  0.1476,  0.0306,  ..., -0.1913, -0.2540,  0.0545],
         [-0.0718,  0.0751, -0.0131,  ..., -0.2039, -0.2138,  0.0240],
         [-0.1134,  0.1231,  0.0050,  ..., -0.2011, -0.2479,  0.0196]],

        [[-0.1516,  0.1161, -0.0543,  ..., -0.2709, -0.3519,  0.0283],
         [-0.2124,  0.1815, -0.0032,  ..., -0.2696, -0.3644,  0.0336],
         [-0.0949,  0.0863, -0.0512,  ..., -0.2828, -0.3137,  0.0642],
         [-0.1658,  0.1080, -0.0734,  ..., -0.2641, -0.3739,  0.0025]],

        ...,

        [[ 0.6353, -0.4397, -0.9656,  ..., -0.5909, -0.4097, -0.1709],
         [ 0.6032, -0.4428, -0.9768,  ..., -0.5935, -0.41

In [33]:
latent = [storage.latent[0].detach().clone()]
latent[0].requires_grad = True

h = storage.hidden_states[0].detach()
for i in range(storage.actions.shape[0]):
    # reset hidden state of the GRU when we reset the task
    h = agent.actor_critic.encoder.reset_hidden(h, storage.done[i])
    # not sure why this is i + 1?
    # h = self.actor_critic.encoder.reset_hidden(h, policy_storage.done[i + 1])

    _, tm, tl, h = agent.actor_critic.encoder(
        storage.actions.float()[i:i + 1],
        storage.next_state[i:i + 1],
        storage.rewards_raw[i:i + 1],
        h,
        sample=False,
        return_prior=False,
        detach_every=None
    )
    # latent_sample.append(ts)
    latent.append(torch.cat((tm, tl), dim = -1)[None,:])

In [39]:
(torch.cat(storage.latent) - torch.cat(latent)).sum()

tensor(0., device='cuda:0', grad_fn=<SumBackward0>)

In [38]:
storage.latent[0].size()

torch.Size([1, 4, 256])

In [37]:
latent

torch.Size([1, 4, 256])

In [35]:
i = 5
storage.latent[i] - latent[i]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       grad_fn=<SubBackward0>)

In [31]:
latent[1]

tensor([[-0.0807,  0.1167,  0.0092,  ..., -0.1912, -0.2447,  0.0459],
        [-0.0998,  0.1476,  0.0306,  ..., -0.1913, -0.2540,  0.0545],
        [-0.0718,  0.0751, -0.0131,  ..., -0.2039, -0.2138,  0.0240],
        [-0.1134,  0.1231,  0.0050,  ..., -0.2011, -0.2479,  0.0196]],
       device='cuda:0', grad_fn=<CatBackward>)

In [19]:
torch.cat(latent).size()

torch.Size([2004, 256])

In [82]:
reward.squeeze(0).size()

torch.Size([4, 1])

In [64]:
d

tensor([[[0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        [[1.2308],
         [1.2294],
         [1.2258],
         [1.2257]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        ...,

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         [0.0000]]])

In [46]:
storage.hidden_states[:1]

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0

In [27]:
vec_envs.reset()

Process Process-13:
Traceback (most recent call last):
  File "/home/grant/miniconda3/envs/varibad/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/grant/miniconda3/envs/varibad/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/grant/working_repos/varibad/environments/env_utils/vec_env/subproc_vec_env.py", line 21, in worker
    ob = env.reset()
  File "/home/grant/working_repos/varibad/environments/metaworld_envs/test_continual_env.py", line 58, in reset
    obs = torch.from_numpy(np.append(obs, 0).reshape(1, -1)).float().to(device)
  File "/home/grant/miniconda3/envs/varibad/lib/python3.9/site-packages/torch/cuda/__init__.py", line 163, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
Process Process-10:
Process Process-11:
Traceback (most recent c

EOFError: 

In [25]:
obs = pyt_vec.reset()

TypeError: reset() got an unexpected keyword argument 'index'

In [42]:
latent = [storage.latent[0].detach().clone()]
latent[0].requires_grad = True

h = storage.hidden_states[0].detach()
for i in range(storage.actions.shape[0]):
    # reset hidden state of the GRU when we reset the task
    if storage.done[i] == 1:
        print(i, 'Resetting', storage.done[i+1], storage.latent[i+1])
        h = agent.actor_critic.encoder.reset_hidden(h, storage.done[i])

        # also add the reset latent for the start of the next episode
        _latent = storage.latent[i].detach().clone()
        _latent.requires_grad = True
        latent.append(_latent)
    # not sure why this is i + 1?
    # h = self.actor_critic.encoder.reset_hidden(h, policy_storage.done[i + 1])

    _, tm, tl, h = agent.actor_critic.encoder(
        storage.actions.float()[i:i + 1],
        storage.next_state[i:i + 1],
        storage.rewards_raw[i:i + 1],
        h,
        sample=False,
        return_prior=False,
        detach_every=None
    )

    latent.append(torch.cat((tm, tl), dim = -1))




500 Resetting tensor([[0.]], device='cuda:0') tensor([[ 0.0037,  0.0568,  0.0397, -0.0055, -0.0818, -0.0101,  0.0708,  0.0079,
          0.0424,  0.1191, -0.0008, -0.0250, -0.0121, -0.0623, -0.0995, -0.0121,
         -0.0115, -0.0719,  0.0463, -0.1275,  0.0058,  0.0608,  0.0959, -0.0609,
          0.0416,  0.0727, -0.0562,  0.0961, -0.0383,  0.0316, -0.0229,  0.0256,
         -0.0508,  0.0397,  0.0675, -0.0524,  0.0099, -0.0123,  0.0097,  0.0050,
          0.0064, -0.0643,  0.0515, -0.0102,  0.0280,  0.0607,  0.0094, -0.0862,
          0.0151, -0.0556,  0.0345,  0.0694, -0.0842, -0.0672, -0.0334, -0.0391,
          0.0942,  0.0439,  0.0679,  0.0088, -0.0418,  0.0701, -0.0439,  0.0398,
         -0.0556, -0.0207,  0.0492,  0.0810,  0.0516, -0.0498,  0.0138, -0.0205,
          0.0046,  0.0644, -0.0265,  0.0170,  0.0024,  0.0657,  0.0330,  0.0506,
          0.0059,  0.0069, -0.0340,  0.0629, -0.0139,  0.0307,  0.0692, -0.0419,
         -0.0929,  0.0006, -0.0722,  0.0374,  0.0375,  0.0830, 

In [33]:
# storage.latent[501] - 
(torch.cat(latent) == storage.latent[0]).sum()

tensor(256, device='cuda:0')

In [43]:
(torch.cat(storage.latent) - torch.cat(latent)).sum() 

tensor(-1226.3718, device='cuda:0', grad_fn=<SumBackward0>)

In [53]:
np.where(storage.done.cpu().detach().numpy()==1)

(array([ 500, 1000, 1500, 2000]), array([0, 0, 0, 0]), array([0, 0, 0, 0]))

In [60]:
np.unique(np.where(torch.cat(storage.latent).cpu().detach().numpy()==storage.latent[0].cpu().detach().numpy()[0])[0])

array([   0,  501, 1002, 1503])

In [58]:
torch.cat(storage.latent).cpu().detach().numpy()[0]

array([ 0.0036862 ,  0.05677195,  0.03968116, -0.00554151, -0.08180048,
       -0.01013042,  0.07082925,  0.00790204,  0.0423823 ,  0.11912974,
       -0.00084364, -0.02503387, -0.01213983, -0.06234327, -0.09946632,
       -0.01211663, -0.01150087, -0.0718758 ,  0.04626398, -0.12751377,
        0.00579331,  0.06083824,  0.09590048, -0.06088419,  0.04160011,
        0.07273776, -0.05615995,  0.09611335, -0.03827709,  0.03161815,
       -0.02291975,  0.02557641, -0.05081708,  0.03965855,  0.0675004 ,
       -0.05241253,  0.00993486, -0.01231433,  0.00968155,  0.00498494,
        0.00644882, -0.06429251,  0.05148085, -0.01024368,  0.02796842,
        0.06069341,  0.00942255, -0.08616532,  0.01513379, -0.05556767,
        0.03445356,  0.06938666, -0.08415639, -0.06715772, -0.0333991 ,
       -0.03908315,  0.09423542,  0.04387035,  0.06786837,  0.00878241,
       -0.04184344,  0.0700689 , -0.0438891 ,  0.03982812, -0.05555326,
       -0.02070116,  0.04923742,  0.08097896,  0.05160142, -0.04

In [65]:
a = storage.hidden_states ==0
indices = a.nonzero()
print(indices)

tensor([[   0,    0,    0],
        [   0,    0,    1],
        [   0,    0,    2],
        ...,
        [1500,    0,  125],
        [1500,    0,  126],
        [1500,    0,  127]], device='cuda:0')


In [60]:
storage.hidden_states.size()[-1]*4

512

In [None]:
## training loop
# while the whole continual env is not done
# train on each env sequentially
# periodically evaluate on all envs

agent = ActorCritic(policy_net, encoder_net)
cont_env = ContinualEnv(training_envs, 10)
eval_freq = 10
num_steps = 0
eval_results = dict()
while cont_env.cur_step < cont_env.steps_limit:
    # do each env
    obs = cont_env.reset()
    done = False
    while not done:
        if agent is not None:
            obs = torch.from_numpy(np.append(obs, 0).reshape(1, -1))[None,:,:].float().to(device)
            if hidden_state is not None:
                reward = torch.from_numpy(np.array(reward).reshape(1, -1))[None,:,:].float().to(device)
                action = torch.from_numpy(action)[None, None,:].float().to(device)
            _, act, hidden_state = agent.act(action, obs, reward, hidden_state)
            action = act.cpu().detach().numpy()[0]
        else:
            action = env.action_space.sample()

        next_obs, reward, done, info = cont_env.step(action)
        obs = next_obs

        # periodically evaluate
        if  (cont_env.cur_step + 1) % eval_freq == 0:
            print(cont_env.cur_step, 'EVALUATING')
            all_envs = cont_env._get_envs()
            eval_results[cont_env.cur_step] = evaluate_all_envs(all_envs, num_episodes = 3)
            eval_results[cont_env.cur_step]['task'] = cont_env.cur_seq_idx
        
    print(cont_env.cur_step, cont_env.cur_seq_idx)


        