In [33]:
%reload_ext autoreload
%autoreload 2
from copy import deepcopy
import stable_baselines3 as sb3
from stable_baselines3.common.policies import ActorCriticPolicy
import torch
import torch.nn as nn
import numpy as np
import gym
from tqdm.autonotebook import tqdm, trange

from systems.cartpole import create_cartpole, CartpoleEnv

# Reptile

In [44]:
class Reptile:
    def __init__(self, envs, pretrain_steps, beta=0.01, algorithm=sb3.PPO, inner_kwargs={}, **kwargs):
        self.envs = envs
        self.env = envs[0]
        self.meta_envs = envs[1:]
        self.beta = beta
        self.pretrain_steps = pretrain_steps
        self.kwargs = kwargs
        self.algorithm = algorithm
        self.inner_kwargs = deepcopy(kwargs)
        self.inner_kwargs.update(**inner_kwargs)
        self.inner_kwargs['tensorboard_log'] = None
        self.outer = self.algorithm(ActorCriticPolicy, self.env, **self.kwargs)

    def learn(self, total_timesteps, reset_timesteps=True):
        if self.reset_timesteps:
            self.outer.learn(total_timesteps=self.pretrain_steps)
        self.inner = self.algorithm(ActorCriticPolicy, self.env, **self.inner_kwargs)
        for batch in trange(0, total_timesteps, self.kwargs.get('n_steps'), leave=False):
            self.inner.policy.load_state_dict(self.outer.policy.state_dict())  
            for env in tqdm(self.meta_envs, leave=False):
                self.inner.set_env(env)
                self.inner.learn(total_timesteps=self.inner_kwargs.get('n_steps'),
                                 progress_bar=False)
                inner_params = self.inner.policy.parameters()
                outer_params = self.outer.policy.parameters()
                with torch.no_grad():
                    for  outer, inner in zip(outer_params, inner_params):
                        outer.add_(self.beta * (inner - outer))
            self.outer.learn(self.kwargs.get('n_steps'), reset_num_timesteps=False)

In [31]:
sys_kwargs = dict(mc=0.5, mp=0.1, l=1, g=10, df=0.01)
learn_kwargs = dict(seed=0, learning_rate=2e-3,
                    n_steps=2048, batch_size=64, n_epochs=10,
                    gamma=0.99)
total_timesteps = 50_000
q = np.asarray([[1,0,0,0], [0,0.1,0,0],[0,0,1e-5,0],[0,0,0,1e-1]])
r = np.asarray([[0.00001]])
xformA = np.diagflat(np.random.RandomState(seed=0).randn(4))
xformB = np.diagflat(np.random.RandomState(seed=1).randn(4))
x0 = np.asarray([-np.pi/45, 0, 0, 0])
make_env = lambda: CartpoleEnv(**sys_kwargs, q=q, seed=0)
def sample_env(seed=None, n=1):
    random = np.random.RandomState(seed)
    envs = []
    for _ in range(n):
        kwargs = deepcopy(sys_kwargs)
        kwargs['mc'] *= (1 + random.randn() * 0.2)
        kwargs['mp'] *= (1 + random.randn() * 0.2)
        kwargs['l'] *= (1 + random.randn() * 0.2)
        env = CartpoleEnv(**kwargs, q=q, seed=0)
        envs.append(env)
    return envs
env = make_env()
sys = create_cartpole(**sys_kwargs)

In [45]:
agent = Reptile([make_env()] + sample_env(n=3), pretrain_steps=5_000,
               tensorboard_log='./tensorboard/Cartpole/reptile/',
                **learn_kwargs)
agent.learn(50_000)

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]