In [None]:
#default_exp imaml
#export
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import higher
import wandb

import pytorch_lightning as pl
from itertools import repeat

from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from unsupervised_meta_learning.pl_dataloaders import OmniglotDataModule
from unsupervised_meta_learning.hessian_free import HessianFree
from unsupervised_meta_learning.nn_utils import get_accuracy
from unsupervised_meta_learning.maml import ConvolutionalNeuralNetwork

import unsupervised_meta_learning.hypergrad as hg

In [None]:
#export
class iMAML(pl.LightningModule):
    def __init__(self, model, meta_lr, inner_lr, inner_steps, cg_steps, reg_param, hg_mode='CG'):
        super().__init__()
        self.automatic_optimization = False
        self.model = model
        self.meta_lr = meta_lr
        self.inner_lr = inner_lr
        self.inner_steps = inner_steps
        self.cg_steps = cg_steps
        self.n_params = len(list(model.parameters()))
        self.reg_param = reg_param
        self.hg_mode = hg_mode
        self.T = 16
        self.K = 5
        
        self.fmodel = higher.monkeypatch(model, device=self.device, copy_initial_weights=True)
        self.inner_opt_cls = hg.GradientDescent
    
    def bias_reg_f(self, bias, params):
        return sum(
            [((b - p) ** 2).sum() for b, p in zip(bias, params)]
        )
    
    def train_loss_f(self, params, hparams):
        o = self.fmodel(self.tr_x, params=params)
        return F.cross_entropy(o, self.tr_y) + .5 + self.reg_param + self.bias_reg_f(hparams, params)
    
    def val_loss_f(self, params, hparams):
        o = self.fmodel(self.tst_x, params=params)
        val_loss = F.cross_entropy(o, self.tst_y) / self.batch_size
        self.val_loss = val_loss.item()
        pred = o.argmax(dim=1, keepdim=True)
        self.val_acc = pred.eq(self.tst_y.view_as(pred)).sum().item() / len(self.tst_y)
        return val_loss
    
    
    def get_inner_opt(self, train_loss, kwargs):
        return self.inner_opt_cls(train_loss, **kwargs)
    
    def inner_loop(self, hparams, params, optim, n_steps, log_interval, create_graph=False):
        params_history = [optim.get_opt_params(params)]

        for t in range(n_steps):
            params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))
            self.log('loss', optim.curr_loss.item(), on_step=True, prog_bar=True, logger=True)

        return params_history
    
    def configure_optimizers(self):
        outer_opt = torch.optim.Adam(params=self.model.parameters(), lr=1e-3)
        return outer_opt
    
    @torch.enable_grad()
    def meta_learn(self, batch, batch_idx):
        meta_optimizer = self.optimizers()
        meta_optimizer = meta_optimizer.optimizer        
        tr_xs, tr_ys = batch["train"][0].to(self.device), batch["train"][1].to(self.device)
        tst_xs, tst_ys = batch["test"][0].to(self.device), batch["test"][1].to(self.device)
        
                
        self.batch_size = tr_xs.shape[0]
        val_loss, val_acc = torch.tensor(0., device=self.device), torch.tensor(0., device=self.device)
        
        inner_opt_kwargs = {'step_size': self.inner_lr}
        
        
        meta_optimizer.zero_grad()
        for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
            self.tr_x = tr_x
            self.tr_y = tr_y

            self.tst_x = tst_x
            self.tst_y = tst_y
            inner_opt = self.get_inner_opt(self.train_loss_f, inner_opt_kwargs)
            
            params = [p.detach().clone().requires_grad_(True) for p in self.model.parameters()]
            last_param = self.inner_loop(self.model.parameters(), params, inner_opt, self.T, log_interval=None)[-1]
            
            if self.hg_mode == 'CG':
                # This is the approximation used in the paper CG stands for conjugate gradient
                cg_fp_map = hg.GradientDescent(loss_f=self.train_loss_f, step_size=1.)
                hg.CG(last_param, list(self.model.parameters()), K=self.K, fp_map=cg_fp_map, outer_loss=self.val_loss_f)
            elif self.hg_mode == 'fixed_point':
                hg.fixed_point(last_param, list(seld.model.parameters()), K=self.K, fp_map=inner_opt,
                               outer_loss=self.val_loss_f)
            
            val_loss += self.val_loss
            val_acc += self.val_acc / self.batch_size

        meta_optimizer.step()
        return val_loss, val_acc
    
    def training_step(self, batch, batch_idx):
        train_loss, train_acc = self.meta_learn(batch, batch_idx)
        self.log_dict({
            'tr_accuracy': train_acc.item(),
            'tr_loss': train_loss.item()
        }, prog_bar=True, logger=True)
        return {'tr_loss': train_loss.item(), 'tr_acc': train_acc.item()}
        
    
    def validation_step(self, batch, batch_idx):
        val_loss, val_acc = self.meta_learn(batch, batch_idx)
        
        self.log_dict({
            'val_loss': val_loss.item(),
            'val_accuracy': val_acc.item()
        })
        return val_loss.item()
    
    def test_step(self, batch, batch_idx):
        test_loss, test_acc = self.meta_learn(batch, batch_idx)
        self.log_dict({
            'test_loss': test_loss.item(),
            'test_accuracy': test_acc.item()
        })
        return test_loss.item()

In [None]:
dm = OmniglotDataModule(
        "data",
        shots=1,
        ways=5,
        shuffle_ds=True,
        test_shots=16,
        meta_train=True,
        download=True,
        batch_size=16,
        shuffle=True,
        num_workers=8,
)

In [None]:
model = iMAML(model=ConvolutionalNeuralNetwork(1, 5, hidden_size=64), meta_lr=1e-3, inner_lr=1e-2, reg_param=2, inner_steps=1, cg_steps=5, )

In [None]:
logger = WandbLogger(
    project='iMAML',
    config={
        'batch_size': 16,
        'steps': 100,
        'dataset': "omniglot",
        'T': 16,
        'K': 5,
        'val/test': 'enabled'
    }
)
profiler = pl.profiler.PyTorchProfiler()
trainer = pl.Trainer(
        profiler=profiler,
        max_steps=28,
        val_check_interval=25,
        limit_train_batches=26,
        limit_val_batches=2,
        limit_test_batches=2,
        fast_dev_run=False,
        gpus=1,
        logger=logger,
        log_every_n_steps=1,
        flush_logs_every_n_steps=1
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mp0int[0m (use `wandb login --relogin` to force relogin)



  | Name   | Type                                 | Params
----------------------------------------------------------------
0 | model  | ConvolutionalNeuralNetwork           | 112 K 
1 | fmodel | FunctionalConvolutionalNeuralNetwork | 112 K 
----------------------------------------------------------------
224 K     Trainable params
0         Non-trainable params
224 K     Total params
0.898     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

AssertionError: 

In [None]:
trainer.test()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]



In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,2.53536
tr_accuracy,0.83984
tr_loss,0.63014
epoch,1.0
trainer/global_step,27.0
_runtime,61.0
_timestamp,1623160931.0
_step,30.0
loss_step/epoch_0,2.53691
loss_epoch,2.53589


0,1
loss,▅▅▆▅▆▇▃▅▆▅█▄▂▃█▃▃▃▂▇▃▃▄▂▃▁▁▂
tr_accuracy,▁▂▄▁▁▄▅▅▅▇▇▃▆▇▆▇▆▅▆▆▇█▇▇▇▇██
tr_loss,██▇▇▇▆▅▅▄▄▃▆▄▃▃▃▃▄▃▂▃▂▂▂▂▂▁▁
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██
trainer/global_step,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▁▁▇▇██
_runtime,▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇██
_timestamp,▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇██
_step,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
loss_step/epoch_0,▁█
loss_epoch,▁


In [None]:
from nbdev.export import notebook2script; notebook2script()

Converted 01_nn_utils.ipynb.
Converted 01b_data_loaders_pl.ipynb.
Converted 01c_grad_utils.ipynb.
Converted 01d_hessian_free.ipynb.
Converted 02_maml_pl.ipynb.
Converted 02b_iMAML.ipynb.
Converted 03_protonet_pl.ipynb.
Converted 04_cactus.ipynb.
Converted index.ipynb.
