<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install easyfsl

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/65.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m61.4/65.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification
!unzip humpback-whale-identification.zip

### Clone GitHub Repo

In [1]:
# !git clone https://github.com/jwells52/creating-ai-enabled-systems.git
%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Imports

In [2]:
import os
import cv2
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import PrototypicalNetworks, FewShotClassifier, SimpleShot
from easyfsl.utils import evaluate, evaluate_on_one_task
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet152

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loaders
from modules.train import train_fsl, device, transform


%load_ext autoreload
%autoreload 2

### Look at classes that have low performance

In [3]:
# Load into dataset
# For each class, calculate accuracy
# Then plot barplot of accuracy for each class
# Look at the characteristics of images for the classes that have low performance

In [4]:
df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/validation_10samples.csv')
)


# df = df[df['class_count'] > 8]

In [61]:
class PrototypicalNetworksLocal(torch.nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworksLocal, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)


        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))

        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in support_labels
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores

In [62]:
cnn = resnet34(weights=True).to(device)
cnn.fc = torch.nn.Flatten()
few_shot_classifier = PrototypicalNetworksLocal(cnn).to(device)
# few_shot_classifier.load_state_dict(
#   torch.load('/content/creating-ai-enabled-systems/Research Project/models/prototypical_network_resnet34_prod_last_epoch')()
# )



In [63]:
few_shot_classifier.eval()

PrototypicalNetworksLocal(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, aff

In [64]:
n_way   = 4
n_shot  = 3
n_query = 5

In [65]:
####################################
# Evaluate accuracy for each class
###################################

# For each class
#  create a query set containing class examples
#  create a support set containing N classes
#  pass query set and support set in evaluate function from easyfsl

In [66]:
def image_reader_single(path):
  image = cv2.imread(path)
  image = cv2.resize(image, (512, 256))
  return image

In [67]:
class2idx = {k:i for i, k in enumerate(df['Id'].unique())}
idx2class = {i:k for k, i in class2idx.items()}

In [85]:
class_accs = dict()
for q_id, df_id in df.groupby('Id'):
  _df = df.copy()

  # Create query set
  query_set_labels = np.array([class2idx[q_id] for _ in range(n_query)])
  query_set_images_paths = np.random.choice(df_id['Image'], size=n_query)
  query_set_images = np.array(
      [
          image_reader_single(os.path.join('/content/train', image))
            for image in query_set_images_paths
      ]
    )

  # Remove query images from dataset
  _df = _df[~_df['Image'].isin(query_set_images_paths)]

  # Create support set
  support_classes = np.random.choice(_df['Id'].unique(), size=n_way, replace=False)
  support_classes = np.array([class2idx[s_id] for s_id in support_classes] + [class2idx[q_id]])

  support_set_labels = np.zeros((n_way+1)*n_shot, dtype=int)
  support_set_images = np.zeros(((n_way+1)*n_shot, 256, 512, 3))

  cnt = 0
  for i, s_id in enumerate(support_classes):
    class_support_set_image_paths = np.random.choice(_df[_df['Id'] == idx2class[s_id]]['Image'], size=n_shot)
    for image_path in class_support_set_image_paths:
      support_set_labels[cnt] = s_id
      support_set_images[cnt] = image_reader_single(os.path.join('/content/train', image_path))
      cnt += 1

  # Convert sets to Tensors
  query_set_labels_tensor = torch.Tensor(query_set_labels).to(device)
  query_set_images_tensor = transform(torch.Tensor(query_set_images).to(device).permute(0, 3, 1, 2))
  support_set_labels_tensor = torch.Tensor(support_set_labels).to(device)
  support_set_images_tensor = transform(torch.Tensor(support_set_images).to(device).permute(0, 3, 1, 2))

  # Pass support set and query set through Prototypical Network
  predictions = few_shot_classifier(support_set_images_tensor, support_set_labels_tensor, query_set_images_tensor).detach().data
  predicted_labels = support_set_labels_tensor[torch.max(predictions, 1)[1]]



  num_correct = (
    (predicted_labels == query_set_labels_tensor).sum().item()
  )

  print(f'Whale Id: {q_id} | num_correct={num_correct} | len_query={n_query}')

z_query
torch.Size([5, 512])
tensor([[ 2.9121,  0.0000, 20.9017,  ...,  2.9795,  3.0193, 42.2480],
        [ 5.0512,  0.0000, 27.0125,  ...,  3.3605,  3.6738, 49.0319],
        [ 4.1363,  0.0000, 22.9538,  ...,  2.1520,  4.9734, 40.2128],
        [ 2.7724,  0.0000, 17.2667,  ...,  2.1580,  1.7648, 40.7929],
        [ 2.0499,  0.0000, 12.3304,  ...,  2.3795,  1.1136, 30.0161]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)
z_support
torch.Size([15, 512])
tensor([[ 6.9990,  0.0000, 21.8969,  ...,  1.5230, 12.9097, 29.8208],
        [ 5.3496,  0.0000, 21.0384,  ...,  3.3035,  4.3508, 44.9626],
        [ 5.6909,  0.0000, 17.9765,  ...,  3.1184,  3.9325, 38.7819],
        ...,
        [ 3.5204,  0.0000, 19.2850,  ...,  0.8368,  4.1439, 38.8600],
        [ 4.4542,  0.0000, 23.4535,  ...,  2.1361,  5.2636, 42.2278],
        [ 3.0750,  0.0000, 26.3267,  ...,  2.5041,  2.2215, 46.6997]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)
support_labels
tensor([26., 26., 26.,  6

KeyboardInterrupt: ignored