<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment1_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install EasyFSL

In [None]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Identification dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification

Downloading humpback-whale-identification.zip to /content
100% 5.50G/5.51G [04:49<00:00, 17.1MB/s]
100% 5.51G/5.51G [04:49<00:00, 20.4MB/s]


In [None]:
%%capture

!unzip humpback-whale-identification.zip

### Clone GitHub repo

In [None]:
import os

if os.path.exists('/content/creating-ai-enabled-systems/Research Project') == False:
  !git clone https://github.com/jwells52/creating-ai-enabled-systems.git

%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Imports

In [34]:
import torch
import json


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import FewShotClassifier, RelationNetworks, MatchingNetworks, PrototypicalNetworks, SimpleShot, TransductiveFinetuning
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet152

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loader
from modules.train import train_fsl, device, transform
from modules.plotting import fsl_plots


### Load dataset

In [None]:
train_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/training_10samples.csv')
)

test_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/validation_10samples.csv')
)

In [None]:
print(f"Min # of samples for a class in data set = {train_df['class_count'].min()}")
print(f"Max # of samples for a class in data set = {train_df['class_count'].max()}")
print(f"# of classes in data set = {len(train_df['Id'].unique())}")


Min # of samples for a class in data set = 11
Max # of samples for a class in data set = 73
# of classes in data set = 181


In [None]:
print(f"Min # of samples for a class in data set = {test_df['class_count'].min()}")
print(f"Max # of samples for a class in data set = {test_df['class_count'].max()}")
print(f"# of classes in data set = {len(test_df['Id'].unique())}")


Min # of samples for a class in data set = 11
Max # of samples for a class in data set = 48
# of classes in data set = 46


In [None]:
train_dataset = HumpbackWhaleDataset('/content/train', train_df, transform=transform)

### Train network with 1, 3, and 5 shot learning

# Train on training set for 1-shot, 5-shot, 5-way, 20-way

# The following networks are trained
1.  Relation Networks
2. Matching Networks
3. Prototypical Networks w/ Euclidean Distance
4. Prototypical Networks w/ Cosine Similarity
5. Transductive Finetuning


In [None]:
from torchvision.models.feature_extraction import create_feature_extractor

class FeatureExtractor(torch.nn.Module):
  '''
  Class for extracting feature maps from model and return a tensor and not a dictionary.
  '''
  def __init__(self, model, layer_name):
    super().__init__()
    self.model = model
    self.layer_name = layer_name

  def forward(self, x):
    return self.model(x)[self.layer_name]

In [73]:
n_epochs = 10
learning_rate = 1e-2

# Number of Training Task for each epoch
# A training task is a random sample of N shots (images) for M classes
n_task_per_epoch = 100

n_ways = [5, 20]
n_shots = [1, 5]
n_query = 5

In [78]:
def train_network(
    network,
    train_dataset,
    n_ways, n_shots, n_query, n_tasks_per_epoch,
    checkpoint_path,
    n_workers=12,
    feature_maps=False,
    return_layer='layer4.1.bn2',
    learning_rate=1e-2,
    n_epochs=10
  ):
  losses = dict()

  for n_way in n_ways:
    losses[n_way] = dict()
    for n_shot in n_shots:
      train_loader = create_loader(train_dataset, n_way, n_shot, n_query, n_tasks_per_epoch, num_workers=n_workers)

      resnet = resnet18(weights='DEFAULT')
      resnet.fc = torch.nn.Flatten()

      if feature_maps:
        resnet_extractor = create_feature_extractor(resnet, return_nodes=[return_layer])
        feature_extractor = FeatureExtractor(resnet_extractor, return_layer).to(device)
        fsl_network = network(feature_extractor, feature_dimension=512).to(device)

      else:
        feature_extractor = resnet
        fsl_network = network(feature_extractor).to(device)

      loss_fn = torch.nn.CrossEntropyLoss()
      optimizer = Adam(fsl_network.parameters(), lr=learning_rate)
      # optimizer = SGD(fsl_network.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
      print(f'Training network under {n_way}-way {n_shot}-shot')
      train_losses, _ = train_fsl(
        relation_network,
        train_loader, None,
        optimizer, loss_fn, n_epochs=n_epochs,
        save_model=True, save_path=f'{checkpoint_path}_{n_way}-way_{n_shot}-shot'
      )

      losses[n_way][n_way] = train_losses

### Relation Network

In [None]:
print('Training Relation Network')
rn_losses = train_network(
    RelationNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/relation_network',
    feature_maps=True
)

Training Relation Network
Training network under 5-way 1-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:10<00:00,  9.78it/s, loss=1.14]


Epoch 2 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.97it/s, loss=1.17]


Epoch 3 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.21it/s, loss=1.17]


Epoch 4 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.25it/s, loss=1.18]


Epoch 5 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.11it/s, loss=1.18]


Epoch 6 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.05it/s, loss=1.19]


Epoch 7 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.93it/s, loss=1.18]


Epoch 8 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.06it/s, loss=1.17]


Epoch 9 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.01it/s, loss=1.18]


Epoch 10 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.30it/s, loss=1.18]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_1-shot_last_epoch
Training network under 5-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:13<00:00,  7.16it/s, loss=0.99]


Epoch 2 


Training: 100%|██████████| 100/100 [00:13<00:00,  7.20it/s, loss=0.994]


Epoch 3 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.94it/s, loss=1]


Epoch 4 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.07it/s, loss=1]


Epoch 5 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.12it/s, loss=1]


Epoch 6 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.09it/s, loss=1.01]


Epoch 7 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.14it/s, loss=0.985]


Epoch 8 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.10it/s, loss=0.99]


Epoch 9 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.94it/s, loss=1]


Epoch 10 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.08it/s, loss=1]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_5-shot_last_epoch
Training network under 20-way 1-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:37<00:00,  2.67it/s, loss=2.33]


Epoch 2 


Training:  94%|█████████▍| 94/100 [00:36<00:01,  4.22it/s, loss=2.32]

### Matching Network

In [None]:
print('Training Matching Network')
mn_losses = train_network(
    MatchingNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/matching_network',
    feature_maps=True
)

### Prototypical Network with Euclidean Distance

In [None]:
print('Training Prototypical Network')
pt_losses = train_network(
    PrototypicalNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/prototypical_network',
    feature_maps=False
)

### Prototypical Network with Cosine Similarity

In [None]:
print('Training Simple Shot Network')
ss_losses = train_network(
    SimpleShot,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/simple_shot_network',
    feature_maps=False
)

### Transductive Finetuning -- In Work

In [26]:
# Transductive Finetuning

# Train the model with classical training
train_loader_classical = DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=12,
    pin_memory=True,
    shuffle=True,
)

In [58]:
# from easyfsl.modules import resnet18

resnet = resnet18(weights='DEFAULT')
resnet.fc = torch.nn.Linear(in_features=512, out_features=len(set(train_dataset.get_labels())))

# model = resnet18(
#     use_fc=True,
#     num_classes=len(set(train_dataset.get_labels())),
# )

In [70]:
def training_epoch_classical(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer, loss_fn: Callable, device='cuda'):
    all_loss = []
    model_.train()
    with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
        for images, labels in tqdm_train:
            print(images.shape)
            print(labels.shape)
            optimizer.zero_grad()

            loss = loss_fn(model_(images.to(device)), labels.to(device))
            print(loss)
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

In [71]:
tf_network = TransductiveFinetuning(resnet).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = SGD(tf_network.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

In [72]:
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch_classical(
        tf_network,
        train_loader_classical,
        optimizer,
        loss_fn
    )
    # optimizer.step()

Epoch 0


Training:   0%|          | 0/112 [00:00<?, ?it/s]

torch.Size([32, 3, 256, 512])
torch.Size([32])


Training:   0%|          | 0/112 [00:02<?, ?it/s]


RuntimeError: ignored