In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, '../')
import fewshot.proto.sampler
import fewshot.proto.trainer
import fewshot.maml.maml
import fewshot.maml.model
import fewshot.data

import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
from torchvision import models
import torch.optim as optim


import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# Exploring MAML

Just checking that our fewshot.proto.sampler fits with @oscarknagg's implementation of MAML and that everything trains properly:


In [3]:
train_transform = transforms.Compose([
        transforms.Resize(100),
        transforms.RandomCrop(80),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

test_transform = transforms.Compose([
        transforms.Resize(100),
        transforms.CenterCrop(80),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

In [4]:
ds = fewshot.proto.sampler.NShotFashionDataset('../data/fashion-dataset/styles_quoted.csv',
                                               '../data/fashion-dataset/images/',
                                               classlist=None,
                                               transform=test_transform)

In [5]:
n, k, q = 5, 10, 20
sampler = fewshot.proto.sampler.NShotTaskSampler(ds, episodes_per_epoch=30, n=n, k=k, q=q, bs=4)
dl = torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=4)

In [6]:
model = fewshot.maml.model.FewShotClassifier(3, k, 1600).cuda()

In [7]:
loss_func = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [8]:
device = torch.device('cuda')

In [9]:
losses = []

In [10]:
x, y = next(iter(dl))

Checking that we can overfit to a single batch:

In [11]:
for _ in range(70):
    loss, _ = fewshot.maml.maml.meta_gradient_step(model, optimizer, loss_func, x.cuda().float(), y, n, k, q, 1, 2, .01, True, device)
    print('loss: ', loss.item())
    losses.append(loss.item())


loss:  2.105039119720459
loss:  1.8905718326568604
loss:  1.7168676853179932
loss:  1.568798542022705
loss:  1.4654457569122314
loss:  1.3561930656433105
loss:  1.2642076015472412
loss:  1.1862894296646118
loss:  1.0871490240097046
loss:  1.0120195150375366
loss:  0.9447699189186096
loss:  0.8714093565940857
loss:  0.8043456673622131
loss:  0.7511523962020874
loss:  0.7011047601699829
loss:  0.6724287271499634
loss:  0.6075797080993652
loss:  0.554768443107605
loss:  0.5373630523681641
loss:  0.5047076940536499
loss:  0.4716227054595947
loss:  0.4530414044857025
loss:  0.41742974519729614
loss:  0.3990890383720398
loss:  0.3854110836982727
loss:  0.3472350537776947
loss:  0.31305018067359924
loss:  0.2959599792957306
loss:  0.2934921681880951
loss:  0.27284759283065796
loss:  0.2613319456577301
loss:  0.24893787503242493
loss:  0.2390328347682953
loss:  0.22823569178581238
loss:  0.21382534503936768
loss:  0.2012040764093399
loss:  0.1854984015226364
loss:  0.1697147786617279
loss:  0.

Ok, so we can overfit to a single batch - a good sign that we're training!

In [None]:
losses = []
accuracies = []
for epoch in range(10):
    print('------------------------ epoch: ', epoch)
    epoch_losses = []
    epoch_accuracies = []
    for x, y in dl:
        loss, probs = fewshot.maml.maml.meta_gradient_step(model, optimizer, loss_func, x.cuda().float(), y, n, k, q, 1, 2, .01, True, device)
        _, predicted = torch.max(probs, 1)
        labels = torch.arange(0, k, 1/q).long().cuda().repeat(4)
        acc = (predicted == labels).float().mean().item()
        epoch_accuracies.append(acc)
        epoch_losses.append(loss.item())
        print('curr loss: {:10.1f} curr accuracy: {:10.1f}'.format(loss.item()*100, acc*100), end='\r')
        
    
    epoch_loss = np.mean(epoch_losses)
    epoch_accuracy = np.mean(epoch_accuracies)
    losses.append(epoch_loss)
    accuracies.append(epoch_accuracy)
    print('epoch loss: {:10.1f} epoch accuracy: {:10.1f}'.format(epoch_loss*100, epoch_accuracy*100))
        
        
    

------------------------ epoch:  0
epoch loss:      204.2 epoch accuracy:       30.8
------------------------ epoch:  1
epoch loss:      193.9 epoch accuracy:       32.4
------------------------ epoch:  2
epoch loss:      190.0 epoch accuracy:       33.9
------------------------ epoch:  3
curr loss:      182.3 curr accuracy:       38.2