<a href="https://colab.research.google.com/github/dlfelps/ml_portfolio/blob/main/few_shot_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook provides an introduction to few shot learning.

It has a companion [blog post](https://dlfelps.github.io/2024/06/03/few-shot.html).

It is part of Daniel Felps' [ML portfolio](https://github.com/dlfelps/ml_portfolio/tree/main)

# SETUP ENVIRONMENT

In [1]:
# install few shot learning package
!git clone https://github.com/sicara/easy-few-shot-learning
%cd easy-few-shot-learning
!pip install .

Cloning into 'easy-few-shot-learning'...
remote: Enumerating objects: 1188, done.[K
remote: Counting objects: 100% (451/451), done.[K
remote: Compressing objects: 100% (245/245), done.[K
remote: Total 1188 (delta 285), reused 259 (delta 204), pack-reused 737[K
Receiving objects: 100% (1188/1188), 2.33 MiB | 10.70 MiB/s, done.
Resolving deltas: 100% (689/689), done.
/content/easy-few-shot-learning
Processing /content/easy-few-shot-learning
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.5.0->easyfsl==1.5.0)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.5.0->easyfsl==1.5.0)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105

In [8]:
# download pretrained models
!wget https://github.com/dlfelps/ml_portfolio/raw/main/pretrained_models/classical_model_18_acc_744.pt
!wget https://github.com/dlfelps/ml_portfolio/raw/main/pretrained_models/episodic_model_18_acc_779.pt

--2024-05-08 01:49:57--  https://github.com/dlfelps/ml_portfolio/raw/main/pretrained_models/classical_model_18_acc_744.pt
Resolving github.com (github.com)... 140.82.116.3
Connecting to github.com (github.com)|140.82.116.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/dlfelps/ml_portfolio/main/pretrained_models/classical_model_18_acc_744.pt [following]
--2024-05-08 01:49:57--  https://raw.githubusercontent.com/dlfelps/ml_portfolio/main/pretrained_models/classical_model_18_acc_744.pt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 45073134 (43M) [application/octet-stream]
Saving to: ‘classical_model_18_acc_744.pt’


2024-05-08 01:49:59 (299 MB/s) - ‘classical_model_18_acc_744.pt’ saved [450731

In [2]:
# Download the CUB dataset
!wget -O data/CUB/CUB_200_2011.tgz https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1
!tar -xzf data/CUB/CUB_200_2011.tgz --strip-components 1 --directory ./data/CUB/ CUB_200_2011/images/

--2024-05-08 01:37:25--  https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1
Resolving data.caltech.edu (data.caltech.edu)... 35.155.11.48
Connecting to data.caltech.edu (data.caltech.edu)|35.155.11.48|:443... connected.
HTTP request sent, awaiting response... 302 FOUND
Location: https://s3.us-west-2.amazonaws.com/caltechdata/96/97/8384-3670-482e-a3dd-97ac171e8a10/data?response-content-type=application%2Foctet-stream&response-content-disposition=attachment%3B%20filename%3DCUB_200_2011.tgz&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIARCVIVNNAP7NNDVEA%2F20240508%2Fus-west-2%2Fs3%2Faws4_request&X-Amz-Date=20240508T013725Z&X-Amz-Expires=60&X-Amz-SignedHeaders=host&X-Amz-Signature=d80555dbba1c6d2858ab7410aa03efe90f448493a1865196a8747784e1701d12 [following]
--2024-05-08 01:37:25--  https://s3.us-west-2.amazonaws.com/caltechdata/96/97/8384-3670-482e-a3dd-97ac171e8a10/data?response-content-type=application%2Foctet-stream&response-content-disposition=attachme

# IMPORTS

In [3]:
from pathlib import Path
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torchvision.models import resnet18, ResNet18_Weights
from easyfsl.datasets import CUB
from torch.utils.data import DataLoader
from easyfsl.samplers import TaskSampler
from easyfsl.utils import evaluate
from easyfsl.methods import PrototypicalNetworks

# EVAL SETUP

In [4]:
n_test_tasks = 1000
n_way = 5
n_shot = 5
n_query = 10
DEVICE = 'cuda'

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=10,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)



# PRETRAINED

In [6]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Flatten()
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 130MB/s]


In [7]:
evaluate(few_shot_classifier, test_loader, device=DEVICE) #.684

  self.pid = os.fork()
  self.pid = os.fork()
100%|██████████| 1000/1000 [04:58<00:00,  3.35it/s, accuracy=0.684]


0.68412

# CLASSICAL

In [24]:
model = resnet18()
model.fc = nn.Linear(512, 140)
model.load_state_dict(torch.load('/content/easy-few-shot-learning/classical_model_18_acc_744.pt')) # pretrained just for you!
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

In [25]:
evaluate(few_shot_classifier, test_loader, device=DEVICE) #.773

100%|██████████| 1000/1000 [04:47<00:00,  3.48it/s, accuracy=0.773]


0.77268

# EPISODIC (META-LEARNING)

In [6]:
model = resnet18()
model.fc = nn.Flatten()
model.load_state_dict(torch.load('/content/easy-few-shot-learning/episodic_model_18_acc_779.pt')) # pretrained just for you!
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

In [7]:
evaluate(few_shot_classifier, test_loader, device=DEVICE) #.779

  self.pid = os.fork()
100%|██████████| 1000/1000 [04:55<00:00,  3.38it/s, accuracy=0.779]


0.77896

# BONUS: DINO

In [35]:
model  = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
model.to(DEVICE)

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-23): 24 x NestedTensorBlock(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )


In [36]:
from easyfsl.utils import predict_embeddings
from easyfsl.datasets import FeaturesDataset

dataloader = DataLoader(
    test_set,
    batch_size=128,
    num_workers=1,
    shuffle=False,
)

embeddings_df = predict_embeddings(dataloader, model, device=DEVICE)
features_dataset = FeaturesDataset.from_dataframe(embeddings_df)

task_sampler = TaskSampler(
    features_dataset,
    n_way=n_way,
    n_shot=n_shot,
    n_query=n_query,
    n_tasks=n_test_tasks,
)

features_loader = DataLoader(
    features_dataset,
    batch_sampler=task_sampler,
    num_workers=1,
    pin_memory=True,
    collate_fn=task_sampler.episodic_collate_fn,
)

Predicting embeddings: 100%|██████████| 14/14 [00:13<00:00,  1.01batch/s]


In [37]:
from easyfsl.methods import PrototypicalNetworks

few_shot_classifier = PrototypicalNetworks()

In [38]:
evaluate(few_shot_classifier, features_loader, device='cpu') #.964

100%|██████████| 1000/1000 [00:08<00:00, 123.77it/s, accuracy=0.964]


0.96448