<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment1_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install EasyFSL

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Identification dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification

Downloading humpback-whale-identification.zip to /content
100% 5.50G/5.51G [04:16<00:00, 20.6MB/s]
100% 5.51G/5.51G [04:16<00:00, 23.1MB/s]


In [4]:
%%capture

!unzip humpback-whale-identification.zip

### Clone GitHub repo

In [5]:
import os

if os.path.exists('/content/creating-ai-enabled-systems/Research Project') == False:
  !git clone https://github.com/jwells52/creating-ai-enabled-systems.git

%cd creating-ai-enabled-systems/Research\ Project

Cloning into 'creating-ai-enabled-systems'...
remote: Enumerating objects: 473, done.[K
remote: Counting objects: 100% (286/286), done.[K
remote: Compressing objects: 100% (127/127), done.[K
remote: Total 473 (delta 186), reused 238 (delta 157), pack-reused 187[K
Receiving objects: 100% (473/473), 160.26 MiB | 16.20 MiB/s, done.
Resolving deltas: 100% (265/265), done.
Updating files: 100% (52/52), done.
/content/creating-ai-enabled-systems/Research Project


### Imports

In [7]:
import torch
import json


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import FewShotClassifier, RelationNetworks, MatchingNetworks, PrototypicalNetworks, SimpleShot, TransductiveFinetuning
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet152

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loader, transform
from modules.train import train_fsl, device
from modules.plotting import fsl_plots


### Load dataset

In [8]:
train_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/training_10samples.csv')
)

In [9]:
print(f"Min # of samples for a class in data set = {train_df['class_count'].min()}")
print(f"Max # of samples for a class in data set = {train_df['class_count'].max()}")
print(f"# of classes in data set = {len(train_df['Id'].unique())}")


Min # of samples for a class in data set = 11
Max # of samples for a class in data set = 73
# of classes in data set = 181


In [10]:
train_dataset = HumpbackWhaleDataset('/content/train', train_df, transform=transform)

# Train on training set for 1-shot, 5-shot, 5-way, 20-way

# The following networks are trained
1.  Relation Networks
2. Matching Networks
3. Prototypical Networks w/ Euclidean Distance
4. Prototypical Networks w/ Cosine Similarity
5. Transductive Finetuning


In [27]:
from torchvision.models.feature_extraction import create_feature_extractor

class FeatureExtractor(torch.nn.Module):
  '''
  Class for extracting feature maps from model and return a tensor and not a dictionary.
  '''
  def __init__(self, model, layer_name):
    super().__init__()
    self.model = model
    self.layer_name = layer_name

  def forward(self, x):
    return self.model(x)[self.layer_name]

In [13]:
n_epochs = 10
learning_rate = 1e-2

# Number of Training Task for each epoch
# A training task is a random sample of N shots (images) for M classes
n_task_per_epoch = 100

n_ways = [5, 20]
n_shots = [1, 5]
n_query = 5

In [30]:
cnn1 = resnet18(weights='DEFAULT')
resnet_extractor = create_feature_extractor(cnn1, return_nodes=['layer4.1.bn2'])
feature_extractor = FeatureExtractor(resnet_extractor, 'layer4.1.bn2')
fsl_network = MatchingNetworks(feature_extractor, feature_dimension=512).to(device)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7bafb8b21e10>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


In [44]:
def train_network(
    network,
    train_dataset,
    n_ways, n_shots, n_query, n_tasks_per_epoch,
    checkpoint_path,
    n_workers=12,
    feature_extractor=True,
    feature_dimension=None,
    return_layer='layer4.1.bn2',
    learning_rate=1e-2,
    n_epochs=10
  ):

  losses = dict()
  for n_way in n_ways:
    losses[n_way] = dict()
    for n_shot in n_shots:
      train_loader = create_loader(train_dataset, n_way, n_shot, n_query, n_tasks_per_epoch, num_workers=n_workers)

      resnet = resnet18(weights='DEFAULT')
      resnet.fc = torch.nn.Flatten()

      if feature_extractor:
        resnet_extractor = create_feature_extractor(resnet, return_nodes=[return_layer])
        feature_extractor = FeatureExtractor(resnet_extractor, return_layer)
      else:
        feature_extractor = resnet

      if feature_dimension is not None:
        fsl_network = network(feature_extractor, feature_dimension=feature_dimension)
      else:
        fsl_network = network(feature_extractor)

      fsl_network = fsl_network.to(device)
      loss_fn = torch.nn.CrossEntropyLoss()
      optimizer = SGD(fsl_network.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

      print(f'Training network under {n_way}-way {n_shot}-shot')
      train_losses, _ = train_fsl(
        fsl_network,
        train_loader, None,
        optimizer, loss_fn, n_epochs=n_epochs,
        save_model=True, save_path=f'{checkpoint_path}_{n_way}-way_{n_shot}-shot'
      )

      losses[n_way][n_way] = train_losses

### Relation Network

In [21]:
print('Training Relation Network')
rn_losses = train_network(
    RelationNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/relation_network',
    feature_extractor=True,
    feature_dimension=512
)

Training Relation Network
Training network under 5-way 1-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:10<00:00,  9.77it/s, loss=1.47]


Epoch 2 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.79it/s, loss=1.3]


Epoch 3 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.96it/s, loss=1.28]



Epoch 4 

Training: 100%|██████████| 100/100 [00:09<00:00, 10.54it/s, loss=1.27]



Epoch 5 

Training: 100%|██████████| 100/100 [00:09<00:00, 10.04it/s, loss=1.23]



Epoch 6 

Training: 100%|██████████| 100/100 [00:10<00:00,  9.84it/s, loss=1.23]


Epoch 7 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.94it/s, loss=1.17]


Epoch 8 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.19it/s, loss=1.15]


Epoch 9 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.20it/s, loss=1.17]



Epoch 10 

Training: 100%|██████████| 100/100 [00:09<00:00, 10.07it/s, loss=1.12]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_1-shot_last_epoch
Training network under 5-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:14<00:00,  6.97it/s, loss=1.46]


Epoch 2 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.99it/s, loss=1.31]


Epoch 3 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.00it/s, loss=1.22]


Epoch 4 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.93it/s, loss=1.19]


Epoch 5 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.92it/s, loss=1.12]


Epoch 6 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.98it/s, loss=1.09]


Epoch 7 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.03it/s, loss=1.07]


Epoch 8 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.02it/s, loss=1.05]


Epoch 9 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.97it/s, loss=1.04]


Epoch 10 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.07it/s, loss=1.02]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_5-shot_last_epoch
Training network under 20-way 1-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:38<00:00,  2.61it/s, loss=2.77]


Epoch 2 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.65it/s, loss=2.48]


Epoch 3 


Training: 100%|██████████| 100/100 [00:38<00:00,  2.61it/s, loss=2.37]


Epoch 4 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=2.31]


Epoch 5 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.63it/s, loss=2.26]


Epoch 6 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.64it/s, loss=2.23]


Epoch 7 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.65it/s, loss=2.2]


Epoch 8 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=2.17]


Epoch 9 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.65it/s, loss=2.16]


Epoch 10 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=2.15]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_20-way_1-shot_last_epoch
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=2.7]


Epoch 2 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.29]


Epoch 3 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.2]


Epoch 4 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.15]


Epoch 5 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.13]


Epoch 6 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.11]


Epoch 7 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.11]


Epoch 8 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.1]


Epoch 9 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.1]


Epoch 10 


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=2.1]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_20-way_5-shot_last_epoch


### Matching Network

In [42]:
print('Training Matching Network')
mn_losses = train_network(
    MatchingNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/matching_network',
    feature_extractor=False,
    feature_dimension=512
)

Training Matching Network
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.584]


Epoch 2 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=0.318]


Epoch 3 


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=0.185]


Epoch 4 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=0.0608]


Epoch 5 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=0.0378]


Epoch 6 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.0225]


Epoch 7 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.0174]


Epoch 8 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.70it/s, loss=0.012]


Epoch 9 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=0.00975]


Epoch 10 


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=0.00905]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/matching_network_20-way_5-shot_last_epoch


### Prototypical Network with Euclidean Distance

In [47]:
print('Training Prototypical Network')
pt_losses = train_network(
    PrototypicalNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/prototypical_network',
    feature_extractor=False,
    feature_dimension=None
)

Training Prototypical Network
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.397]



Epoch 2 

Training: 100%|██████████| 100/100 [01:01<00:00,  1.62it/s, loss=0.0615]


Epoch 3 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.0278]


Epoch 4 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.70it/s, loss=0.0162]


Epoch 5 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=0.00995]


Epoch 6 


Training: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s, loss=0.00837]


Epoch 7 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.00679]


Epoch 8 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.00591]


Epoch 9 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.00458]



Epoch 10 

Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=0.00404]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/prototypical_network_20-way_5-shot_last_epoch


### Prototypical Network with Cosine Similarity

In [51]:
print('Training Simple Shot Network')
ss_losses = train_network(
    SimpleShot,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/simple_shot_network',
    feature_extractor=False,
    feature_dimension=None
)

Training Simple Shot Network
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s, loss=2.84]


Epoch 2 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=2.61]


Epoch 3 


Training: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s, loss=2.52]


Epoch 4 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=2.46]


Epoch 5 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=2.42]


Epoch 6 


Training: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s, loss=2.38]


Epoch 7 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.70it/s, loss=2.37]


Epoch 8 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=2.34]


Epoch 9 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=2.33]


Epoch 10 


Training: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s, loss=2.32]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/simple_shot_network_20-way_5-shot_last_epoch


### Transductive Finetuning -- In Work

In [None]:
# Transductive Finetuning

# Train the model with classical training
train_loader_classical = DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=12,
    pin_memory=True,
    shuffle=True,
)

In [None]:
# from easyfsl.modules import resnet18

resnet = resnet18(weights='DEFAULT')
resnet.fc = torch.nn.Linear(in_features=512, out_features=len(set(train_dataset.get_labels())))

# model = resnet18(
#     use_fc=True,
#     num_classes=len(set(train_dataset.get_labels())),
# )

In [None]:
def training_epoch_classical(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer, loss_fn: Callable, device='cuda'):
    all_loss = []
    model_.train()
    with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
        for images, labels in tqdm_train:
            print(images.shape)
            print(labels.shape)
            optimizer.zero_grad()

            loss = loss_fn(model_(images.to(device)), labels.to(device))
            print(loss)
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

In [None]:
tf_network = TransductiveFinetuning(resnet).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = SGD(tf_network.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

In [None]:
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch_classical(
        tf_network,
        train_loader_classical,
        optimizer,
        loss_fn
    )
    # optimizer.step()

Epoch 0


Training:   0%|          | 0/112 [00:00<?, ?it/s]

torch.Size([32, 3, 256, 512])
torch.Size([32])


Training:   0%|          | 0/112 [00:02<?, ?it/s]


RuntimeError: ignored