<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment1_cv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install EasyFSL

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Identification dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification

Downloading humpback-whale-identification.zip to /content
100% 5.50G/5.51G [04:49<00:00, 17.1MB/s]
100% 5.51G/5.51G [04:49<00:00, 20.4MB/s]


In [4]:
%%capture

!unzip humpback-whale-identification.zip

### Clone GitHub repo

In [1]:
import os

if os.path.exists('/content/creating-ai-enabled-systems/Research Project') == False:
  !git clone https://github.com/jwells52/creating-ai-enabled-systems.git

%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Imports

In [2]:
import torch
import json


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import FewShotClassifier, RelationNetworks, MatchingNetworks, PrototypicalNetworks, SimpleShot, TransductiveFinetuning
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet152

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loader
from modules.train import train_fsl, device, transform
from modules.plotting import fsl_plots


### Load dataset

In [3]:
train_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/training_10samples.csv')
)

test_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/validation_10samples.csv')
)

In [4]:
print(f"Min # of samples for a class in data set = {train_df['class_count'].min()}")
print(f"Max # of samples for a class in data set = {train_df['class_count'].max()}")
print(f"# of classes in data set = {len(train_df['Id'].unique())}")


Min # of samples for a class in data set = 11
Max # of samples for a class in data set = 73
# of classes in data set = 181


In [11]:
print(f"Min # of samples for a class in data set = {test_df['class_count'].min()}")
print(f"Max # of samples for a class in data set = {test_df['class_count'].max()}")
print(f"# of classes in data set = {len(test_df['Id'].unique())}")


Min # of samples for a class in data set = 11
Max # of samples for a class in data set = 48
# of classes in data set = 46


In [5]:
train_dataset = HumpbackWhaleDataset('/content/train', train_df, transform=transform)

### Train network with 1, 3, and 5 shot learning

# Train on training set for 1-shot, 5-shot, 5-way, 20-way

# The following networks are trained
1.  Relation Networks
2. Matching Networks
3. Prototypical Networks w/ Euclidean Distance
4. Prototypical Networks w/ Cosine Similarity
5. Transductive Finetuning


In [6]:
from torchvision.models.feature_extraction import create_feature_extractor

class FeatureExtractor(torch.nn.Module):
  '''
  Class for extracting feature maps from model and return a tensor and not a dictionary.
  '''
  def __init__(self, model, layer_name):
    super().__init__()
    self.model = model
    self.layer_name = layer_name

  def forward(self, x):
    return self.model(x)[self.layer_name]

In [7]:
n_epochs = 10
learning_rate = 1e-2

# Number of Training Task for each epoch
# A training task is a random sample of N shots (images) for M classes
n_task_per_epoch = 25

n_ways = [5, 20]
n_shots = [1, 5]
n_query = 5

In [10]:
# Train Relation Network
rn_losses = dict()
for n_way in n_ways:
  rn_losses[n_way] = dict()
  for n_shot in n_shots:
    train_loader = create_loader(
        train_dataset,
        n_way,
        n_shot,
        n_query,
        n_task_per_epoch,
        num_workers=12
    )
    pretrained_cnn = resnet18(weights='DEFAULT')
    pretrained_cnn.fc = torch.nn.Flatten()

    layer = "layer4.1.bn2"
    resnet_extractor = create_feature_extractor(pretrained_cnn, return_nodes=[layer])
    feature_extractor = FeatureExtractor(resnet_extractor, layer)

    relation_network = RelationNetworks(feature_extractor, feature_dimension=512).to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = SGD(relation_network.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
    print(f'Training network under {n_way}-way {n_shot}-shot')
    train_losses_rn, _ = train_fsl(
        relation_network,
        train_loader, None,
        optimizer, loss_fn, n_epochs=n_epochs,
        save_model=True, save_path=f'/content/drive/MyDrive/Research-Project/relation_network_{n_way}-way_{n_shot}-shot'
    )

    rn_losses[n_way][n_shot] = train_losses_rn
    print('='*100)

Training network under 5-way 1-shot

Epoch 1


Training: 100%|██████████| 25/25 [00:05<00:00,  4.55it/s, loss=1.61]



Epoch 2 

Training: 100%|██████████| 25/25 [00:03<00:00,  6.44it/s, loss=1.54]


Epoch 3 


Training:   0%|          | 0/25 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x78febb739a20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
Exception ignored in:     assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError<function _MultiProcessingDataLoaderIter.__del__ at 0x78febb739a20>
: can only test a child processTraceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdo


Epoch 4 


Training: 100%|██████████| 25/25 [00:03<00:00,  7.03it/s, loss=1.29]


Epoch 5 


Training: 100%|██████████| 25/25 [00:03<00:00,  6.71it/s, loss=1.31]


Epoch 6 


Training: 100%|██████████| 25/25 [00:03<00:00,  7.17it/s, loss=1.28]


Epoch 7 


Training: 100%|██████████| 25/25 [00:03<00:00,  7.13it/s, loss=1.27]



Epoch 8 

Training: 100%|██████████| 25/25 [00:03<00:00,  7.21it/s, loss=1.27]


Epoch 9 


Training: 100%|██████████| 25/25 [00:03<00:00,  7.48it/s, loss=1.28]


Epoch 10 


Training: 100%|██████████| 25/25 [00:03<00:00,  7.10it/s, loss=1.27]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_1-shot_last_epoch
Training network under 5-way 5-shot

Epoch 1


Training: 100%|██████████| 25/25 [00:04<00:00,  5.04it/s, loss=1.59]


Epoch 2 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.50it/s, loss=1.39]


Epoch 3 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.20it/s, loss=1.26]


Epoch 4 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.39it/s, loss=1.22]


Epoch 5 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.59it/s, loss=1.2]


Epoch 6 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.07it/s, loss=1.2]


Epoch 7 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.62it/s, loss=1.12]


Epoch 8 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.10it/s, loss=1.14]


Epoch 9 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.37it/s, loss=1.15]


Epoch 10 


Training: 100%|██████████| 25/25 [00:04<00:00,  5.55it/s, loss=1.1]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_5-shot_last_epoch
Training network under 20-way 1-shot

Epoch 1


Training: 100%|██████████| 25/25 [00:16<00:00,  1.50it/s, loss=2.99]


Epoch 2 


Training: 100%|██████████| 25/25 [00:13<00:00,  1.81it/s, loss=2.9]


Epoch 3 


Training: 100%|██████████| 25/25 [00:13<00:00,  1.85it/s, loss=2.67]


Epoch 4 


Training: 100%|██████████| 25/25 [00:14<00:00,  1.77it/s, loss=2.56]


Epoch 5 


Training: 100%|██████████| 25/25 [00:13<00:00,  1.81it/s, loss=2.53]


Epoch 6 


Training: 100%|██████████| 25/25 [00:13<00:00,  1.79it/s, loss=2.49]


Epoch 7 


Training: 100%|██████████| 25/25 [00:14<00:00,  1.78it/s, loss=2.47]


Epoch 8 


Training: 100%|██████████| 25/25 [00:13<00:00,  1.79it/s, loss=2.42]


Epoch 9 


Training: 100%|██████████| 25/25 [00:14<00:00,  1.78it/s, loss=2.41]


Epoch 10 


Training: 100%|██████████| 25/25 [00:13<00:00,  1.84it/s, loss=2.38]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_20-way_1-shot_last_epoch
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 25/25 [00:21<00:00,  1.16it/s, loss=2.99]


Epoch 2 


Training: 100%|██████████| 25/25 [00:21<00:00,  1.17it/s, loss=2.9]


Epoch 3 


Training: 100%|██████████| 25/25 [00:22<00:00,  1.13it/s, loss=2.57]


Epoch 4 


Training: 100%|██████████| 25/25 [00:21<00:00,  1.15it/s, loss=2.39]


Epoch 5 


Training: 100%|██████████| 25/25 [00:21<00:00,  1.15it/s, loss=2.34]


Epoch 6 


Training: 100%|██████████| 25/25 [00:22<00:00,  1.13it/s, loss=2.28]


Epoch 7 


Training: 100%|██████████| 25/25 [00:21<00:00,  1.15it/s, loss=2.24]


Epoch 8 


Training: 100%|██████████| 25/25 [00:20<00:00,  1.19it/s, loss=2.21]


Epoch 9 


Training: 100%|██████████| 25/25 [00:21<00:00,  1.16it/s, loss=2.19]


Epoch 10 


Training: 100%|██████████| 25/25 [00:21<00:00,  1.16it/s, loss=2.17]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_20-way_5-shot_last_epoch


In [None]:
train_losses_mean, train_losses_std = [], []
valid_accs_mean, valid_accs_std     = [], []

for fold in range(num_folds):
  torch.cuda.empty_cache()
  print(f'Fold {fold+1}')

  # Creating training/validation folds
  validation_fold = folds[i]
  training_folds = []
  for i in range(num_folds):
    if i != fold:
      training_folds.extend(folds[i])

  train_df = images_and_ids[images_and_ids['Id'].isin(training_folds)]
  valid_df = images_and_ids[images_and_ids['Id'].isin(validation_fold)]


  # Creating training/validation PyTorch datasets
  train_set = HumpbackWhaleDataset(
      image_dir='/content/train',
      labels=train_df,
      transform=transform

  )

  valid_set = HumpbackWhaleDataset(
      image_dir='/content/train',
      labels=valid_df,
      transform=transform
  )

  # Create DataLoaders
  train_loader, valid_loader = create_loaders(
      train_set,
      valid_set,
      n_way,
      n_shot,
      n_query,
      n_task_per_epoch,
      n_validation_tasks,
      num_workers=12
  )

  # Train and Validate
  cnn = resnet12().to(device)
  cnn.fc = torch.nn.Flatten()
  few_shot_classifier = PrototypicalNetworks(cnn).to(device)

  loss_fn = torch.nn.CrossEntropyLoss()
  optimizer = SGD(few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

  train_losses = []
  valid_accs = []
  for epoch in tqdm(range(n_epochs)):
      epoch_loss = training_epoch(few_shot_classifier, train_loader, optimizer, loss_fn)
      train_losses += [epoch_loss]

      valid_acc = evaluate(
          few_shot_classifier, valid_loader, device=device, use_tqdm=False
      )

      valid_accs += [valid_acc]

      optimizer.step()


  train_losses_mean += [np.mean(train_losses)]
  train_losses_std  += [np.std(train_losses)]

  valid_accs_mean += [np.mean(valid_accs)]
  valid_accs_std  += [np.std(valid_accs)]

  print(f'Average fold validation accuracy = {np.mean(valid_accs)} {chr(177)}{np.std(valid_accs)}')



print(f'{n_way}-way {n_shot}-shot average validation accuracy = {np.mean(valid_accs_mean)} {chr(177)}{np.std(valid_accs_std)}')

Fold 1


  0%|          | 0/100 [00:18<?, ?it/s]


OutOfMemoryError: ignored