<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment2_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Set up notebook

### Install easyfsl

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/65.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m41.0/65.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale identification dataset

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification


Downloading humpback-whale-identification.zip to /content
100% 5.50G/5.51G [04:55<00:00, 18.7MB/s]
100% 5.51G/5.51G [04:55<00:00, 20.0MB/s]


In [6]:
%%capture

!unzip humpback-whale-identification.zip

### Clone GitHub repo

In [1]:
import os

if os.path.exists('/content/creating-ai-enabled-systems/Research Project') == False:
  !git clone https://github.com/jwells52/creating-ai-enabled-systems.git

%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Experiment 2 - Calculating the similarity between a pretrained FSL and fine-tuned FSL

In [2]:
import torch
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from copy import deepcopy
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier, SimpleShot
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler
from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from typing import Callable
from tqdm import tqdm

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loader
from modules.train import train_fsl, device, transform
from modules.plotting import fsl_plots

%load_ext autoreload
%autoreload 2

In [3]:
# Initialize pretained Prototypical Network
cnn1 = resnet18(weights='DEFAULT')
cnn1.fc = torch.nn.Flatten()
proto1 = PrototypicalNetworks(cnn1).to(device)

In [4]:
# Initialize fine-tuned Prototypical Network
cnn2 = resnet18()
cnn2.fc = torch.nn.Flatten()
proto2 = PrototypicalNetworks(cnn2).to(device)

proto2.load_state_dict(
  torch.load('/content/drive/MyDrive/prototypical_network_resnet12_last_epoch')()
)

<All keys matched successfully>

In [5]:
%%capture

proto1.eval()
proto2.eval()

In [28]:
# Load validation set
df = remove_new_whale_class(
    pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/images_and_ids.csv')
)


df = df[df['class_count'] >= 10]

In [30]:
def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

In [48]:
diffs = []
kls = []
accs = {'proto_pretrained': [], 'proto_finetuned': []}

_df = df.copy()
i = 0
while True:
  whale_ids = list(_df.Id.unique())
  n_way = 10 if len(whale_ids) > 10 else len(whale_ids)
  print(len(whale_ids))
  if len(whale_ids) == 0: break

  dataset = HumpbackWhaleDataset('/content/train', _df, transform)
  loader = create_loader(dataset, n_way=n_way, n_shot=5, n_query=5, n_tasks=10)

  task_acc1, task_acc2 = [], []
  task_diffs = []
  task_kls = []
  for (support_images, support_labels, query_images, query_labels, _) in loader:
    support_images = support_images.to(device)
    support_labels = support_labels.to(device)
    query_images = query_images.to(device)
    query_labels = query_labels.to(device)

    proto1.process_support_set(support_images, support_labels)
    proto2.process_support_set(support_images, support_labels)

    scores1 = proto1(query_images)
    scores2 = proto2(query_images)

    acc1 = (torch.max(scores1, 1)[1] == query_labels).sum().item() / len(query_labels)
    acc2 = (torch.max(scores2, 1)[1] == query_labels).sum().item() / len(query_labels)

    diff = (torch.max(scores1, 1)[1] != torch.max(scores2, 1)[1]).sum().item()
    kl = kl_divergence(scores1.cpu().detach().numpy(), scores2.cpu().detach().numpy())

    task_acc1 += [acc1]
    task_acc2 += [acc2]
    task_diffs += [diff]
    task_kls += [kl]

  accs['proto_pretrained'] += [np.mean(task_acc1)]
  accs['proto_finetuned'] += [np.mean(task_acc2)]
  diffs += [np.mean(task_diffs)]
  kls += [np.mean(task_kls)]

  query_ids = [dataset.id_to_label[id] for id in list(set(query_labels.cpu().detach().numpy()))]
  _df = _df[~_df.Id.isin(query_ids)]

273
263
253
243
233
223
213
203
193
183
173
163
153
143
133
123
113
103
93
83
73
63
53
43
33
23
13
3
0


In [53]:
print(f"Average 5-shot 5-way accuracy for pretrained model: {np.mean(accs['proto_pretrained']):.3f} {chr(177)} {np.std(accs['proto_pretrained']):.3f}")
print(f"Average 5-shot 5-way accuracy for fine-tuned model: {np.mean(accs['proto_finetuned']):.3f} {chr(177)} {np.std(accs['proto_finetuned']):.3f}")



Average 5-shot 5-way accuracy for pretrained model: 0.483 ± 0.062
Average 5-shot 5-way accuracy for fine-tuned model: 0.975 ± 0.020


In [50]:
print(f'Average difference = {np.mean(diffs)} {chr(177)} {np.std(diffs)}')

Average difference = 25.63214285714286 ± 4.442735126941878


In [51]:
print(f'Average kl divergence = {np.mean(kls)} {chr(177)} {np.std(kls)}')


Average kl divergence = 2291.8671875 ± 433.66961669921875


In [52]:
accs['proto_finetuned']

[0.998,
 1.0,
 0.992,
 0.992,
 0.9960000000000001,
 0.9940000000000001,
 0.986,
 0.992,
 0.9739999999999999,
 0.9899999999999999,
 0.9739999999999999,
 0.982,
 0.99,
 0.984,
 0.9879999999999999,
 0.9719999999999999,
 0.9620000000000001,
 0.9719999999999999,
 0.962,
 0.962,
 0.9540000000000001,
 0.9399999999999998,
 0.95,
 0.9339999999999999,
 0.9339999999999999,
 0.9640000000000001,
 0.962,
 1.0]

In [47]:
accs['proto_pretrained']


[0.488,
 0.472,
 0.44800000000000006,
 0.45199999999999996,
 0.46399999999999997,
 0.488,
 0.508,
 0.476,
 0.5,
 0.508,
 0.41600000000000004,
 0.54,
 0.516,
 0.524,
 0.528,
 0.44400000000000006,
 0.484,
 0.42800000000000005,
 0.388,
 0.46799999999999997,
 0.46799999999999997,
 0.484,
 0.45199999999999996,
 0.45999999999999996,
 0.44800000000000006,
 0.5,
 0.44800000000000006,
 0.7333333333333333]