Testing a toy example for demonstrating MAML in RL.

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch
from torch import nn
from torch import optim
from torch.autograd import grad
from higher import innerloop_ctx
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

import notebook_setup
import ppo, utils
from ppo import DEVICE
from systems import CartPoleEnv, plot_cartpole

### Standard MAML training

In [None]:
def train(agent: ppo.PPO, tasks, n=10, losses=None, seed=None, lr_meta=None, lr_inner=None):
    """Pre-training model using higher-order gradients on per-task samples from `training_tasks`"""
    if seed is not None: torch.manual_seed(seed)
    model, opt = agent.policy, agent.optimizer
    og_lrs = []
    if lr_meta is not None:
        for pgroup in opt.param_groups:
            og_lrs.append(pgroup['lr'])
            pgroup['lr'] = lr_meta
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=n)
    memory = ppo.Memory()
    for e in trange(n, leave=False, desc='MAML Initialization'):
        opti = optim.Adam(model.parameters(), **opt.defaults)
        for i, env in enumerate(tasks):
            with innerloop_ctx(model, opti, copy_initial_weights=False,
                              override=None if lr_inner is None else dict(lr=[lr_inner])) \
            as (fmodel, diffopt):

                agent.experience(memory, timesteps=k, env=env, policy=fmodel)
                agent.update(fmodel, memory, epochs=1, optimizer=diffopt, higher_optim=True)
                memory.clear()
                
                agent.experience(memory, timesteps=k, env=env, policy=fmodel)
                l = agent.update(fmodel, memory, optimizer=None)
                memory.clear()
                if losses is not None:
                    losses[i].append(l.item())

        opt.step()
        scheduler.step()
        opt.zero_grad()
    if lr_meta is not None:
        for pgroup, lr in zip(opt.param_groups, og_lrs):
            pgroup['lr'] = lr

def benchmark_train(agent, tasks, n=10, losses=None, seed=None):
    """Pre-training model on aggregated samples from `training_tasks`"""
    if seed is not None: torch.manual_seed(seed)
    model, opt = agent.policy, agent.optimizer
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=n)
    memory = ppo.Memory()
    for _ in trange(n, leave=False, desc='Bench. Initialization'):
        for i, env in enumerate(tasks):

                agent.experience(memory, timesteps=k, env=env, policy=model)
                l = agent.update(model, memory, epochs=1, optimizer=opt, higher_optim=False)
                memory.clear()
                if losses is not None:
                    losses[i].append(l.item())

        opt.step()
        scheduler.step()
        opt.zero_grad()

In [None]:
def test(agent, n, env, losses=None, rewards=None, callback=None, track_higher_grads=False, seed=None):
    """Fine-tuning model on tasks contained in `evaluation_tasks`"""
    if seed is not None: torch.manual_seed(seed)
    model = agent.policy
    memory = ppo.Memory()
    opti = optim.Adam(model.parameters(), **agent.optimizer.defaults)
    with innerloop_ctx(model, opti, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
        for e in trange(n, leave=False, desc='Testing'):

            agent.experience(memory, timesteps=k, env=env, policy=fmodel)
            l = agent.update(fmodel, memory, epochs=n_adapt, optimizer=diffopt, higher_optim=True)

            if losses is not None or rewards is not None:
                episodic_rewards = utils.cache_to_episodic_rewards([memory.rewards], [memory.is_terminals])
                if losses is not None: losses.append(l.item())
                if rewards is not None: rewards.append(np.nanmean(episodic_rewards))
            
            if callback is not None:
                callback(locals())

            memory.clear()

# Experiments

In [None]:
training_tasks = [0,1,2,3,4]
evaluation_task = 6
k = 500        # number of examples per task
alpha = 0.02   # global learning rate
alpha_meta = 0.02
alpha_inner = 0.1 # initialization lr for our approach
n_adapt = 5
n_train = 10
n_test = 30

In [None]:
def get_task(seed=None, randomize=True):
    env = CartPoleEnv(seed)
    return env.randomize() if randomize else env

def make_model(seed=None):
    return ppo.PPO(
        env=None,
        policy=ppo.ActorCriticDiscrete,
        state_dim=4,
        action_dim=2,
        n_latent_var=32,
        lr=alpha,
        seed=seed)

In [None]:
a = make_model()
e = get_task(randomize=False)
a.env = e
r = a.learn(50000, 500)

In [None]:
plt.scatter(np.arange(len(r)), r)
plt.ylim(top=1000)

### Standard MAML testing

In [None]:
agent = make_model(0)
bench = make_model(0)
bench.policy.load_state_dict(agent.policy.state_dict())

r_test = []
r_bench_test = []

train(agent, list(map(get_task, training_tasks)), n=n_train, seed=0, lr_meta=alpha_meta, lr_inner=alpha_inner)
benchmark_train(bench, list(map(get_task, training_tasks)), n=n_train, seed=0)

env = get_task(evaluation_task)
test(agent, n=n_test, env=env, rewards=r_test, seed=0)
env = get_task(evaluation_task)
test(bench, n=n_test, env=env, rewards=r_bench_test, seed=0)

In [None]:
# Loss perfomance
plt.plot(r_test, label='MAML')
plt.plot(r_bench_test, label='Bench')
plt.grid(True)
plt.legend()

In [None]:
# Plotting actual outputs
plt.figure(figsize=(12, 6))
plt.subplot(1,2,1)
plot_cartpole(get_task(evaluation_tasks[-1]), agent, legend=False)
plt.subplot(1,2,2)
test(agent, evaluation_tasks[-1], losses=None, callback=lambda args: plot_cartpole(args.get('env'), args.get('agent')))
plt.tight_layout()