In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [23]:
import torch 
import copy 

import numpy as np 
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt 

from torch.autograd import Variable 
from pathlib import Path
from tqdm import tqdm
from tensorboardX import SummaryWriter

from dataloader import OmniClassDataset, OmniLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
base = Path('data/omniglot/')
data_dir = base/'data'
split_dir = base/'splits'/'vinyals'

In [4]:
def get_dataloader(split, k_shot, n_way, n_test):
    dataset = OmniClassDataset(split=split,
                           data_dir=data_dir, 
                           splits_dir=split_dir,
                           shuffle=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               lambda x: 1 - x
                           ]))
    
    dataloader = OmniLoader(k_shot=k_shot, 
                            n_way=n_way,
                            n_test=n_test,
                            dataset=dataset,
                            shuffle=True,
                            pin_memory=True,
                            num_workers=8)
    return dataloader

In [5]:
# ran with 3 diff seeds 
# linearly annealed 

way5_params = {
    'n_way': 5,
    'k_shots': 1,
    'n_test': 1,
    # inner loop parameters
    'inner_lr': 0.001,
    'inner_batchsize': 10,
    'inner_iterations': 5,
    # outter loop parameters
    'outer_lr': 1.0,
    'outer_iterations': 100000,
    'meta_batchsize': 5,
    # evaluation params
    'eval_inner_iterations': 50,
    'eval_inner_batch': 5,
    # other...
    'validation_rate': 10
}

way20_params = {
    'n_way': 20,
    'k_shots': 1,
    'n_test': 1,
    # inner loop parameters
    'inner_lr': 0.0005,
    'inner_batchsize': 20,
    'inner_iterations': 10,
    # outter loop parameters
    'outer_lr': 1.0,
    'outer_iterations': 200000,
    'meta_batchsize': 5,
    # evaluation params
    'eval_inner_iterations': 50,
    'eval_inner_batch': 10,
    # other...
    'validation_rate': 10
}

In [6]:
class Model(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        
        conv_block = lambda in_dim:(nn.Conv2d(in_dim, 64, 3, stride=2, padding=1),
                                    nn.BatchNorm2d(64),
                                    nn.ReLU())
        
        self.cnn = nn.Sequential(
            *conv_block(1),
            *conv_block(64),
            *conv_block(64), 
            *conv_block(64)
        )
        self.linear = nn.Linear(256, n_classes)
        
    def forward(self, x):
        x = self.cnn(x)
        x = x.reshape(x.size(0), -1)
        x = self.linear(x)
        return x 

In [7]:
params = way5_params
model = Model(n_classes=params['n_way'])

In [8]:
outter_loop_optim = torch.optim.SGD(model.parameters(), lr=params['outer_lr'])
inner_loop_optim = torch.optim.Adam(model.parameters(), lr=params['inner_lr'], betas=(0, 0))
loss_fcn = nn.CrossEntropyLoss()

In [9]:
train_dataloader = get_dataloader('train', params['k_shots'], params['n_way'], params['n_test'])
val_dataloader = get_dataloader('val', params['k_shots'], params['n_way'], params['n_test'])

In [10]:
def take_n_steps(loss_fcn, optim, model, x, y, n_steps):
    losses = []
    for _ in range(n_steps):
        optim.zero_grad()
        loss = loss_fcn(model(x), y)
        loss.backward()
        optim.step()
        
        losses.append(loss)
    return losses

In [11]:
def meta_train(model, optim, loss_fcn, n_iterations, train_x, train_y, test_x, test_y):
    task_model = copy.deepcopy(model)
    inner_loop_opitm = torch.optim.Adam(task_model.parameters(), lr=params['inner_lr'], betas=(0, 0))

    take_n_steps(loss_fcn, 
                 optim,
                 task_model,
                 train_x, train_y,
                 n_iterations)

    y_preds = task_model(test_x)
    loss = loss_fcn(y_preds, test_y)
    n_correct = (F.softmax(y_preds, dim=-1).argmax(-1) == test_y).sum()
    accuracy = n_correct / test_y.size(0)
    
    return task_model, loss.item(), accuracy.item()

In [15]:
n_way = 2
batch_size = 4
k_shot = 2
n_test = 3

In [16]:
# https://github.com/gabrielhuang/reptile-pytorch/blob/master/train_omniglot.py
def make_inf(D):
    while True:
        for x in D:
            yield x

In [17]:
dataset = OmniClassDataset(split='train',
                   data_dir=data_dir, 
                   splits_dir=split_dir,
                   shuffle=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       lambda x: 1 - x
                   ]))

train_loader = make_inf(OmniLoader(k_shot=k_shot, 
                                    n_way=n_way,
                                    n_test=n_test,
                                    dataset=dataset,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=8))

In [18]:
def batch(x, y, batch_size):
    x_batch = torch.stack(list(torch.chunk(x, batch_size)) if x.size(0) % batch_size == 0 else list(torch.chunk(x, batch_size+1))[:-1])
    y_batch = torch.stack(list(torch.chunk(y, batch_size)) if x.size(0) % batch_size == 0 else list(torch.chunk(y, batch_size+1))[:-1])
    return x_batch, y_batch

In [41]:
def train_test_split(x, y, n_test):
    test_idxs, train_idxs = [], []
    
    for class_i in range(int(y.max())):
        class_ex = (y == class_i).nonzero().flatten()
        class_ex = class_ex[torch.randperm(class_ex.size(0))]
        ctest_idx = class_ex[:n_test]
        ctrain_idx = class_ex[n_test:]

        test_idxs.append(ctest_idx)
        train_idxs.append(ctrain_idx)
    
    train_idxs = torch.cat(train_idxs)
    x_train, y_train = x[train_idxs], y[train_idxs]
    
    test_idxs = torch.cat(test_idxs)
    x_test, y_test = x[test_idxs], y[test_idxs]
    
    return (x_train, y_train), (x_test, y_test)

In [39]:
for x, y in val_loader:

torch.Size([20, 1, 28, 28])

In [45]:
model_name = '5w1s - 1test - full_run'
writer = SummaryWriter(comment=model_name)
debug = False

params = way5_params
model = Model(n_classes=params['n_way']).to(device)
outter_loop_optim = torch.optim.SGD(model.parameters(), lr=params['outer_lr'])
loss_fcn = nn.CrossEntropyLoss()

# debugging parameters
if debug:
    params['outer_iterations'] = 100
    params['validation_rate'] = 3
    params['meta_batchsize'] = 1
    params['inner_batchsize'] = 15

# params['outer_iterations'] = 500

train_loader = get_dataloader('train', params['k_shots'], params['n_way'], params['n_test'])
val_loader = get_dataloader('val', params['k_shots'], params['n_way'], params['n_test'])
    
for outer_i in tqdm(range(params['outer_iterations'])):
    outter_loop_optim.zero_grad()
    
    # lr annealing 
    frac_done = outer_i / params['outer_iterations']
    cur_meta_step_size = frac_done * 1e-9 + (1 - frac_done) * params['outer_lr']
    for param_group in outter_loop_optim.param_groups:
        param_group['lr'] = cur_meta_step_size
    
    # inner loop
    for task_i, (x, y) in enumerate(train_loader):
        new_model = model.clone()
        inner_loop_opitm = torch.optim.Adam(new_model.parameters(), lr=params['inner_lr'], betas=(0, 0))
        
        # train on batches of the dataset
        batch_x, batch_y = batch(x, y, batch_size=params['inner_batchsize'])
        for i, (train_x, train_y) in enumerate(zip(batch_x, batch_y)):
            take_n_steps(loss_fcn, 
                         inner_loop_opitm,
                         new_model,
                         train_x, train_y, 1)
            if i == params['inner_iterations'] -1: 
              break
        
        # record weights
        for w, w_t in zip(model.parameters(), new_model.parameters()):
            if w.grad is None:
                w.grad = Variable(torch.zeros_like(w)).to(device)
            # invert loss eqn. to use descent optimization
            w.grad.data.add_(w.data - w_t.data)
        
        if task_i == params['meta_batchsize'] - 1:
            break

    # eval 
    y_preds = new_model(x)
    loss = loss_fcn(y_preds, y)
    accuracy = (y_preds.argmax(-1) == y).float().mean()
                        
    writer.add_scalar('meta_train_loss', loss, outer_i)
    writer.add_scalar('meta_train_acc', accuracy, outer_i)
    
    # update model with avg over mini batches 
    for w in model.parameters():
        w.grad.data.div_(params['meta_batchsize'])
    outter_loop_optim.step()
    
    # validation
    if outer_i % params['validation_rate'] == 0:
        
        for (x, y) in val_loader:
            new_model = model.clone()
            inner_loop_opitm = torch.optim.Adam(new_model.parameters(), lr=params['inner_lr'], betas=(0, 0))
            
            (x_train, y_train), (x_test, y_test) = train_test_split(x, y, params['n_test'])
            
            # train
            xb_train, yb_train = batch(x_train, y_train, batch_size=params['eval_inner_batch'])
            for (train_x, train_y) in zip(xb_train, yb_train):
                take_n_steps(loss_fcn, 
                             inner_loop_opitm,
                             new_model,
                             train_x, train_y,
                             params['eval_inner_iterations'])
            
            # eval 
            y_preds = new_model(x_test)
            loss = loss_fcn(y_preds, y_test)
            accuracy = (y_preds.argmax(-1) == y_test).float().mean()
                                
            writer.add_scalar('meta_val_loss', loss, outer_i)
            writer.add_scalar('meta_val_acc', accuracy, outer_i)
            break
        
writer.close()
print('Summary writer closed...')

Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_byt

  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:13,  1.35s/it][ATraceback (most recent call last):
Traceback (most recent call last):
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/Users/brennangebotys/miniconda2/envs/meta/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/Users/brennangebotys/m

KeyboardInterrupt: 

In [41]:
def view_losses(losses, split_interval=100):
    d = torch.split(torch.from_numpy(np.array(losses)), split_interval)
    mean_losses = [x.mean() for x in d]
    intervals = [split_interval * i for i in range(len(mean_losses))]
    
    plt.plot(intervals, mean_losses)
    plt.title('Model Test Loss (averaged every {} iterations)'.format(split_interval))
    plt.show()

In [None]:
view_losses(val_loss, split_interval=params['eval_inner_batch'])