<a href="https://colab.research.google.com/github/jwells52/creating-ai-enabled-systems/blob/main/Research%20Project/notebooks/fsl_experiment1_new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install EasyFSL

In [1]:
%pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.4.0-py3-none-any.whl (65 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/65.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.2/65.2 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.4.0


### Download Humpback Whale Identification dataset

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!rm -rf /root/.kaggle && mkdir /root/.kaggle && cp /content/drive/MyDrive/Research-Project/kaggle.json /root/.kaggle/kaggle.json && chmod 600 /root/.kaggle/kaggle.json && kaggle competitions download -c humpback-whale-identification

Downloading humpback-whale-identification.zip to /content
100% 5.50G/5.51G [03:18<00:00, 37.5MB/s]
100% 5.51G/5.51G [03:18<00:00, 29.8MB/s]


In [4]:
%%capture

!unzip humpback-whale-identification.zip

### Clone GitHub repo

In [1]:
!rm -rf /content/creating-ai-enabled-systems

In [5]:
import os

if os.path.exists('/content/creating-ai-enabled-systems/Research Project') == False:
  !git clone https://github.com/jwells52/creating-ai-enabled-systems.git

%cd creating-ai-enabled-systems/Research\ Project

/content/creating-ai-enabled-systems/Research Project


### Imports

In [6]:
import time
import torch
import pandas as pd
import numpy as np

from easyfsl.methods import RelationNetworks, MatchingNetworks, PrototypicalNetworks
from easyfsl.utils import evaluate
from torchvision.models import resnet18
from torchvision.models.feature_extraction import create_feature_extractor

from modules.data_utils import HumpbackWhaleDataset, remove_new_whale_class
from modules.train import train_network, transform, FeatureExtractor, device
from modules.data_utils import create_loader


### Load dataset

In [4]:
train_df = remove_new_whale_class(
  pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/training_10samples.csv')
)

In [5]:
print(f"Min # of samples for a class in data set = {train_df['class_count'].min()}")
print(f"Max # of samples for a class in data set = {train_df['class_count'].max()}")
print(f"# of classes in data set = {len(train_df['Id'].unique())}")


Min # of samples for a class in data set = 11
Max # of samples for a class in data set = 73
# of classes in data set = 181


In [None]:
train_dataset = HumpbackWhaleDataset('/content/train', train_df, transform=transform)

# Train on training set for 1-shot, 5-shot, 5-way, 20-way

# The following networks are trained
1.  Relation Networks
2. Matching Networks
3. Prototypical Networks

In [8]:
n_epochs = 10
learning_rate = 1e-2

# Number of Training Task for each epoch
# A training task is a random sample of N shots (images) for M classes
n_task_per_epoch = 100

n_ways = [5, 20]
n_shots = [1, 5]
n_query = 5

### Relation Network

In [None]:
print('Training Relation Network')
rn_losses = train_network(
    RelationNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/relation_network',
    feature_extractor=True,
    feature_dimension=512
)

Training Relation Network
Training network under 5-way 1-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:10<00:00,  9.77it/s, loss=1.47]


Epoch 2 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.79it/s, loss=1.3]


Epoch 3 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.96it/s, loss=1.28]



Epoch 4 

Training: 100%|██████████| 100/100 [00:09<00:00, 10.54it/s, loss=1.27]



Epoch 5 

Training: 100%|██████████| 100/100 [00:09<00:00, 10.04it/s, loss=1.23]



Epoch 6 

Training: 100%|██████████| 100/100 [00:10<00:00,  9.84it/s, loss=1.23]


Epoch 7 


Training: 100%|██████████| 100/100 [00:10<00:00,  9.94it/s, loss=1.17]


Epoch 8 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.19it/s, loss=1.15]


Epoch 9 


Training: 100%|██████████| 100/100 [00:09<00:00, 10.20it/s, loss=1.17]



Epoch 10 

Training: 100%|██████████| 100/100 [00:09<00:00, 10.07it/s, loss=1.12]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_1-shot_last_epoch
Training network under 5-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:14<00:00,  6.97it/s, loss=1.46]


Epoch 2 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.99it/s, loss=1.31]


Epoch 3 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.00it/s, loss=1.22]


Epoch 4 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.93it/s, loss=1.19]


Epoch 5 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.92it/s, loss=1.12]


Epoch 6 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.98it/s, loss=1.09]


Epoch 7 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.03it/s, loss=1.07]


Epoch 8 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.02it/s, loss=1.05]


Epoch 9 


Training: 100%|██████████| 100/100 [00:14<00:00,  6.97it/s, loss=1.04]


Epoch 10 


Training: 100%|██████████| 100/100 [00:14<00:00,  7.07it/s, loss=1.02]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_5-way_5-shot_last_epoch
Training network under 20-way 1-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:38<00:00,  2.61it/s, loss=2.77]


Epoch 2 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.65it/s, loss=2.48]


Epoch 3 


Training: 100%|██████████| 100/100 [00:38<00:00,  2.61it/s, loss=2.37]


Epoch 4 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=2.31]


Epoch 5 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.63it/s, loss=2.26]


Epoch 6 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.64it/s, loss=2.23]


Epoch 7 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.65it/s, loss=2.2]


Epoch 8 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=2.17]


Epoch 9 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.65it/s, loss=2.16]


Epoch 10 


Training: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s, loss=2.15]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_20-way_1-shot_last_epoch
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=2.7]


Epoch 2 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.29]


Epoch 3 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.2]


Epoch 4 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.15]


Epoch 5 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.13]


Epoch 6 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.11]


Epoch 7 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.11]


Epoch 8 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=2.1]


Epoch 9 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=2.1]


Epoch 10 


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=2.1]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/relation_network_20-way_5-shot_last_epoch


### Matching Network

In [None]:
print('Training Matching Network')
mn_losses = train_network(
    MatchingNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/matching_network',
    feature_extractor=False,
    feature_dimension=512
)

Training Matching Network
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.584]


Epoch 2 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=0.318]


Epoch 3 


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=0.185]


Epoch 4 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=0.0608]


Epoch 5 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.68it/s, loss=0.0378]


Epoch 6 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.0225]


Epoch 7 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.0174]


Epoch 8 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.70it/s, loss=0.012]


Epoch 9 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.67it/s, loss=0.00975]


Epoch 10 


Training: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s, loss=0.00905]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/matching_network_20-way_5-shot_last_epoch


### Prototypical Network

In [None]:
print('Training Prototypical Network')
pt_losses = train_network(
    PrototypicalNetworks,
    train_dataset,
    n_ways, n_shots, n_query, n_task_per_epoch,
    '/content/drive/MyDrive/Research-Project/prototypical_network',
    feature_extractor=False,
    feature_dimension=None
)

Training Prototypical Network
Training network under 20-way 5-shot

Epoch 1


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.397]



Epoch 2 

Training: 100%|██████████| 100/100 [01:01<00:00,  1.62it/s, loss=0.0615]


Epoch 3 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.0278]


Epoch 4 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.70it/s, loss=0.0162]


Epoch 5 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=0.00995]


Epoch 6 


Training: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s, loss=0.00837]


Epoch 7 


Training: 100%|██████████| 100/100 [00:59<00:00,  1.69it/s, loss=0.00679]


Epoch 8 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.00591]


Epoch 9 


Training: 100%|██████████| 100/100 [00:58<00:00,  1.71it/s, loss=0.00458]



Epoch 10 

Training: 100%|██████████| 100/100 [00:58<00:00,  1.72it/s, loss=0.00404]


Saving state of model checkpoint at last epoch to /content/drive/MyDrive/Research-Project/prototypical_network_20-way_5-shot_last_epoch


# Evaluating networks on test dataset

In [20]:
def load_networks_from_checkpoint(key_name, n_ways, n_shots, backbone, network, **kwargs):
    network_dict = dict()
    for n_way in n_ways:
        network_dict[n_way] = dict()
        for n_shot in n_shots:

            fsl_network = network(backbone, **kwargs)
            fsl_network.load_state_dict(torch.load(f'/content/drive/MyDrive/Research-Project/{key_name}_{n_way}-way_{n_shot}-shot_last_epoch'))
            network_dict[n_way][n_shot] = fsl_network.to(device)

    return network_dict

In [14]:
def get_accuracy_table(networks, test_dataset, ways, shots, n_query, verbose=True, device=device, pretrained=True):
    '''
    Very hacky way of formatting dataframe that displays the accuracies for each network under different k-ways n-shot learning.
    '''
    accs = dict()
    for network in networks:
        accs[network] = dict()
        for way in ways:
            accs[network][way] = {shot: np.nan for shot in shots}


    for network_name in networks.keys():
        for n_way in ways:
            for n_shot in shots:
                if verbose: print(f'Evaluating {network_name} for {n_way}-way {n_shot}-shot:', end=' ')
                if pretrained:
                  network = networks[network_name]
                else:
                  network = networks[network_name][n_way][n_shot]

                loader = create_loader(test_dataset, n_way, n_shot, n_query, 5, num_workers=2)
                accs[network_name][n_way][n_shot] = evaluate(network, loader, device=device, use_tqdm=True)
                time.sleep(0.1)
                if verbose: print()

    accs_v2 = dict()
    for network in accs.keys():
        accs_v2[network] = dict()
        for way in accs[network].keys():
            accs_v2[network][f'{way}-way'] = [acc for acc in accs[network][way].values()]
    accs_df = pd.DataFrame.from_dict(accs_v2).T

    accs_dict2 = dict()
    for network, row in accs_df.iterrows():
        accs_dict2[network] = dict()
        for i, way in enumerate(row):
            for j, shot in enumerate(way):
                accs_dict2[network][f'{ways[i]}-way {shots[j]}-shot'] = shot
    accs_df2 = pd.DataFrame.from_dict(accs_dict2).T
    multi_columns = [(f'{n_way}-way', f'{n_shot}-shot') for n_way in ways for n_shot in shots]
    accs_df2.columns = pd.MultiIndex.from_tuples(multi_columns)

    return accs_df2

In [None]:
df_test = remove_new_whale_class(
    pd.read_csv('/content/creating-ai-enabled-systems/Research Project/data/testing_10samples.csv')
)

test_dataset = HumpbackWhaleDataset('/content/train', df_test, transform=transform)

In [35]:
############################
# Load pretrained networks
#############################
cnn = resnet18(weights='DEFAULT')
cnn.fc = torch.nn.Flatten()
resnet_extractor = create_feature_extractor(cnn, return_nodes=['layer4.1.bn2'])
feature_extractor = FeatureExtractor(resnet_extractor, 'layer4.1.bn2')

pretrained_networks = {
    'Relation Network': RelationNetworks(feature_extractor, feature_dimension=512).to(device),
    'Matching Network': MatchingNetworks(cnn, feature_dimension=512).to(device),
    'Prototypical Network': PrototypicalNetworks(cnn).to(device),
}

In [36]:
pretained_acc_table = get_accuracy_table(pretrained_networks)
pretained_acc_table

Evaluating Relation Network for 5-way 1-shot: 

100%|██████████| 5/5 [00:02<00:00,  1.83it/s, accuracy=0.208]


Evaluating Relation Network for 5-way 5-shot: 


100%|██████████| 5/5 [00:04<00:00,  1.14it/s, accuracy=0.208]


Evaluating Relation Network for 20-way 1-shot: 


100%|██████████| 5/5 [00:14<00:00,  2.81s/it, accuracy=0.036]



Evaluating Relation Network for 20-way 5-shot: 

100%|██████████| 5/5 [00:19<00:00,  3.95s/it, accuracy=0.06]



Evaluating Matching Network for 5-way 1-shot: 

100%|██████████| 5/5 [00:04<00:00,  1.09it/s, accuracy=0.4]


Evaluating Matching Network for 5-way 5-shot: 


100%|██████████| 5/5 [00:03<00:00,  1.26it/s, accuracy=0.656]


Evaluating Matching Network for 20-way 1-shot: 


100%|██████████| 5/5 [00:13<00:00,  2.67s/it, accuracy=0.182]


Evaluating Matching Network for 20-way 5-shot: 


100%|██████████| 5/5 [00:20<00:00,  4.02s/it, accuracy=0.39]



Evaluating Prototypical Network for 5-way 1-shot: 

100%|██████████| 5/5 [00:02<00:00,  1.91it/s, accuracy=0.52]


Evaluating Prototypical Network for 5-way 5-shot: 


100%|██████████| 5/5 [00:05<00:00,  1.08s/it, accuracy=0.648]


Evaluating Prototypical Network for 20-way 1-shot: 


100%|██████████| 5/5 [00:11<00:00,  2.30s/it, accuracy=0.202]


Evaluating Prototypical Network for 20-way 5-shot: 


100%|██████████| 5/5 [00:17<00:00,  3.59s/it, accuracy=0.348]





In [ ]:
def load_networks_from_checkpoint(key_name, n_ways, n_shots, backbone, network, **kwargs):
    network_dict = dict()
    for n_way in ways:
        network_dict[n_way] = dict()
        for n_shot in shots:

            fsl_network = network(backbone, **kwargs)
            fsl_network.load_state_dict(torch.load(f'/content/drive/MyDrive/Research-Project/{key_name}_{n_way}-way_{n_shot}-shot_last_epoch'))
            network_dict[n_way][n_shot] = fsl_network.to(device)

    return network_dict

In [3]:
###########################
# Load finetuned networks
###########################
finetuned_networks = dict()

In [9]:
# Load Relation Network checkpoints
return_layer='layer4.1.bn2'
cnn = resnet18()
resnet_extractor = create_feature_extractor(cnn, return_nodes=[return_layer])
feature_extractor = FeatureExtractor(resnet_extractor, return_layer)

finetuned_networks['relation_network'] = load_networks_from_checkpoint('relation_network', n_ways, n_shots, feature_extractor, RelationNetworks, feature_dimension=512)

In [10]:
# Load Matching Network checkpoints
return_layer='layer4.1.bn2'
cnn = resnet18()
cnn.fc = torch.nn.Flatten()

finetuned_networks['matching_network'] = load_networks_from_checkpoint('matching_network', n_ways, n_shots, cnn, MatchingNetworks, feature_dimension=512)

In [12]:
# Load Prototypical Network checkpoints
cnn = resnet18()
cnn.fc = torch.nn.Flatten()
finetuned_networks['prototypical'] = load_networks_from_checkpoint('prototypical_network', n_ways, n_shots, cnn, PrototypicalNetworks)

In [19]:
finetuned_acc_table = get_accuracy_table(finetuned_networks)
finetuned_acc_table

Evaluating relation_network for 5-way 1-shot: 

100%|██████████| 5/5 [00:04<00:00,  1.12it/s, accuracy=0.752]



Evaluating relation_network for 5-way 5-shot: 

100%|██████████| 5/5 [00:05<00:00,  1.03s/it, accuracy=0.544]


Evaluating relation_network for 20-way 1-shot: 


100%|██████████| 5/5 [00:12<00:00,  2.46s/it, accuracy=0.634]



Evaluating relation_network for 20-way 5-shot: 

100%|██████████| 5/5 [00:21<00:00,  4.25s/it, accuracy=0.888]


Evaluating matching_network for 5-way 1-shot: 


100%|██████████| 5/5 [00:03<00:00,  1.42it/s, accuracy=0.912]


Evaluating matching_network for 5-way 5-shot: 


100%|██████████| 5/5 [00:06<00:00,  1.28s/it, accuracy=0.976]



Evaluating matching_network for 20-way 1-shot: 

100%|██████████| 5/5 [00:11<00:00,  2.24s/it, accuracy=0.848]


Evaluating matching_network for 20-way 5-shot: 


100%|██████████| 5/5 [00:22<00:00,  4.51s/it, accuracy=0.922]



Evaluating prototypical for 5-way 1-shot: 

100%|██████████| 5/5 [00:02<00:00,  1.79it/s, accuracy=0.904]


Evaluating prototypical for 5-way 5-shot: 


100%|██████████| 5/5 [00:04<00:00,  1.07it/s, accuracy=0.936]


Evaluating prototypical for 20-way 1-shot: 


100%|██████████| 5/5 [00:14<00:00,  2.91s/it, accuracy=0.826]



Evaluating prototypical for 20-way 5-shot: 

100%|██████████| 5/5 [00:22<00:00,  4.56s/it, accuracy=0.908]







In [21]:
finetuned_acc_table

Unnamed: 0_level_0,5-way,5-way,20-way,20-way
Unnamed: 0_level_1,1-shot,5-shot,1-shot,5-shot
relation_network,0.752,0.544,0.634,0.888
matching_network,0.912,0.976,0.848,0.922
prototypical,0.904,0.936,0.826,0.908
