In [1]:
import os
import glob
import math
import pickle
import random
import datetime
from collections import defaultdict
import copy
from enum import Enum
import numpy as np 
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import Module
import tqdm
import gym
from gym import Env
from gym.spaces import Box, Discrete
from gym.utils.env_checker import check_env
from typing import (
    Type,
    OrderedDict,
    List,
    Tuple,
    Callable,
)
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import stable_baselines3
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from torch.utils.tensorboard import SummaryWriter


In [2]:
print(stable_baselines3.__version__)
print(gym.__version__) # should be 0.21.0 to be compatible with stable_baselines3 1.8

1.8.0
0.21.0


## config

In [9]:
config = {
    'SEED' :123,
    'DEVICE' : 'cpu',
    'EPOCHS' : 2,
    'TIMESTEPS' : 5000,
    'N_X' : 100,
    'N_TASKS' : 5,
    'POOL_N_LAYERS' : 100,
    'N_NODES_PER_LAYER' : 32,
    'POOL_LAYER_TYPE' : torch.nn.Linear,
    'ACTION_SPACE_SHAPE' : (3,),
    'EPSILON' : 0.1,
    'BATCH_SIZE' : 1,
    'LEARNING_RATE' : 0.05,
    'ACTION_CACHE_SIZE' : 5,
    'NUM_WORKERS' : 0,
    'LOSS_FN' : torch.nn.MSELoss(),
    'SB3_MODEL' : PPO,
    'SB3_POLICY' : 'MlpPolicy',
    }
config

{'SEED': 123,
 'DEVICE': 'cpu',
 'EPOCHS': 2,
 'TIMESTEPS': 5000,
 'N_X': 100,
 'N_TASKS': 5,
 'POOL_N_LAYERS': 100,
 'N_NODES_PER_LAYER': 32,
 'POOL_LAYER_TYPE': torch.nn.modules.linear.Linear,
 'ACTION_SPACE_SHAPE': (3,),
 'EPSILON': 0.1,
 'BATCH_SIZE': 1,
 'LEARNING_RATE': 0.05,
 'ACTION_CACHE_SIZE': 5,
 'NUM_WORKERS': 0,
 'LOSS_FN': MSELoss(),
 'SB3_MODEL': stable_baselines3.ppo.ppo.PPO,
 'SB3_POLICY': 'MlpPolicy'}

## Reinforcement Meta-Learning (REML) / "Learning to Learn by Gradient Descent as a Markov Deicision Process"

- layer pool
- inner network -- composed of layers from layer pool
- outer network (meta learner) -- responsible for parameters and hyperparameters of inner network

In [10]:
class Layer:
    def __init__(self, 
                type: Type[torch.nn.Module]=config['POOL_LAYER_TYPE']):
        self.type = type
        self.params = type
        self.used = False
        self.times_used = 0

class LayerPool:
    # pool of uniform Layer objects each with the same type and shape
    def __init__(self, 
                size: int=config['POOL_N_LAYERS'], 
                layer_type: Type[torch.nn.Module]=config['POOL_LAYER_TYPE'],
                num_nodes_per_layer: int=config['N_NODES_PER_LAYER']):
        self.size = size
        self.layer_type = layer_type
        self.num_nodes_per_layer = num_nodes_per_layer

        # each layer that is used gets updated (i.e., their parameters change and the copy in 
        # this layer pool is updated), except for the first and last layers which are unique
        # for each task
        self.layers = {
            i : Layer(
                type=self.layer_type(in_features=num_nodes_per_layer, out_features=num_nodes_per_layer)
                )
            for i in range(size)}

        [torch.nn.init.xavier_uniform_(layer.params.weight) for layer in self.layers.values()]
        
    def __str__(self):
        return f"LayerPool(size={self.size}, layer_type={config['POOL_LAYER_TYPE']}, num_nodes_per_layer={config['N_NODES_PER_LAYER']}"



In [11]:
class InnerNetworkAction(Enum):
    UNAVAILABLE = 0
    ADD = 1
    DELETE = 2
    TRAIN = 3

In [12]:
class InnerNetworkTask(Dataset):
    def __init__(self, data, targets, info):
        self.data = data 
        self.targets = targets
        self.info = info

    def __len__(self):
        assert len(self.data) == config['N_X'], '[ERROR] Length should be the same as N_X.'
        return len(self.data)

    def __getitem__(self, index):
        assert self.data[index].dtype == torch.float32, f'[ERROR] Expected type torch.float32, got type: {self.data[index].dtype}'
        assert self.targets[index].dtype == torch.float32, f'[ERROR] Expected type torch.float32, got type: {self.targets[index].dtype}'
        sample = {
            'x' : self.data[index],
            'y' : self.targets[index],
            'info' : self.info[index]
        }
        return sample
    
    def __str__(self):
        return f'[INFO] InnerNetworkTask(data={self.data, self.targets}, info={self.info})'

In [13]:
class InnerNetwork(gym.Env, Module):
    def __init__(self, 
                task: InnerNetworkTask,
                layer_pool: LayerPool,
                learning_rate: float=config['LEARNING_RATE'],
                batch_size: int=config['BATCH_SIZE'],
                epsilon: float=config['EPSILON'],
                action_cache_size: float=config['ACTION_CACHE_SIZE'],
                num_workers: int=config['NUM_WORKERS'],
                shuffle: bool=True,
                log_dir: str='runs',
                ):
        super(InnerNetwork, self).__init__()
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.layer_pool = layer_pool
        self.task = task
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.data_loader = DataLoader(task, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        self.data_iter = iter(self.data_loader)
        self.prev = defaultdict(lambda: None)
        self.curr = defaultdict(lambda: None)
        self.initial_layer = torch.nn.Linear(1, self.layer_pool.num_nodes_per_layer) # TODO is to have param with input_shape
        self.final_layer = torch.nn.Linear(self.layer_pool.num_nodes_per_layer, 1) # TODO is to have param with output_shape
        torch.nn.init.xavier_uniform_(self.final_layer.weight)
        torch.nn.init.xavier_uniform_(self.initial_layer.weight)
        self.layers = torch.nn.ModuleList([self.initial_layer, self.final_layer])
        self.pool_indices = [] 
        self.loss_fn = torch.nn.MSELoss()
        self.opt = torch.optim.Adam(self.layers.parameters(), lr=learning_rate)
        self.action_cache_size = action_cache_size
        self.actions_taken = [InnerNetworkAction.UNAVAILABLE] * config['ACTION_CACHE_SIZE']

        # logging variables
        self.writer = SummaryWriter(log_dir=log_dir)
        self.timestep = 0
        self.reset_num = 0

        # state and action spaces
        self.state = self.reset()
        state_shape = self.build_state().shape
        self.observation_space = Box(low=float('-inf'), high=float('inf'), shape=state_shape) # TODO is to normalize
        self.action_space = Discrete(self.layer_pool.size * 2 + 1)

        # print(f'[INFO] Number of available actions={self.action_space.n}')
        # print(f'[INFO] Range of action indices={}')


    def step(self, action: np.ndarray) -> Tuple[torch.Tensor, float, bool, dict]: 
        assert action.shape == (), f'[ERROR] Expected action shape () for scalar {self.action_space.n}, got: {action.shape}'
        assert action.dtype == np.int64, f'[ERROR] Expected np.int64 dtype, got: {action.dtype}'
        self.prev = self.curr
        self.curr = defaultdict(lambda: None)
        action_index = self.epsilon_greedy(action)
        self.update_inner_network(action_index)
        self.forward_inner_network()

        # log
        task_str = str(self.curr['info']['i'].item())
        self.writer.add_scalar(f'loss/timestep_task_{task_str}', self.curr['loss'], global_step=self.timestep) 
        self.writer.add_scalar(f'num_layers/timestep_task_{task_str}', len(self.layers), global_step=self.timestep) 
        if (len(self.pool_indices)!=0):
            self.writer.add_histogram(f'pool_indices/timestep_task_{task_str}', torch.tensor(self.pool_indices).long(), global_step=self.timestep) 
        self.writer.add_histogram(f'action_types/timestep_task{task_str}', torch.tensor([e.value for e in self.actions_taken]).long(), global_step=self.timestep) 

        self.timestep += 1
        s_prime = self.build_state()
        reward = self.reward()
        if (self.curr['action_taken']==InnerNetworkAction.TRAIN):
            print(f'[INFO] TRAINED')
        # print(f'[INFO] Timestep={self.timestep}')
        return (
            s_prime,
            reward, 
            False, 
            {}
        )
    
    def epsilon_greedy(self, action: np.int64) -> int:
        if random.random() <= self.epsilon: action_index = random.randint(0, self.action_space.n - 1)
        else: action_index = action
        # print(f'[INFO] Action index {action_index} (range: 0 - {self.action_space.n - 1})')
        return action_index
    
    def update_inner_network(self, action_index: int) -> None:
        if (action_index == 0): 
            action_type = InnerNetworkAction.TRAIN
        elif (action_index > 0 and action_index < self.layer_pool.size): 
            action_type = InnerNetworkAction.ADD
        else:
            action_type = InnerNetworkAction.DELETE

        if (action_type == InnerNetworkAction.TRAIN):
            # print('[INFO] "Train" action taken by inner network.')
            self.actions_taken.append(InnerNetworkAction.TRAIN)
            self.curr['action_type'] = InnerNetworkAction.TRAIN

        if (action_type == InnerNetworkAction.ADD):
            # print('[INFO] "Add" action taken by inner network.')
            self.pool_indices.append(action_index)
            next_layer = self.layer_pool.layers[action_index].params
            final_layer = self.layers.pop(-1)
            self.layers.append(next_layer)  
            self.layers.append(final_layer) 
            self.actions_taken.append(InnerNetworkAction.ADD)
            self.curr['action_type'] = InnerNetworkAction.ADD

        if (action_type == InnerNetworkAction.DELETE):
            # print('[INFO] "Delete" action taken by inner network.')
            if (action_index not in self.pool_indices):
                self.actions_taken.append(InnerNetworkAction.UNAVAILABLE)
                return
            adjusted_pool_index = action_index = self.layer_pool.size
            self.pool_indices.remove(adjusted_pool_index)
            layer_to_delete = self.layer_pool.layers[adjusted_pool_index].params
            network_index = self.layers.index(layer_to_delete)
            assert layer_to_delete == self.layers[network_index], '[ERROR] Wrong layer would be deleted from inner network params.'
            self.layers.pop(network_index)
            self.actions_taken.append(InnerNetworkAction.DELETE)
            self.curr['action_type'] = InnerNetworkAction.DELETE

    def next_batch(self, throw_exception=False) -> dict:
        if (throw_exception):
            return next(self.data_iter)
        else: 
            try:
                batch = next(self.data_iter)
                return batch
            except StopIteration:
                self.data_loader = DataLoader(self.task, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers)
                self.data_iter = iter(self.data_loader)
                return next(self.data_iter)

    def forward_inner_network(self) -> None: 
        batch = self.next_batch()
        self.curr['x'] = batch['x']
        self.curr['y'] = batch['y'] 
        self.curr['info'] = batch['info']
        # set model to train or eval
        if self.curr['action_type'] == InnerNetworkAction.TRAIN:
            self.train() # needs to be called before forward pass for gradient information to be saved
            self.opt = torch.optim.Adam(self.layers.parameters(), lr=self.learning_rate) 
            self.opt.zero_grad()
        else:
            self.eval()
        # forward pass
        x = self.curr['x']
        for i in range(len(self.layers) - 1): x = torch.nn.functional.relu(self.layers[i](x))
        self.curr['latent_space'] = x
        self.curr['y_hat'] = self.layers[-1](x) 
        self.curr['loss'] = self.loss_fn(self.curr['y'], self.curr['y_hat'])
        assert self.curr['latent_space'].dtype == torch.float32
        assert self.curr['y_hat'].dtype == torch.float32
        # update params if 'train'
        if (self.curr['action_type'] == InnerNetworkAction.TRAIN):
            print(f'[INFO] Trained network.')
            loss = self.curr['loss']
            loss.backward()
            self.opt.step()

    def build_state(self) -> np.ndarray:
        num_add_actions = torch.tensor(len(list(filter(lambda e : e == InnerNetworkAction.ADD, self.actions_taken)))).unsqueeze(0)
        num_delete_actions = torch.tensor(len(list(filter(lambda e : e == InnerNetworkAction.DELETE, self.actions_taken)))).unsqueeze(0)
        num_train_actions = torch.tensor(len(list(filter(lambda e : e == InnerNetworkAction.TRAIN, self.actions_taken)))).unsqueeze(0)
        num_layers = torch.tensor(len(self.layers)).unsqueeze(0)
        h = torch.tensor([action_enum.value for action_enum in self.actions_taken[-self.action_cache_size:]])
        task_info = torch.tensor([float(value) for value in self.curr['info'].values()])
        return torch.concat((
            task_info,
            self.curr['x'],
            self.curr['latent_space'],
            self.curr['y'],
            self.curr['y_hat'],
            num_add_actions,
            num_delete_actions,
            num_train_actions,
            num_layers,
            h
        ), dim=0).detach().numpy()
    
    def reward(self) -> torch.Tensor:
        prev_loss = self.prev['loss']
        curr_loss = self.curr['loss']
        # assert prev_loss != curr_loss, '[ERROR] Loss values from previous and current run are equal.'
        delta_loss = prev_loss - curr_loss
        curr_action = self.curr['action_type']
        if (curr_action == InnerNetworkAction.ADD):
            reward = delta_loss / math.sqrt(len(self.layers))
        if (curr_action == InnerNetworkAction.TRAIN or InnerNetworkAction.DELETE):
            reward = delta_loss
        if (curr_action == InnerNetworkAction.UNAVAILABLE):
            reward = -1000
        return reward

    def reset(self) -> np.ndarray:
        self.forward_inner_network()
        print(f'[INFO] Reset called.')
        self.reset_num += 1
        return self.build_state()
    
    def close(self):
        print(f'[INFO] Closed writer')
        self.writer.close()

### baby example with just 2 epochs
However, 2 epoch is still going through all 20 tasks, generating 20 different networks, using the same 1 meta policy. Also, on each task we loop through the data 1000 times (timesteps / len(X) == 10000 / 100).
<br>
The benefit of more epochs is we get the meta policy to gnerate 20 different networks again for the 20 tasks, having hopefully learned something from the first run.

### sinusoidal curves regression as in 2018 MAML paper

In [15]:
# create tasks
X = np.linspace(lower_bound, upper_bound, config['N_X'])
lower_bound = torch.tensor(-5).float()
upper_bound = torch.tensor(5).float()
amplitude_range = torch.tensor([0.1, 5.0]).float()
phase_range = torch.tensor([0, math.pi]).float()
amps = torch.from_numpy(np.linspace(amplitude_range[0], amplitude_range[1], config['N_TASKS'])).float()
phases = torch.from_numpy(np.linspace(phase_range[0], phase_range[1], config['N_TASKS'])).float()

# (20, 100)
tasks_data = torch.tensor([ 
        X
        for _ in range(config['N_TASKS'])
        ]).float()
tasks_targets = torch.tensor([
        [((a * np.sin(x)) + p).float()
        for x in X] 
        for a, p in zip(amps, phases)
        ]).float()
tasks_info = [
        [{'i' : i, 
         'amp' : a, 
         'phase_shift' : p, 
         'lower_bound' : lower_bound, 
         'upper_bound' : upper_bound, 
         'amplitude_range_lower_bound' : amplitude_range[0], 
         'amplitude_range_upper_bound' : amplitude_range[1], 
         'phase_range_lower_bound' : phase_range[0],
         'phase_range_lower_bound' : phase_range[1]}
         for _ in X]
        for i, (a, p) in enumerate(zip(amps, phases))
]

  tasks_data = torch.tensor([


In [16]:
print(tasks_data.shape)
print(tasks_data.dtype)
print(tasks_targets.shape)
print(tasks_targets.dtype)
print(len(tasks_info))
print(len(tasks_info[0]))

torch.Size([5, 100])
torch.float32
torch.Size([5, 100])
torch.float32
5
100


In [20]:
# create tasks
tasks = [InnerNetworkTask(data=tasks_data[i], targets=tasks_targets[i], info=tasks_info[i]) for i in range(config['N_TASKS'])]
# create pool
pool = LayerPool()
# create REML
log_dir = f'./runs/ppo_{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
model = REML(layer_pool=pool, tasks=tasks, log_dir=log_dir)

In [21]:
config

{'SEED': 123,
 'DEVICE': 'cpu',
 'EPOCHS': 2,
 'TIMESTEPS': 5000,
 'N_X': 100,
 'N_TASKS': 5,
 'POOL_N_LAYERS': 100,
 'N_NODES_PER_LAYER': 32,
 'POOL_LAYER_TYPE': torch.nn.modules.linear.Linear,
 'ACTION_SPACE_SHAPE': (3,),
 'EPSILON': 0.1,
 'BATCH_SIZE': 1,
 'LEARNING_RATE': 0.05,
 'ACTION_CACHE_SIZE': 5,
 'NUM_WORKERS': 0,
 'LOSS_FN': MSELoss(),
 'SB3_MODEL': stable_baselines3.ppo.ppo.PPO,
 'SB3_POLICY': 'MlpPolicy'}

In [22]:
# train
model.train()

[INFO] Epoch 1/2
[INFO] Task num=1/5
[INFO] Reset called.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
[INFO] Reset called.
Logging to ./runs/ppo_2023-10-10_12-44-57\task_0_0


We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=100 and n_envs=1)


----------------------------
| time/              |     |
|    fps             | 189 |
|    iterations      | 1   |
|    time_elapsed    | 0   |
|    total_timesteps | 100 |
----------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 152        |
|    iterations           | 2          |
|    time_elapsed         | 1          |
|    total_timesteps      | 200        |
| train/                  |            |
|    approx_kl            | 0.01574725 |
|    clip_fraction        | 0.0852     |
|    clip_range           | 0.2        |
|    entropy_loss         | -5.3       |
|    explained_variance   | -7.37      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.11      |
|    n_updates            | 10         |
|    policy_gradient_loss | -0.0637    |
|    value_loss           | 0.0232     |
----------------------------------------
[INFO] Trained network.
-----------------------------------