<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install easyfsl

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/65.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Dataset

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification
!unzip humpback-whale-identification.zip

### Clone GitHub Repo

In [3]:
# !git clone https://github.com/jwells52/creating-ai-enabled-systems.git
%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Imports

In [4]:
import os

import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import PrototypicalNetworks, FewShotClassifier, SimpleShot
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet152

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loaders
from modules.train import train_fsl, device, transform

### Look at classes that have low performance

In [5]:
# Load into dataset
# For each class, calculate accuracy
# Then plot barplot of accuracy for each class
# Look at the characteristics of images for the classes that have low performance

In [6]:
df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/images_and_ids.csv')
)


df = df[df['class_count'] > 8]

In [7]:
cnn = resnet34().to(device)
few_shot_classifier = PrototypicalNetworks(cnn).to(device)
few_shot_classifier.load_state_dict(
  torch.load('/content/creating-ai-enabled-systems/Research Project/models/prototypical_network_resnet34_prod_last_epoch')()
)

<All keys matched successfully>

In [None]:
few_shot_classifier.eval()

In [11]:
n_query = 5
n_shot  = 3
n_way   = 1

In [None]:
class_performances = dict()
for id, df_id in df.groupby('Id'):
  dataset_id = HumpbackWhaleDataset('/content/train', df_id, transform=transform)

  sampler = TaskSampler(
      dataset_id, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=10
  )

  loader = DataLoader(
      dataset_id,
      batch_sampler=sampler,
      num_workers=12,
      pin_memory=True,
      collate_fn=sampler.episodic_collate_fn
  )

  acc = evaluate(
    few_shot_classifier, loader, device=device
  )

  class_performances[id] = acc

100%|██████████| 10/10 [00:10<00:00,  1.03s/it, accuracy=1]
100%|██████████| 10/10 [00:03<00:00,  3.24it/s, accuracy=1]
100%|██████████| 10/10 [00:03<00:00,  3.19it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.74it/s, accuracy=1]
100%|██████████| 10/10 [00:01<00:00,  5.55it/s, accuracy=1]
100%|██████████| 10/10 [00:01<00:00,  5.83it/s, accuracy=1]
100%|██████████| 10/10 [00:04<00:00,  2.13it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.42it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.54it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.96it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.00it/s, accuracy=1]
100%|██████████| 10/10 [00:03<00:00,  3.20it/s, accuracy=1]
100%|██████████| 10/10 [00:01<00:00,  5.26it/s, accuracy=1]
100%|██████████| 10/10 [00:01<00:00,  5.12it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.26it/s, accuracy=1]
100%|██████████| 10/10 [00:02<00:00,  4.42it/s, accuracy=1]
100%|██████████| 10/10 [00:03<00:00,  2.