In [1]:
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
import argparse

import torch
from torch import nn
import torch.nn.functional as F

from ResNet import ResNet, Block
from dataset import load_data, extract_sample

In [2]:
task_params = {'k_shot': 5,
               'n_way': 5, 
               'n_query': 5}

In [3]:
X_train_dataset, y_train_dataset = load_data("/home/logiceat3r/IMAML-cancer/mammogram-data/MIAS/test")

  0%|          | 0/3 [00:00<?, ?it/s]

Loading Data


100%|██████████| 3/3 [00:08<00:00,  2.95s/it]


In [5]:
model = ResNet(101, Block, 3, 1000)

In [6]:
inner_train_steps = 1
alpha = 0.4 # Inner LR
beta = 0.001 # Meta LR
epochs = 5
batch_size = 32
num_episodes = 100
device = 'cpu'

In [7]:
criterion = nn.CrossEntropyLoss()

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=beta)

In [9]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Block(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (identity_downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=Tr

In [10]:
for epoch in range(1, epochs+1):
    
    pbar = tqdm(total=num_episodes, desc='Epoch {}'.format(epoch))
    
    # Meta Episode
    for episode in range(num_episodes):
        
        task_losses = []
        task_accuracies = []
        
        # Task Fine-tuning
        for task_idx in range(batch_size):
            # Get the train and val splits
            train_sample, test_sample = extract_sample(X_train_dataset, y_train_dataset, task_params)
            X_train = train_sample[0].to(device)
            y_train = train_sample[1].to(device)
            X_val = test_sample[0].to(device)
            y_val = test_sample[1].to(device)
            
            # Create a fast model using current meta model weights
            fast_weights = OrderedDict(model.named_parameters())

            # Fine-tune
            for step in range(inner_train_steps):
                # Forward pass
                logits = model.functional_forward(X_train, fast_weights)
                # Loss
                loss = criterion(logits, y_train)
                # Compute Gradients
                gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Manual Gradient Descent on the fast weights
                fast_weights = OrderedDict(
                                    (name, param - alpha * grad)
                                    for ((name, param), grad) in zip(fast_weights.items(), gradients)
                                )
            
            # Testing on the Query Set (Val)
            val_logits = model.functional_forward(X_val, fast_weights)
            val_loss = criterion(val_logits, y_val)
            
            # Calculating accuracy
            y_pred = val_logits.softmax(dim=1)
            accuracy = torch.eq(y_pred.argmax(dim=-1), y_val).sum().item() / y_pred.shape[0]
            
            task_accuracies.append(accuracy)
            task_losses.append(val_loss)
        
        # Meta Update
        model.train()
        optimizer.zero_grad()
        # Meta Loss
        meta_batch_loss = torch.stack(task_losses).mean()
        # Meta backpropagation
        meta_batch_loss.backward()
        # Meta Optimization
        optimizer.step()
        
        meta_batch_accuracy = torch.Tensor(task_accuracies).mean()
        
        # Progress Bar Logging
        pbar.update(1)
        pbar.set_postfix({'Loss': meta_batch_loss.item(), 
                          'Accuracy': meta_batch_accuracy.item()})
        
    pbar.close()

# Save Model
torch.save({'weights': model.state_dict(),
            'task_params': task_params}, "finally.pth")

Epoch 1:   0%|          | 0/100 [00:00<?, ?it/s]

TypeError: functional_forward() takes 2 positional arguments but 3 were given