In [None]:
import logging
import warnings 

import higher
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.metrics.functional import accuracy
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

In [None]:
logger = logging.getLogger(__name__)

In [None]:
class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self, in_channels, out_features, hidden_size=64):
        super(ConvolutionalNeuralNetwork, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size

        self.features = nn.Sequential(
            self.conv3x3(in_channels, hidden_size),
            self.conv3x3(hidden_size, hidden_size),
            self.conv3x3(hidden_size, hidden_size),
            self.conv3x3(hidden_size, hidden_size),
        )

        self.classifier = nn.Linear(hidden_size, out_features)

    def forward(self, inputs, params=None):
        features = self.features(inputs)
        features = features.view((features.size(0), -1))
        logits = self.classifier(features)
        return logits

    def conv3x3(self, in_channels, out_channels, **kwargs):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
            nn.BatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

In [None]:
def get_accuracy(logits, targets):
    """Compute the accuracy (after adaptation) of MAML on the test/query points
    Parameters
    ----------
    logits : `torch.FloatTensor` instance
        Outputs/logits of the model on the query points. This tensor has shape
        `(num_examples, num_classes)`.
    targets : `torch.LongTensor` instance
        A tensor containing the targets of the query points. This tensor has 
        shape `(num_examples,)`.
    Returns
    -------
    accuracy : `torch.FloatTensor` instance
        Mean accuracy on the query points
    """
    _, predictions = torch.max(logits, dim=-1)
    return torch.mean(predictions.eq(targets).float())

In [None]:
class MAML(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.accuracy = get_accuracy
        self.automatic_optimization = False
    
    def forward(self, x):
        return self.model(x)
    def training_step(self, batch, batch_idx, optimizer_idx):
        meta_optimizer, inner_optimizer = self.optimizers()
        meta_optimizer = meta_optimizer.optimizer
        inner_optimizer = inner_optimizer.optimizer
        
        train_inputs, train_targets = batch['train']
        test_inputs, test_targets = batch['test']
        
        batch_size = train_inputs.shape[0]
        self.print(batch_size)
        outer_loss = torch.tensor(0., device=self.device)
        acc = torch.tensor(0., device=self.device)

        for task_idx, (train_input, train_target, test_input, test_target) in enumerate(
            zip(train_inputs, train_targets, test_inputs, test_targets)
        ):
            with higher.innerloop_ctx(self.model, inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt):
                train_logit = fmodel(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)
                
                diffopt.step(inner_loss)
                
                test_logit = fmodel(test_input)
                outer_loss += F.cross_entropy(test_logit, test_target)
                
                with torch.no_grad():
                    preds = test_logit.softmax(dim=-1)
                    acc += torchmetrics.functional.accuracy(preds, test_target)
                

#                     self.print(self.accuracy(test_logit, test_target))
                    self.print(test_logit[0])
                    self.print(test_target[0])
                
        outer_loss.div_(batch_size)
        acc.div_(batch_size)
        self.log_dict({
                    'outer_loss': outer_loss,
                    'accuracy': acc
                }, prog_bar=True)
        
        self.manual_backward(outer_loss, meta_optimizer)
        
        return outer_loss
    
    def configure_optimizers(self):
        meta_optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        inner_optimizer = torch.optim.SGD(self.parameters(), lr=1e-1)
        
        return [meta_optimizer, inner_optimizer]
                
                

In [None]:
class OmniglotDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        shots: int,
        ways: int,
        shuffle_ds: bool,
        test_shots: int,
        meta_train: bool,
        download: bool,
        batch_size: str,
        shuffle: bool,
        num_workers: int):
        super().__init__()
        self.data_dir = data_dir
        self.shots = shots
        self.ways = ways
        self.shuffle_ds = shuffle_ds
        self.test_shots = test_shots
        self.meta_train = meta_train
        self.download = download
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
    
    def setup(self, stage=None):
        self.task_dataset = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_train=self.meta_train,
            download=self.download
        )
    def train_dataloader(self):
        return BatchMetaDataLoader(
            self.task_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers
        )

In [None]:
dm = OmniglotDataModule(
        "data",
        shots=5,
        ways=1,
        shuffle_ds=True,
        test_shots=15,
        meta_train=True,
        download=True,
        batch_size=16,
        shuffle=True,
        num_workers=8,
)

In [None]:
dm.setup()
batch =dm.train_dataloader()

In [None]:
batch = next(iter(batch))



In [None]:
batch['train'][0].shape

torch.Size([16, 5, 1, 28, 28])

In [None]:
model = MAML(model=ConvolutionalNeuralNetwork(1, 5, hidden_size=64))

In [None]:
logger = WandbLogger(
    project='maml'
)
trainer = Trainer(
        profiler='simple',
        max_epochs=100,
        fast_dev_run=False,
        num_sanity_val_steps=2, gpus=1, #logger=logger
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | ConvolutionalNeuralNetwork | 112 K 
-----------------------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.449     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

16
tensor([ 7.9036, -0.8018, -1.9563, -1.6999, -1.8623], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 8.0959, -0.9833, -1.8878, -2.1419, -1.5133], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 6.8777, -1.0457, -1.9666, -1.6909, -1.3560], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 8.9689, -1.2400, -2.2301, -2.8620, -1.6561], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 7.1171, -1.5409, -1.8214, -1.1864, -1.1540], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 7.6583, -1.2810, -2.2873, -1.6743, -1.4706], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 7.6766, -1.3880, -2.1270, -2.6105, -1.4791], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0')
tensor([ 8.9662, -1.4132, -2.1438, -2.9371, -1.2142], device='cuda:0',
       requires_grad=True)
tensor(0, device='cuda:0