<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/fsl_experiment2_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Set up notebook

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification


Downloading humpback-whale-identification.zip to /content
100% 5.50G/5.51G [02:27<00:00, 35.3MB/s]
100% 5.51G/5.51G [02:27<00:00, 40.0MB/s]


In [4]:
%%capture

!unzip humpback-whale-identification.zip

In [1]:
import os

if os.path.exists('/content/creating-ai-enabled-systems/Research Project') == False:
  !git clone https://github.com/jwells52/creating-ai-enabled-systems.git

%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Experiment 2 - Calculating the similarity between a pretrained FSL and fine-tuned FSL

In [2]:
import torch
import json


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from easyfsl.methods import PrototypicalNetworks, FewShotClassifier, SimpleShot
from easyfsl.utils import evaluate
from easyfsl.samplers import TaskSampler

from torch import Tensor, nn
from torch.optim import SGD, Optimizer, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18

from typing import Callable

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class, create_loader
from modules.train import train_fsl, device, transform
from modules.plotting import fsl_plots

%load_ext autoreload
%autoreload 2

In [3]:
# Initialize pretained Prototypical Network

cnn1 = resnet18(weights='DEFAULT')
cnn1.fc = torch.nn.Flatten()
proto1 = PrototypicalNetworks(cnn1).to(device)

In [4]:
# Initialize fine-tuned Prototypical Network
cnn2 = resnet18()
cnn2.fc = torch.nn.Flatten()
proto2 = PrototypicalNetworks(cnn2).to(device)

proto2.load_state_dict(
  torch.load('/content/drive/MyDrive/prototypical_network_resnet12_last_epoch')()
)

<All keys matched successfully>

In [5]:
%%capture

proto1.eval()
proto2.eval()


In [6]:
# Load validation set
df = remove_new_whale_class(
    pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/validation_10samples.csv')
)

In [11]:
# Create loader

dataset = HumpbackWhaleDataset('/content/train', df, transform)
loader = create_loader(dataset, n_way=5, n_shot=5, n_query=1, n_tasks=50)

In [12]:
def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

In [13]:
diffs = []
kls = []

In [17]:
for (
    support_images,
    support_labels,
    query_images,
    query_labels,
    _
) in loader:
  torch.cuda.empty_cache()
  support_images = support_images.to(device)
  support_labels = support_labels.to(device)
  query_images = query_images.to(device)
  query_labels = query_labels.to(device)


  proto1.process_support_set(support_images, support_labels)
  proto2.process_support_set(support_images, support_labels)

  scores1 = proto1(query_images)
  scores2 = proto2(query_images)

  diff = (torch.max(scores1, 1)[1] != torch.max(scores2, 1)[1]).sum().item()
  kl = kl_divergence(scores1.cpu().detach().numpy(), scores2.cpu().detach().numpy())

  diffs += [diff]
  kls += [kl]

In [18]:
kls

[125.789314,
 116.20288,
 90.99115,
 107.0229,
 112.06029,
 127.60464,
 136.59718,
 138.09453,
 76.452446,
 129.30714,
 87.82487,
 104.73794,
 76.93223,
 113.379074,
 129.64378,
 134.48238,
 121.19931,
 109.34143,
 103.62476,
 112.43025,
 116.36514,
 130.55768,
 106.71642,
 107.70276,
 95.32008,
 112.26634,
 92.59087,
 138.4801,
 123.23056,
 115.55577,
 142.79984,
 105.60985,
 119.551506,
 120.42352,
 111.61211,
 111.38084,
 93.464066,
 103.507645,
 119.15788,
 72.69692,
 98.27077,
 75.99947,
 113.04985,
 78.43051,
 115.28421,
 109.86903,
 116.17953,
 80.48638,
 123.556564,
 110.373634,
 113.38274,
 127.46838,
 126.681755,
 119.60081,
 122.4711,
 113.734406,
 110.14253,
 104.08693,
 112.21687,
 131.0983,
 129.60652,
 112.3003,
 112.683525,
 94.834854,
 114.95829,
 121.26494,
 132.1151,
 107.18276,
 106.568306,
 99.407364,
 96.286194,
 90.21943,
 138.65332,
 124.30611,
 117.32862,
 141.28682,
 100.40926,
 80.17983,
 135.83781,
 117.19852,
 109.58702,
 108.456635,
 130.48163,
 145.16086,

In [22]:
print(f'Average difference = {np.mean(diffs)} {chr(177)} {np.std(diffs)}')

Average difference = 2.54 ± 1.1438531374263043


In [23]:
print(f'Average kl divergence = {np.mean(kls)} {chr(177)} {np.std(kls)}')


Average kl divergence = 112.58384704589844 ± 16.937910079956055
