In [None]:
#default_exp maml
#export
import logging
import warnings 

import higher
import kornia as K
import wandb
import pytorch_lightning as pl
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import numpy as np
import matplotlib.pyplot as plt

from copy import deepcopy
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.metrics.functional import accuracy
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from unsupervised_meta_learning.pl_dataloaders import OmniglotDataModule

In [None]:
%matplotlib inline

In [None]:
#export
logger = logging.getLogger(__name__)

In [None]:
#export
class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = nn.Sequential(
            self.conv3x3(in_channels, hidden_size),
            self.conv3x3(hidden_size, hidden_size),
            self.conv3x3(hidden_size, hidden_size),
            self.conv3x3(hidden_size, hidden_size),
        )

        self.classifier = nn.Linear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs)
        features = features.view((features.size(0), -1))
        logits = self.classifier(features)
        return logits

    def conv3x3(self, in_channels, out_channels, **kwargs):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
            nn.BatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

In [None]:
#export
def get_accuracy(logits, targets):
    """Compute the accuracy (after adaptation) of MAML on the test/query points
    Parameters
    ----------
    logits : `torch.FloatTensor` instance
        Outputs/logits of the model on the query points. This tensor has shape
        `(num_examples, num_classes)`.
    targets : `torch.LongTensor` instance
        A tensor containing the targets of the query points. This tensor has 
        shape `(num_examples,)`.
    Returns
    -------
    accuracy : `torch.FloatTensor` instance
        Mean accuracy on the query points
    """
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())

In [None]:
#export
class MAML(pl.LightningModule):
    def __init__(self, model, inner_steps=1):
        super().__init__()
        self.model = model
        self.accuracy = get_accuracy
        self.automatic_optimization = False
        self.inner_steps = inner_steps
    
    def forward(self, x):
        return self.model(x)
    
    def inner_loop(self, fmodel, diffopt, train_input, train_target):
        train_logit = fmodel(train_input)
        inner_loss = F.cross_entropy(train_logit, train_target)
        diffopt.step(inner_loss)
        
        return inner_loss.item()
    
    @torch.enable_grad()
    def meta_learn(self, batch, batch_idx, optimizer_idx=None):
        meta_optimizer, inner_optimizer = self.optimizers()
        meta_optimizer = meta_optimizer.optimizer
        inner_optimizer = inner_optimizer.optimizer
        
        train_inputs, train_targets = batch['train']
        test_inputs, test_targets = batch['test']
        
        batch_size = train_inputs.shape[0]
        outer_loss = torch.tensor(0., device=self.device)
        acc = torch.tensor(0., device=self.device)
        self.model.zero_grad()
        
        for task_idx, (train_input, train_target, test_input, test_target) in enumerate(
            zip(train_inputs, train_targets, test_inputs, test_targets)
        ):
#             inner_optimizer.zero_grad()
            with higher.innerloop_ctx(self.model, inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt):
#                 train_logit = fmodel(train_input)
#                 inner_loss = F.cross_entropy(train_logit, train_target)

#                 diffopt.step(inner_loss)
                for step in range(self.inner_steps):
                    self.inner_loop(fmodel, diffopt, train_input, train_target)
            
                test_logit = fmodel(test_input)
                outer_loss += F.cross_entropy(test_logit, test_target)
                
                with torch.no_grad():
                    preds = test_logit.softmax(dim=-1)
                    acc += self.accuracy(test_logit, test_target)
                

#                     self.print(self.accuracy(test_logit, test_target))
                
        outer_loss.div_(batch_size)
        acc.div_(batch_size)
        self.log_dict({
                    'outer_loss': outer_loss,
                    'accuracy': acc
                }, prog_bar=True)
        
        meta_optimizer.zero_grad()
#         outer_loss.backward()
        self.manual_backward(outer_loss, meta_optimizer)
        meta_optimizer.step()
        return outer_loss, acc
        
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        train_loss, acc = self.meta_learn(batch, batch_idx, optimizer_idx)
        
        self.log_dict({
            'train_loss': train_loss.item(),
            'train_accuracy': acc.item()
        }, prog_bar=True)
            
        return train_loss.item()
    
    def validation_step(self, batch, batch_idx):
        val_loss, val_acc = self.meta_learn(batch, batch_idx)
        
        self.log_dict({
            'val_loss': val_loss.item(),
            'val_accuracy': val_acc.item()
        })
        return val_loss.item()
    
    def test_step(self, batch, batch_idx):
        test_loss, test_acc = self.meta_learn(batch, batch_idx)
        self.log_dict({
            'test_loss': test_loss.item(),
            'test_accuracy': test_acc.item()
        })
        return test_loss.item()
        
    
    def configure_optimizers(self):
        meta_optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        inner_optimizer = torch.optim.SGD(self.parameters(), lr=1e-1)
        
        return [meta_optimizer, inner_optimizer]
                
                

In [None]:
#export
class UMTRA(pl.LightningModule):
    def __init__(self, model, augmentation):
        super().__init__()
        self.model = model
        self.accuracy = get_accuracy
        self.augmentation = augmentation
        self.automatic_optimization = False
    
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx, optimizer_idx):
        meta_optimizer, inner_optimizer = self.optimizers()
        meta_optimizer = meta_optimizer.optimizer
        inner_optimizer = inner_optimizer.optimizer
        
        train_inputs, train_targets = batch['train']
        test_inputs, test_targets = batch['test']
        
        batch_size = train_inputs.shape[0]
        outer_loss = torch.tensor(0., device=self.device)
        acc = torch.tensor(0., device=self.device)
        self.model.zero_grad()
        
        for task_idx, (train_input, train_target, test_input, test_target) in enumerate(
            zip(train_inputs, train_targets, test_inputs, test_targets)
        ):
#             inner_optimizer.zero_grad()
            val_input = self.augmentation(train_input).to(self.device)
            val_target = deepcopy(train_target).to(self.device)
            with higher.innerloop_ctx(self.model, inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt):
                train_logit = fmodel(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)
                
                diffopt.step(inner_loss)
                
                val_logits = fmodel(val_input)
                outer_loss += F.cross_entropy(val_logits, val_target)
#                 test_logit = fmodel(test_input)
#                 outer_loss += F.cross_entropy(test_logit, test_target)
                
                with torch.no_grad():
                    test_logits = fmodel(test_input)
                    acc += self.accuracy(test_logits, test_target)
                

#                     self.print(self.accuracy(test_logit, test_target))
                
        outer_loss.div_(batch_size)
        acc.div_(batch_size)
        self.log_dict({
                    'outer_loss': outer_loss,
                    'accuracy': acc
                }, prog_bar=True)
        
        meta_optimizer.zero_grad()
#         outer_loss.backward()

        self.manual_backward(outer_loss, meta_optimizer)
        meta_optimizer.step()
        
        return outer_loss
    
    def configure_optimizers(self):
        meta_optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        inner_optimizer = torch.optim.SGD(self.parameters(), lr=1e-1)
        
        return [meta_optimizer, inner_optimizer]

In [None]:
dm = OmniglotDataModule(
        "data",
        shots=1,
        ways=5,
        shuffle_ds=True,
        test_shots=15,
        meta_train=True,
        download=True,
        batch_size=16,
        shuffle=True,
        num_workers=8,
)

In [None]:
model = MAML(model=ConvolutionalNeuralNetwork(1, 5, hidden_size=64), inner_steps=1)

In [None]:
logger = WandbLogger(
    project='maml',
    config={
        'batch_size': 16,
        'steps': 100,
        'dataset': "omniglot",
        'inner_steps': 1,
        'val/test': 'enabled'
    }
)
trainer = Trainer(
        profiler='simple',
        max_steps=100,
        fast_dev_run=False,
        val_check_interval=25,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=2, gpus=1, logger=logger
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | ConvolutionalNeuralNetwork | 112 K 
-----------------------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.449     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  37.089         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  30.104         	|1              	|  30.104         	|  81.167         	|
run_training_batch                 	|  0.22786        	|100            	|  22.786         	|  61.435         	|
model_forward                      	|  0.2277         	|100            	|  22.77          	|  61.393         	|
training_step                      	|  0.22747        	|100            	|  22.747         	|  61.33          	|
evaluation_step_and_end            

In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
outer_loss,0.35524
accuracy,0.88667
val_loss,0.35524
val_accuracy,0.88667
epoch,0.0
trainer/global_step,99.0
_runtime,35.0
_timestamp,1623086749.0
_step,5.0
train_loss,0.2952


0,1
outer_loss,█▄▄▁▁▂
accuracy,▁▅▄██▇
val_loss,█▄▁▁
val_accuracy,▁▄█▇
epoch,▁▁▁▁▁▁
trainer/global_step,▁▃▃▆██
_runtime,▁▃▄▆██
_timestamp,▁▃▄▆██
_step,▁▂▄▅▇█
train_loss,█▁


In [None]:
aug = nn.Sequential(
    K.augmentation.RandomAffine(degrees=0, translate=(0.4, 0.4), padding_mode='border'),
    K.augmentation.RandomGaussianNoise(mean=0., std=.1, p=.3)
)
model = UMTRA(model=ConvolutionalNeuralNetwork(1, 5, hidden_size=64), augmentation=aug)

In [None]:
logger = WandbLogger(
    project='umtra',
    config={
        'batch_size': 16,
        'steps': 100,
        'dataset': "omniglot"
    }
)
trainer = Trainer(
        profiler='simple',
        max_epochs=100,
        max_steps=100,
        fast_dev_run=False,
        num_sanity_val_steps=2, gpus=1, logger=logger
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name         | Type                       | Params
------------------------------------------------------------
0 | model        | ConvolutionalNeuralNetwork | 112 K 
1 | augmentation | Sequential                 | 0     
------------------------------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.449     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  36.018         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  31.24          	|1              	|  31.24          	|  86.733         	|
run_training_batch                 	|  0.29617        	|100            	|  29.617         	|  82.227         	|
model_forward                      	|  0.29602        	|100            	|  29.602         	|  82.185         	|
training_step                      	|  0.29571        	|100            	|  29.571         	|  82.099         	|
get_train_batch                    

In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
outer_loss,0.26876
accuracy,0.83417
epoch,0.0
trainer/global_step,99.0
_runtime,35.0
_timestamp,1622905873.0
_step,1.0


0,1
outer_loss,█▁
accuracy,▁█
epoch,▁▁
trainer/global_step,▁█
_runtime,▁█
_timestamp,▁█
_step,▁█


In [None]:
from nbdev.export import notebook2script; notebook2script()

Converted 01_nn_utils.ipynb.
Converted 01b_data_loaders_pl.ipynb.
Converted 02_maml.ipynb.
Converted 02b_maml_pl.ipynb.
Converted 03_protonet_pl.ipynb.
Converted 04_cactus.ipynb.
Converted index.ipynb.
