<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install EasyFSL

In [None]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/65.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m61.4/65.2 kB[0m [31m1.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Identification dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification
!unzip humpback-whale-identification.zip

### Clone GitHub repo

In [None]:
# !git clone https://github.com/jwells52/creating-ai-enabled-systems.git
%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Imports

In [None]:
import os

import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import PrototypicalNetworks, FewShotClassifier, SimpleShot
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet152

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loaders
from modules.train import train_fsl, device, transform
from modules.plotting import fsl_plots


### Load dataset

In [None]:
train_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/training_10samples.csv')
)

valid_df = remove_new_whale_class(
    pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/validation_10samples.csv')
)


In [None]:
print(f"Min # of samples for a class in training set = {train_df['class_count'].min()}")
print(f"Max # of samples for a class in training set = {train_df['class_count'].max()}")
print(f"# of classes in training set = {len(train_df['Id'].unique())}")


Min # of samples for a class in training set = 11
Max # of samples for a class in training set = 73
# of classes in training set = 181


In [None]:
print(f"Min # of samples for a class in validation set = {valid_df['class_count'].min()}")
print(f"Max # of samples for a class in validation set = {valid_df['class_count'].max()}")
print(f"# of classes in validation set = {len(valid_df['Id'].unique())}")

Min # of samples for a class in validation set = 11
Max # of samples for a class in validation set = 48
# of classes in validation set = 46


### Set up PyTorch Dataset

In [None]:
train_set = HumpbackWhaleDataset(
    image_dir='/content/train',
    labels=train_df,
    transform=transform
)

valid_set = HumpbackWhaleDataset(
    image_dir='/content/train',
    labels=valid_df,
    transform=transform
)

### Train Prototypical Network with ResNet18, ResNet34, and ResNet152 as the feature extractor

In [None]:
# Train a fsl model with 1, 3, and 5 shot learning

# Define learning rate and epochs
n_epochs = 25
learning_rate = 1e-2

# Number of Training Task for each epoch
# A training task is a random sample of N shots (images) for M classes
n_task_per_epoch = 100

# Number of Validation tasks for evaluation during training
n_validation_tasks = 45

n_way   = 5
n_shots = 5
n_query = 5

backbones = ['resnet18', 'resnet34', 'resnet152']

train_loss_dict = dict()
valid_acc_dict  = dict()
for backbone in backbones:
  print(f'Training and Evaluating {backbone} as feature extractor')
  cnn = torch.hub.load('pytorch/vision:v0.10.0', backbone, pretrained=True)
  cnn.fc = torch.nn.Flatten()
  cnn = cnn.to(device)

  few_shot_classifier = PrototypicalNetworks(cnn).to(device)
  loss_fn = torch.nn.CrossEntropyLoss()
  optimizer = SGD(few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

  train_loader, valid_loader = create_loaders(
      train_set, valid_set,
      n_way,
      n_shot,
      n_query,
      n_task_per_epoch,
      n_validation_tasks
  )

  train_losses, valid_accs = train_fsl(
      few_shot_classifier,
      train_loader,
      valid_loader,
      optimizer,
      loss_fn,
      n_epochs=n_epochs,
      use_tqdm=True,
      save_model=False
  )

  train_loss_dict[backbone] = train_losses
  valid_acc_dict[backbone]  = valid_accs

Training and Evaluating 1-shot learning
Epoch 1


Training: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s, loss=1.86]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.34it/s, accuracy=0.287]

Epoch 2



Training: 100%|██████████| 100/100 [00:29<00:00,  3.37it/s, loss=1.64]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s, accuracy=0.335]

Epoch 3



Training: 100%|██████████| 100/100 [00:29<00:00,  3.38it/s, loss=1.52]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s, accuracy=0.342]

Epoch 4



Training: 100%|██████████| 100/100 [00:29<00:00,  3.36it/s, loss=1.5]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.27it/s, accuracy=0.388]

Epoch 5



Training: 100%|██████████| 100/100 [00:30<00:00,  3.30it/s, loss=1.44]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s, accuracy=0.376]

Epoch 6



Training: 100%|██████████| 100/100 [00:29<00:00,  3.42it/s, loss=1.36]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.34it/s, accuracy=0.504]

Epoch 7



Training: 100%|██████████| 100/100 [00:29<00:00,  3.42it/s, loss=1.31]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.34it/s, accuracy=0.414]

Epoch 8



Training: 100%|██████████| 100/100 [00:29<00:00,  3.38it/s, loss=1.21]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s, accuracy=0.509]

Epoch 9



Training: 100%|██████████| 100/100 [00:29<00:00,  3.39it/s, loss=1.16]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s, accuracy=0.518]

Epoch 10



Training: 100%|██████████| 100/100 [00:29<00:00,  3.35it/s, loss=1.17]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.31it/s, accuracy=0.515]

Epoch 11



Training: 100%|██████████| 100/100 [00:29<00:00,  3.38it/s, loss=1.08]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s, accuracy=0.519]

Epoch 12



Training: 100%|██████████| 100/100 [00:29<00:00,  3.35it/s, loss=1.09]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.32it/s, accuracy=0.555]

Epoch 13



Training: 100%|██████████| 100/100 [00:29<00:00,  3.40it/s, loss=1.01]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.38it/s, accuracy=0.617]

Epoch 14



Training: 100%|██████████| 100/100 [00:29<00:00,  3.40it/s, loss=0.964]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.28it/s, accuracy=0.562]

Epoch 15



Training: 100%|██████████| 100/100 [00:29<00:00,  3.43it/s, loss=0.922]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.27it/s, accuracy=0.574]

Epoch 16



Training: 100%|██████████| 100/100 [00:29<00:00,  3.36it/s, loss=0.781]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.28it/s, accuracy=0.603]

Epoch 17



Training: 100%|██████████| 100/100 [00:29<00:00,  3.34it/s, loss=0.83]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.33it/s, accuracy=0.577]

Epoch 18



Training: 100%|██████████| 100/100 [00:29<00:00,  3.34it/s, loss=0.844]
Validation: 100%|██████████| 45/45 [00:14<00:00,  3.20it/s, accuracy=0.628]

Epoch 19



Training: 100%|██████████| 100/100 [00:29<00:00,  3.41it/s, loss=0.772]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.29it/s, accuracy=0.62]

Epoch 20



Training: 100%|██████████| 100/100 [00:29<00:00,  3.35it/s, loss=0.774]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.45it/s, accuracy=0.676]

Epoch 21



Training: 100%|██████████| 100/100 [00:29<00:00,  3.35it/s, loss=0.704]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.27it/s, accuracy=0.627]

Epoch 22



Training: 100%|██████████| 100/100 [00:29<00:00,  3.42it/s, loss=0.694]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.27it/s, accuracy=0.626]

Epoch 23



Training: 100%|██████████| 100/100 [00:29<00:00,  3.38it/s, loss=0.629]
Validation: 100%|██████████| 45/45 [00:13<00:00,  3.35it/s, accuracy=0.684]

Epoch 24



Training: 100%|██████████| 100/100 [00:29<00:00,  3.37it/s, loss=0.651]
Validation: 100%|██████████| 45/45 [00:14<00:00,  3.21it/s, accuracy=0.65]

Epoch 25



Training: 100%|██████████| 100/100 [00:29<00:00,  3.36it/s, loss=0.581]
Validation: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s, accuracy=0.66]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/prototypical_network_resnet12_last_epoch
Training and Evaluating 3-shot learning
Epoch 1


Training: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s, loss=1.69]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.45it/s, accuracy=0.359]

Epoch 2



Training: 100%|██████████| 100/100 [00:40<00:00,  2.49it/s, loss=1.49]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.41it/s, accuracy=0.434]

Epoch 3



Training: 100%|██████████| 100/100 [00:39<00:00,  2.53it/s, loss=1.38]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.42it/s, accuracy=0.463]

Epoch 4



Training: 100%|██████████| 100/100 [00:40<00:00,  2.49it/s, loss=1.22]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.39it/s, accuracy=0.482]

Epoch 5



Training: 100%|██████████| 100/100 [00:39<00:00,  2.50it/s, loss=1.18]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.49it/s, accuracy=0.515]

Epoch 6



Training: 100%|██████████| 100/100 [00:39<00:00,  2.53it/s, loss=1.08]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.48it/s, accuracy=0.588]

Epoch 7



Training: 100%|██████████| 100/100 [00:39<00:00,  2.53it/s, loss=1.02]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.42it/s, accuracy=0.556]

Epoch 8



Training: 100%|██████████| 100/100 [00:39<00:00,  2.53it/s, loss=0.976]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.41it/s, accuracy=0.547]

Epoch 9



Training: 100%|██████████| 100/100 [00:39<00:00,  2.50it/s, loss=0.883]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.48it/s, accuracy=0.63]

Epoch 10



Training: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s, loss=0.827]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.46it/s, accuracy=0.672]

Epoch 11



Training: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s, loss=0.762]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.46it/s, accuracy=0.662]

Epoch 12



Training: 100%|██████████| 100/100 [00:40<00:00,  2.48it/s, loss=0.765]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.46it/s, accuracy=0.678]

Epoch 13



Training: 100%|██████████| 100/100 [00:40<00:00,  2.50it/s, loss=0.726]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.43it/s, accuracy=0.691]

Epoch 14



Training: 100%|██████████| 100/100 [00:40<00:00,  2.49it/s, loss=0.591]
Validation: 100%|██████████| 45/45 [00:17<00:00,  2.52it/s, accuracy=0.715]

Epoch 15



Training: 100%|██████████| 100/100 [00:39<00:00,  2.55it/s, loss=0.629]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.45it/s, accuracy=0.706]

Epoch 16



Training: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s, loss=0.518]
Validation: 100%|██████████| 45/45 [00:17<00:00,  2.52it/s, accuracy=0.735]

Epoch 17



Training: 100%|██████████| 100/100 [00:40<00:00,  2.49it/s, loss=0.488]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.46it/s, accuracy=0.725]

Epoch 18



Training: 100%|██████████| 100/100 [00:39<00:00,  2.52it/s, loss=0.453]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.45it/s, accuracy=0.768]

Epoch 19



Training: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s, loss=0.426]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.37it/s, accuracy=0.751]

Epoch 20



Training: 100%|██████████| 100/100 [00:39<00:00,  2.53it/s, loss=0.406]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.44it/s, accuracy=0.754]

Epoch 21



Training: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s, loss=0.374]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.44it/s, accuracy=0.784]

Epoch 22



Training: 100%|██████████| 100/100 [00:39<00:00,  2.55it/s, loss=0.406]
Validation: 100%|██████████| 45/45 [00:17<00:00,  2.51it/s, accuracy=0.764]

Epoch 23



Training: 100%|██████████| 100/100 [00:39<00:00,  2.55it/s, loss=0.305]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.45it/s, accuracy=0.762]

Epoch 24



Training: 100%|██████████| 100/100 [00:40<00:00,  2.50it/s, loss=0.274]
Validation: 100%|██████████| 45/45 [00:17<00:00,  2.52it/s, accuracy=0.784]

Epoch 25



Training: 100%|██████████| 100/100 [00:39<00:00,  2.55it/s, loss=0.304]
Validation: 100%|██████████| 45/45 [00:18<00:00,  2.49it/s, accuracy=0.799]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/prototypical_network_resnet12_last_epoch
Training and Evaluating 5-shot learning
Epoch 1


Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=1.63]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.97it/s, accuracy=0.365]

Epoch 2



Training: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s, loss=1.46]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.92it/s, accuracy=0.43]

Epoch 3



Training: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s, loss=1.31]
Validation: 100%|██████████| 45/45 [00:22<00:00,  2.00it/s, accuracy=0.496]

Epoch 4



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=1.19]
Validation: 100%|██████████| 45/45 [00:22<00:00,  2.00it/s, accuracy=0.522]

Epoch 5



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=1.11]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.97it/s, accuracy=0.569]

Epoch 6



Training: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s, loss=0.998]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.99it/s, accuracy=0.6]

Epoch 7



Training: 100%|██████████| 100/100 [00:49<00:00,  2.00it/s, loss=0.921]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.95it/s, accuracy=0.633]

Epoch 8



Training: 100%|██████████| 100/100 [00:50<00:00,  2.00it/s, loss=0.772]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.96it/s, accuracy=0.648]

Epoch 9



Training: 100%|██████████| 100/100 [00:49<00:00,  2.00it/s, loss=0.798]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.92it/s, accuracy=0.704]

Epoch 10



Training: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s, loss=0.765]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.93it/s, accuracy=0.667]

Epoch 11



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=0.649]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.94it/s, accuracy=0.684]

Epoch 12



Training: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s, loss=0.602]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.94it/s, accuracy=0.746]

Epoch 13



Training: 100%|██████████| 100/100 [00:50<00:00,  1.98it/s, loss=0.56]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.88it/s, accuracy=0.732]

Epoch 14



Training: 100%|██████████| 100/100 [00:49<00:00,  2.00it/s, loss=0.497]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.92it/s, accuracy=0.749]

Epoch 15



Training: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s, loss=0.46]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.91it/s, accuracy=0.767]

Epoch 16



Training: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s, loss=0.472]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.95it/s, accuracy=0.779]

Epoch 17



Training: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s, loss=0.383]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.98it/s, accuracy=0.74]

Epoch 18



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=0.374]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.92it/s, accuracy=0.78]

Epoch 19



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=0.343]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.93it/s, accuracy=0.781]

Epoch 20



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=0.331]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.98it/s, accuracy=0.806]

Epoch 21



Training: 100%|██████████| 100/100 [00:50<00:00,  2.00it/s, loss=0.32]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.96it/s, accuracy=0.815]

Epoch 22



Training: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s, loss=0.314]
Validation: 100%|██████████| 45/45 [00:22<00:00,  1.99it/s, accuracy=0.778]

Epoch 23



Training: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s, loss=0.299]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.94it/s, accuracy=0.825]

Epoch 24



Training: 100%|██████████| 100/100 [00:49<00:00,  2.00it/s, loss=0.298]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.95it/s, accuracy=0.828]

Epoch 25



Training: 100%|██████████| 100/100 [00:49<00:00,  2.02it/s, loss=0.228]
Validation: 100%|██████████| 45/45 [00:23<00:00,  1.95it/s, accuracy=0.852]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/prototypical_network_resnet12_last_epoch


In [None]:
for backbone in backbones:
  fig,ax = plt.subplots()

  plt.title(f'Prototypical Network with {backbone} Learning Curves')

  ax.plot(train_losses[backbone], color='blue')
  ax.set_xlabel("epochs")
  ax.set_ylabel('Training loss')

  ax2=ax.twinx()
  ax2.plot(valid_accs[backbone], color='orange')
  ax2.set_ylabel('Validation accuracy')

  plt.show()
  fig.savefig(f'/content/prototypical_network_{backbone}_learning_curves_100epochs.png', format='png')

# 5-fold cross validation for the following shots

* 1-shot 10-way
* 5-shot 10-way
* 10-shot 10-way

In [None]:
# Number of Training Task for each epoch
# A training task is a random sample of N shots (images) for M classes
n_task_per_epoch = 40

# Number of Validation tasks for evaluation during training
n_validation_tasks = 20

n_way   = 10
n_shot  = 1
n_query = 10

In [None]:
n_epochs = 100
learning_rate = 1e-2

# loss_fn = torch.nn.CrossEntropyLoss()
# optimizer = SGD(few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

In [None]:
# Load dataset
images_and_ids = pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/images_and_ids.csv')
images_and_ids = images_and_ids[images_and_ids['class_count'] > 10]
images_and_ids

Unnamed: 0,Image,Id,class_count
0,0000e88ab.jpg,w_f48451c,14
3,000a6daec.jpg,w_dd88965,16
6,001cae55b.jpg,w_581ba42,14
11,004e8ad5b.jpg,w_3de579a,54
12,004f87702.jpg,w_1d0830e,11
...,...,...,...
15688,ffca5cb22.jpg,w_51e7506,15
15689,ffcd5efdc.jpg,w_f765256,34
15691,ffe52d320.jpg,w_bc285a6,21
15693,ffef89eed.jpg,w_9c506f6,62


In [None]:
whale_ids = np.unique(images_and_ids['Id'])
len(whale_ids)

227

In [None]:
# Create folds
num_folds = 5
folds = []
for i in range(num_folds):
  start_fold = (len(whale_ids)//num_folds)*i
  end_fold   = (len(whale_ids)//num_folds)*(i+1)
  if i == num_folds-1:
    end_fold = len(whale_ids)

  fold = whale_ids[start_fold:end_fold]
  folds += [fold]

In [None]:
transform = transforms.Compose(
        [
              transforms.Grayscale(num_output_channels=3),
              transforms.Resize((256, 512)),
              transforms.ToTensor(),
        ])

In [None]:
train_losses_mean, train_losses_std = [], []
valid_accs_mean, valid_accs_std     = [], []
for fold in range(num_folds):
  torch.cuda.empty_cache()
  print(f'Fold {fold+1}')

  # Creating training/validation folds
  validation_fold = folds[i]
  training_folds = []
  for i in range(num_folds):
    if i != fold:
      training_folds.extend(folds[i])

  train_df = images_and_ids[images_and_ids['Id'].isin(training_folds)]
  valid_df = images_and_ids[images_and_ids['Id'].isin(validation_fold)]


  # Creating training/validation PyTorch datasets
  train_set = HumpbackWhaleDataset(
      image_dir='/content/train',
      labels=train_df,
      transform=transform

  )

  valid_set = HumpbackWhaleDataset(
      image_dir='/content/train',
      labels=valid_df,
      transform=transform
  )

  # Create TaskSamplers
  train_sampler = TaskSampler(
      train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_task_per_epoch
  )

  valid_sampler = TaskSampler(
      valid_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
  )

  # Create DataLoaders
  train_loader = DataLoader(
      train_set,
      batch_sampler=train_sampler,
      num_workers=12,
      pin_memory=True,
      collate_fn=train_sampler.episodic_collate_fn
  )

  valid_loader = DataLoader(
      valid_set,
      batch_sampler=valid_sampler,
      num_workers=12,
      pin_memory=True,
      collate_fn=valid_sampler.episodic_collate_fn
  )

  # Train and Validate
  cnn = resnet12().to(device)
  few_shot_classifier = PrototypicalNetworks(cnn).to(device)

  loss_fn = torch.nn.CrossEntropyLoss()
  optimizer = SGD(few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

  train_losses = []
  valid_accs = []
  for epoch in tqdm(range(n_epochs)):
      epoch_loss = training_epoch(few_shot_classifier, train_loader, optimizer, loss_fn)
      train_losses += [epoch_loss]

      valid_acc = evaluate(
          few_shot_classifier, valid_loader, device=device, use_tqdm=False
      )

      valid_accs += [valid_acc]

      optimizer.step()


  train_losses_mean += [np.mean(train_losses)]
  train_losses_std  += [np.std(train_losses)]

  valid_accs_mean += [np.mean(valid_accs)]
  valid_accs_std  += [np.std(valid_accs)]

  print(f'Average fold validation accuracy = {np.mean(valid_accs)} {chr(177)}{np.std(valid_accs)}')

Fold 1


  0%|          | 0/100 [00:18<?, ?it/s]


OutOfMemoryError: ignored