In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist

import gymnasium as gym

In [None]:
PARAMS = {
    "continuous":False,
    "gravity":-10.0,
    "enable_wind": False,
    "wind_power":15.0,
    "turbulence_power":1.5
}

## create lunar lander
env = gym.make_vec("LunarLander-v3", num_envs= 3, **PARAMS)

step, _ = env.reset()
num_steps = 200
for _ in range(num_steps):
    action = env.action_space.sample()
    next_state, reward, terminated, truncated, info = env.step(action)
    done = [any(i) for i in zip(terminated, truncated)] 

    state = next_state

In [15]:
10%2

0

In [61]:
class ContinualEnv(gym.Env):

    def __init__(self, params: list[dict], steps_per_env: int, num_envs: int = 1):
        self.params = params
        self.num_envs = num_envs
        self.steps_per_env = steps_per_env
        assert self.steps_per_env%self.num_envs == 0, "steps_per_env must be divisible by num_envs"

        self.envs = [gym.make_vec(num_envs = self.num_envs, **param) for param in self.params]
        self.current_env = 0
        self.current_step = 0
        self.current_env_instance = self.envs[self.current_env]
        self.action_space = self.current_env_instance.action_space
        self.observation_space = self.current_env_instance.observation_space

        ## TODO: do i need action/observation spaces?
    
    def reset(self):
        return self.current_env_instance.reset()
    
    def step(self, action):
        next_state, reward, terminated, truncated, info = self.current_env_instance.step(action)
        self.current_step += 1*self.num_envs
        
        if self.current_step >= self.steps_per_env:
            print(f"Switching to next environment: {self.current_env} -> {self.current_env + 1} at step {self.current_step}")
            self.current_env += 1
            if self.current_env < len(self.envs):
                self.current_env_instance = self.envs[self.current_env]
                self.current_step = 0
                next_state = self.current_env_instance.reset()
            else:
                print("No more environments to switch to. Continuing with the last environment.")
                ## TODO: handle end of environments gracefully
        
        return next_state, reward, terminated, truncated, info

In [62]:
PARAMS = [ 
        {
            "id":"LunarLander-v3",
            "continuous":False,
            "gravity":-10.0,
            "enable_wind": False,
            "wind_power":15.0,
            "turbulence_power":1.5
        },
        {
            "id":"LunarLander-v3",
            "continuous":False,
            "gravity":-10.0,
            "enable_wind": True,
            "wind_power":15.0,
            "turbulence_power":1.5
        }
    ]



In [63]:
num_steps = 1000 * len(PARAMS) // 4
c_env = ContinualEnv(PARAMS, steps_per_env=1000, num_envs=4)
step, _ = c_env.reset()
for _ in range(num_steps):
    action = c_env.action_space.sample()
    next_state, reward, terminated, truncated, info = c_env.step(action)
    # print(f"Step: {c_env.current_step}, Env: {c_env.current_env}, Action: {action}, Reward: {reward}")
    # print(f"Terminated: {terminated}, Truncated: {truncated}, Info: {info}")
    done = [any(i) for i in zip(terminated, truncated)] 
    
    # if done:
    #     print(f"Environment {c_env.current_env} finished after {c_env.current_step} steps.")
    #     step, _ = c_env.reset()
    # else:
    state = next_state

Switching to next environment: 0 -> 1 at step 1000
Switching to next environment: 1 -> 2 at step 1000
No more environments to switch to. Continuing with the last environment.


In [None]:
class ContinualTrainer:

    def __init__(self, env: ContinualEnv, agent, logger):
        self.env = env
        self.agent = agent
        self.logger = logger

    def train(self, num_epochs: int, steps_per_epoch: int = 1000):
        
        for _ in range(num_epochs):
            state, _ = self.env.reset(self.seed)
            for _ in range(steps_per_epoch):

                action = self.agent.act(state)
                next_state, reward, terminated, truncated, info = self.env.step(action)
                done = [any(i) for i in zip(terminated, truncated)]

                self.agent.record(state, action, reward, next_state, done)
                self.logger.log(state, action, reward, next_state, done, info)
                
                state = next_state
            
            self.agent.update()
            ## TODO: log agent performance, maybe save model
            ## TODO: evaluate agent performance on all environments