In [2]:
!pip install easyfsl
!pip install git+https://github.com/mlfoundations/open_clip.git

Collecting easyfsl
  Downloading easyfsl-1.5.0-py3-none-any.whl.metadata (16 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.5.0->easyfsl)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.5.0->easyfsl)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.5.0->easyfsl)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.5.0->easyfsl)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.5.0->easyfsl)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.5.0->easyfsl)
  Download

In [3]:
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import category_encoders as ce
from datetime import datetime
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from torchvision.models import resnet18
from scipy.spatial.distance import cosine
import torch.optim as optim
import timm
import open_clip

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Dataset preparation

### Custom dataset class

In [6]:
class ImageDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame with image metadata.
            root_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to apply on a sample.
        """
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.transform = transform

    def augment_dataset(self, max_per_class: int = 5):
        """
        Augments the under-represented classes by creating and saving new image files,
        and updating the dataset's DataFrame with the new file paths.
        
        Args:
            max_per_class (int, optional): The target number of samples per class. Default value is 5.
        """
        # define the augmentation method
        aa = AutoAugment(policy=AutoAugmentPolicy.IMAGENET)

        new_rows = []
        grouped = self.dataframe.groupby('category_id')
        augmented_dir = '/kaggle/working/augmented'
        os.makedirs(augmented_dir, exist_ok=True)
        
        for cat_id, group in tqdm(grouped, desc="Augmenting classes"):
            num_samples = len(group)
            if num_samples < max_per_class:
                num_to_add = max_per_class - num_samples

                # retrieve the original images
                original_images = group['filename'].tolist()

                for i in range(num_to_add):
                    img_to_augment = random.choice(original_images)
                    
                    original_image_path = os.path.join(self.root_dir, img_to_augment)
                    
                    try:
                        image = Image.open(original_image_path).convert('RGB')

                        # augment the image
                        augmented_image = aa(image)
                        
                        original_filename = os.path.basename(original_image_path)
                        augmented_filename = f"aug_{i}_{original_filename}"
                        augmented_image_path = os.path.join(augmented_dir, augmented_filename)

                        # save augmented image on /working/augmented/ path
                        augmented_image.save(augmented_image_path)
                        
                        original_row = self.dataframe[self.dataframe['filename'] == original_filename].iloc[0]
                        new_row = original_row.copy()
                        new_row['filename'] = augmented_filename
                        new_rows.append(new_row)

                    except Exception as e:
                        print(f"Failed to augment image {original_image_path}: {e}")
                        continue

        if new_rows:
            new_df = pd.DataFrame(new_rows)
            self.dataframe = pd.concat([self.dataframe, new_df], ignore_index=True)

        print(f"Dataset augmentation complete. New dataset size: {len(self.dataframe)}")

    def get_labels(self):
        return self.dataframe['category_id'].tolist()
    
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        row = self.dataframe.iloc[idx]
        img_name = row['filename']
        
        if 'aug' in img_name:
            img_path = os.path.join('/kaggle/working/augmented', img_name)
        else:
            img_path = os.path.join(self.root_dir, row['filename'])
            
        image = Image.open(img_path).convert('RGB')
        
        label = row['category_id']
        
        if self.transform:
            image = self.transform(image)

        return image, int(label)

### Define dataset

In [7]:
train_df = pd.read_csv('/kaggle/input/fungi-clef-2025/metadata/FungiTastic-FewShot/FungiTastic-FewShot-Train.csv')
test_df = pd.read_csv('/kaggle/input/fungi-clef-2025/metadata/FungiTastic-FewShot/FungiTastic-FewShot-Val.csv')

### Define transformations and create train and test datasets

In [8]:
trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) # 🏳️‍🌈🏳️‍⚧️

train_dataset = ImageDataset(
    dataframe=train_df,
    root_dir='/kaggle/input/fungi-clef-2025/images/FungiTastic-FewShot/train/300p',
    transform=trans
)

test_dataset = ImageDataset(
    dataframe=test_df,
    root_dir='/kaggle/input/fungi-clef-2025/images/FungiTastic-FewShot/val/300p',
    transform=trans
)

### Augment datasets to have 10 samples per class

In [9]:
train_dataset.augment_dataset(max_per_class=10)
test_dataset.augment_dataset(max_per_class=10)

Augmenting classes: 100%|██████████| 2427/2427 [01:55<00:00, 21.00it/s]


Dataset augmentation complete. New dataset size: 24422


Augmenting classes: 100%|██████████| 570/570 [00:24<00:00, 23.39it/s]


Dataset augmentation complete. New dataset size: 5841


### Create a dataset not augmented

In [10]:
test_dataset_noaug = ImageDataset(
    dataframe=test_df,
    root_dir='/kaggle/input/fungi-clef-2025/images/FungiTastic-FewShot/val/300p',
    transform=trans
)

test_dataloader_noaug = DataLoader(
    test_dataset_noaug,
    batch_size=32,
    num_workers=4,
    pin_memory=True,
)

### prototypes computation

In [11]:
def compute_prototypes(dataset, backbone: torch.nn.Module, device: str) -> dict:
    """
    Computes a single prototype (median embedding) for each class in the dataset.
    All embeddings are L2-normalized before median computation.

    Args:
        dataset: instance of ImageDataset class.
        backbone (torch.nn.Module): The trained ResNet backbone for embedding extraction.
        device (str): 'cpu' or 'cuda'.

    Returns:
        dict: A dictionary where keys are class IDs and values are the
              numpy arrays of the prototype embeddings.
    """
    # Use a DataLoader
    data_loader = DataLoader(dataset, batch_size=32, shuffle=False)
    
    backbone.eval()
    backbone.to(device)

    embeddings_dict = {class_id: [] for class_id in dataset.dataframe['category_id'].unique()}

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Extracting embeddings"):
            images = images.to(device)
            
            # Pass through the backbone
            embeddings = backbone(images)
            
            # L2 normalization
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            embeddings = embeddings.squeeze().cpu().numpy()
            
            for embedding, label in zip(embeddings, labels):
                embeddings_dict[label.item()].append(embedding)

    prototypes = {}
    for class_id, embeddings in tqdm(embeddings_dict.items(), desc="Computing prototypes"):
        if embeddings:
            # A class prototype is the median embedding
            prototypes[class_id] = np.median(embeddings, axis=0)
            
    return prototypes

### prediction and evaluation

In [12]:
def predict_and_evaluate(test_dataset, backbone: torch.nn.Module, prototypes: dict, device: str) -> tuple:
    """
    Predicts labels for a test dataset using our pre-computed prototypes and calculates accuracy.
    Uses L2 normalization for embeddings and cosine similarity for prediction.

    Args:
        test_dataset: Instance of ImageDataset for the test set.
        backbone (torch.nn.Module): The trained ResNet backbone for embedding extraction.
        prototypes (dict): The dictionary of pre-computed prototype embeddings for each class.
        device (str): 'cpu' or 'cuda'.

    Returns:
        tuple: A tuple containing the accuracy score and the predicted labels.
    """
    backbone.eval()
    backbone.to(device)

    # Convert prototypes to a PyTorch tensor for efficient computation
    class_ids = list(prototypes.keys())
    prototype_array = np.array(list(prototypes.values()))
    
    data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Predicting labels"):
            images = images.to(device)
            
            # Get embeddings for test images
            embeddings = backbone(images)
            
            # Embeddings L2 normalization
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            embeddings = embeddings.squeeze().cpu().numpy()
            if embeddings.ndim == 1:
                embeddings = embeddings.reshape(1, -1)

            # Find the closest prototype
            for test_embedding in embeddings:
                distances = [cosine(test_embedding, p) for p in prototype_array]
                min_distance_index = np.argmin(distances)
                
                # Map index back to the class ID
                prediction = class_ids[min_distance_index]
                predicted_labels.append(prediction)
            
            true_labels.extend(labels.cpu().numpy())

    # Compute and return accuracy
    accuracy = accuracy_score(true_labels, predicted_labels)
    return accuracy

# Classification

### Prototypical network definition

In [13]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone(support_images)
        z_query = self.backbone(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the median of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].median(0).values
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # We tried also computing cosign similarity and normalization, but eucledian distance is the best choice
        #z_query_norm = F.normalize(z_query, p=2, dim=1)
        #z_proto_norm = F.normalize(z_proto, p=2, dim=1)
        #scores = torch.matmul(z_query_norm, z_proto_norm.T)
        #scores = torch.matmul(z_query, z_proto.T)

        # Distances into classification scores
        scores = -dists
        
        return scores

### Define few-shot parameters
- n_way: the number of classes in a task
- n_shot: the number of images per class in the support set
- n_query: the number of images per class in the query set

In [14]:
N_WAY = 5
N_SHOT = 5
N_QUERY = 5

N_EVALUATION_TASKS = 100
N_TRAINING_EPISODES = 1000
N_VALIDATION_TASKS = 100

#### Use the TaskSampler class to sample the necessary few-shot tasks, for both train and test

In [15]:
train_sampler = TaskSampler(
    train_dataset, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)

train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

In [16]:
test_sampler = TaskSampler(
    test_dataset, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_dataset,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

#### Utility functions for overall and single-task evaluation 

In [17]:
def evaluate_on_one_task(model,
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    return (
        torch.max(
            model(support_images.cuda(), support_labels.cuda(), query_images.cuda()).detach().data, 1,
        )[1]
        == query_labels.cuda()
    ).sum().item(), len(query_labels)


def evaluate(model, data_loader: DataLoader):
    total_predictions = 0
    correct_predictions = 0

    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):

            correct, total = evaluate_on_one_task(model,
                support_images, support_labels, query_images, query_labels
            )

            total_predictions += total
            correct_predictions += correct

    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )

In [18]:
def fit(
    model,
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images.cuda(), support_labels.cuda(), query_images.cuda()
    )

    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()

    return loss.item()

In [19]:
def train_few_shot(model):
    log_update_frequency = 10

    all_loss = []
    model.train()
    with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            loss_value = fit(model, support_images, support_labels, query_images, query_labels)
            all_loss.append(loss_value)
    
            if episode_index % log_update_frequency == 0:
                tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))
    
    return model

### Define ResNet model

In [20]:
convolutional_network = resnet18( weights='DEFAULT')
convolutional_network.fc = nn.Flatten()
fs_model = PrototypicalNetworks(convolutional_network).cuda()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 152MB/s] 


In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fs_model.parameters(), lr=0.0001,weight_decay=1e-4)

### Training function

#### A first evaluation to see the result without training

In [22]:
evaluate(fs_model, test_loader)

100%|██████████| 100/100 [00:11<00:00,  8.53it/s]

Model tested on 100 tasks. Accuracy: 81.80%





### Training

In [23]:
fs_model = train_few_shot(fs_model)

100%|██████████| 1000/1000 [01:49<00:00,  9.13it/s, loss=0.0754]


#### Evaluation after training

In [24]:
evaluate(fs_model, test_loader)

100%|██████████| 100/100 [00:10<00:00,  9.15it/s]

Model tested on 100 tasks. Accuracy: 87.24%





# Add a final layer for the classification

We take the backbone of PrototypeNetwork we just trained as `trained_backbone`. 

Optimizer must update only the parameters of the new Linear layer

#### Functions for the outer layer training and testing

In [25]:
def train_classifier_head(head_model, epochs=10):
    model = head_model
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        for images, labels in tqdm(full_train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
    
            optimizer.zero_grad()
    
            logits = model(images)
    
            loss = criterion(logits, labels)
    
            loss.backward()
            optimizer.step()
    
        print(f"Fine-tuning Epoch {epoch+1} Loss: {loss.item():.4f}")

    return model

In [26]:
def test_classifier_head(model):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_dataloader_noaug:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

Define `full_train_dataset` as complete training dataset, not just few-shot episodes

In [27]:
full_train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True
)

#### First try with ResNet backbone

In [28]:
trained_backbone = fs_model.backbone
trained_backbone.to(device)
# Freeze
for param in trained_backbone.parameters():
    param.requires_grad = False

embedding_dim = 512
num_classes = len(train_df['category_id'].unique())

# We define the new model with Linear layer
# It includes the freezed backbone and the new classification layer
model = nn.Sequential(
    trained_backbone,
    nn.Linear(embedding_dim, num_classes)
).to(device)

In [29]:
model = train_classifier_head(model, 10)

Epoch 1/10: 100%|██████████| 382/382 [01:31<00:00,  4.16it/s]


Fine-tuning Epoch 1 Loss: 3.4883


Epoch 2/10: 100%|██████████| 382/382 [01:27<00:00,  4.35it/s]


Fine-tuning Epoch 2 Loss: 1.9309


Epoch 3/10: 100%|██████████| 382/382 [01:28<00:00,  4.30it/s]


Fine-tuning Epoch 3 Loss: 1.4754


Epoch 4/10: 100%|██████████| 382/382 [01:28<00:00,  4.30it/s]


Fine-tuning Epoch 4 Loss: 0.8864


Epoch 5/10: 100%|██████████| 382/382 [01:27<00:00,  4.36it/s]


Fine-tuning Epoch 5 Loss: 0.4566


Epoch 6/10: 100%|██████████| 382/382 [01:27<00:00,  4.36it/s]


Fine-tuning Epoch 6 Loss: 0.3828


Epoch 7/10: 100%|██████████| 382/382 [01:27<00:00,  4.36it/s]


Fine-tuning Epoch 7 Loss: 0.2782


Epoch 8/10: 100%|██████████| 382/382 [01:27<00:00,  4.34it/s]


Fine-tuning Epoch 8 Loss: 0.2988


Epoch 9/10: 100%|██████████| 382/382 [01:27<00:00,  4.36it/s]


Fine-tuning Epoch 9 Loss: 0.1720


Epoch 10/10: 100%|██████████| 382/382 [01:26<00:00,  4.42it/s]

Fine-tuning Epoch 10 Loss: 0.2719





#### Now, after training, we can use the new model for classification

In [30]:
test_classifier_head(model)

Test Accuracy: 5.08%


## Use DINOv2 as embedding model

In [31]:
dinov2_backbone = timm.create_model('vit_base_patch16_224.dino', pretrained=True)
dinov2_backbone.to(device)

fs_model = PrototypicalNetworks(dinov2_backbone).cuda()

In [32]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fs_model.parameters(), lr=0.0001,weight_decay=1e-4)

#### Evaluation before training

In [33]:
evaluate(fs_model, test_loader)

100%|██████████| 100/100 [00:29<00:00,  3.39it/s]

Model tested on 100 tasks. Accuracy: 88.68%





#### Training

In [None]:
fs_model = train_few_shot(fs_model)

#### Evaluation after few-shot training

In [None]:
evaluate(fs_model, test_loader)

### Add classification head on DINOv2 backbone

In [None]:
trained_backbone = fs_model.backbone
trained_backbone.to(device)

for param in trained_backbone.parameters():
    param.requires_grad = False

embedding_dim = 768
num_classes = len(train_df['category_id'].unique())

model = nn.Sequential(
    trained_backbone,
    nn.Linear(embedding_dim, num_classes)
).to(device)

In [None]:
model = train_classifier_head(model, 10)

In [None]:
test_classifier_head(model)

## Use BioCLIP as embedding model

Define new PrototypicalNetworks class to adapt to BioCLIP

In [None]:
class PrototypicalNetworksCLIP(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        # BioCLIP needs to use .encode_image method 
        z_support = self.backbone.encode_image(support_images)
        z_query = self.backbone.encode_image(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].median(0).values
                for label in range(n_way)
            ]
        )
        
        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        scores = -dists
        
        return scores

In [None]:
bioclip_model, _, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
fs_model = PrototypicalNetworksCLIP(bioclip_model).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fs_model.parameters(), lr=0.0001,weight_decay=1e-4)

Preliminary evaluation

In [None]:
evaluate(fs_model, test_loader)

#### Training

In [None]:
fs_model = train_few_shot(fs_model)

Evaluation after training

In [None]:
evaluate(fs_model, test_loader)

Few-shot learning failed us. 

So we moved to another approach.

In [None]:
# Define the BioCLIP backbone
trained_backbone, _, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
trained_backbone.to(device)

# Freeze
for param in trained_backbone.parameters():
    param.requires_grad = False

embedding_dim = 512
num_classes = len(train_df['category_id'].unique())

# Define our new classification model
class MyModel(nn.Module):
    def __init__(self, trained_backbone, embedding_dim, num_classes):
        super(MyModel, self).__init__()
        self.backbone = trained_backbone
        
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.backbone.encode_image(x)
        x = self.fc(x)
        return x

In [55]:
model = MyModel(trained_backbone, embedding_dim, num_classes).to(device)

model = train_classifier_head(model, 30)

Epoch 1/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 1 Loss: 5.5355


Epoch 2/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 2 Loss: 2.6350


Epoch 3/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 3 Loss: 1.5096


Epoch 4/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 4 Loss: 0.9906


Epoch 5/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 5 Loss: 0.8780


Epoch 6/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 6 Loss: 0.6014


Epoch 7/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 7 Loss: 0.2227


Epoch 8/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 8 Loss: 0.3477


Epoch 9/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 9 Loss: 0.1581


Epoch 10/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 10 Loss: 0.1789


Epoch 11/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 11 Loss: 0.0833


Epoch 12/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 12 Loss: 0.0887


Epoch 13/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 13 Loss: 0.0837


Epoch 14/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 14 Loss: 0.0606


Epoch 15/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 15 Loss: 0.0433


Epoch 16/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 16 Loss: 0.0317


Epoch 17/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 17 Loss: 0.0323


Epoch 18/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 18 Loss: 0.0208


Epoch 19/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 19 Loss: 0.0198


Epoch 20/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 20 Loss: 0.0096


Epoch 21/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 21 Loss: 0.0121


Epoch 22/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 22 Loss: 0.0105


Epoch 23/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 23 Loss: 0.0078


Epoch 24/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 24 Loss: 0.0090


Epoch 25/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 25 Loss: 0.0064


Epoch 26/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 26 Loss: 0.0037


Epoch 27/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 27 Loss: 0.0038


Epoch 28/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 28 Loss: 0.0040


Epoch 29/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 29 Loss: 0.0028


Epoch 30/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 30 Loss: 0.0035


In [56]:
test_classifier_head(model)

Test Accuracy: 18.29%


#### Implement a new classifier head model

In [57]:
trained_backbone, _, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
trained_backbone.to(device)
for param in trained_backbone.parameters():
    param.requires_grad = False

embedding_dim = 512
num_classes = len(train_df['category_id'].unique())

class MyModel(nn.Module):
    def __init__(self, trained_backbone, embedding_dim, num_classes):
        super(MyModel, self).__init__()
        self.backbone = trained_backbone

        # define new model
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.backbone.encode_image(x)
        x = self.classifier(x)
        return x

In [58]:
model = MyModel(trained_backbone, embedding_dim, num_classes).to(device)

model = train_classifier_head(model, 30)

Epoch 1/30: 100%|██████████| 382/382 [02:23<00:00,  2.66it/s]


Fine-tuning Epoch 1 Loss: 5.9577


Epoch 2/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 2 Loss: 4.8573


Epoch 3/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 3 Loss: 3.5490


Epoch 4/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 4 Loss: 3.1687


Epoch 5/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 5 Loss: 2.6814


Epoch 6/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 6 Loss: 1.6341


Epoch 7/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 7 Loss: 1.8434


Epoch 8/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 8 Loss: 1.5550


Epoch 9/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 9 Loss: 1.5310


Epoch 10/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 10 Loss: 0.8258


Epoch 11/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 11 Loss: 1.6081


Epoch 12/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 12 Loss: 0.6701


Epoch 13/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 13 Loss: 0.6426


Epoch 14/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 14 Loss: 0.6655


Epoch 15/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 15 Loss: 1.2912


Epoch 16/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 16 Loss: 0.6345


Epoch 17/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 17 Loss: 0.4499


Epoch 18/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 18 Loss: 0.7664


Epoch 19/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 19 Loss: 0.7741


Epoch 20/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 20 Loss: 0.2342


Epoch 21/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 21 Loss: 0.8049


Epoch 22/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 22 Loss: 0.4596


Epoch 23/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 23 Loss: 0.2389


Epoch 24/30: 100%|██████████| 382/382 [02:22<00:00,  2.68it/s]


Fine-tuning Epoch 24 Loss: 0.8964


Epoch 25/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 25 Loss: 0.3325


Epoch 26/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 26 Loss: 0.8664


Epoch 27/30: 100%|██████████| 382/382 [02:22<00:00,  2.67it/s]


Fine-tuning Epoch 27 Loss: 0.9873


Epoch 28/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 28 Loss: 0.8055


Epoch 29/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 29 Loss: 0.5141


Epoch 30/30: 100%|██████████| 382/382 [02:23<00:00,  2.67it/s]


Fine-tuning Epoch 30 Loss: 0.5926


In [59]:
test_classifier_head(model)

Test Accuracy: 15.19%


# **Few-shot with frozen BioCLIP and trainable head**

Now we define another approach.
We use BioCLIP as backbone, then train a classifier head with a couple of layers with few-shot. Finally combine the backbone and the head to perform classification with prototypes. 

In [60]:
clip_model, _, _ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')

for param in clip_model.parameters():
    param.requires_grad = False

In [61]:
class Head(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Head, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

embedding_dim = 512
head_hidden_dim = 512
output_feature_dim = 256

head_network = Head(
    input_dim=embedding_dim,
    hidden_dim=head_hidden_dim,
    output_dim=output_feature_dim
)

In [62]:
class CombinedModel(nn.Module):
    def __init__(self, clip_backbone: nn.Module, head_model: nn.Module):
        super(CombinedModel, self).__init__()
        # Store the pre-trained CLIP backbone
        self.clip_backbone = clip_backbone
        # Store the custom head model
        self.head = head_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 1: Pass input through CLIP to get features.
        with torch.no_grad():
            x = self.clip_backbone.encode_image(x)
        
        # 2: Pass features through our custom head network.
        x = self.head(x)
        
        return x

In [63]:
combined_model = CombinedModel(clip_model, head_network)

In [64]:
fs_model = PrototypicalNetworks(combined_model).to(device)

In [65]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fs_model.parameters(), lr=0.0001,weight_decay=1e-4)

Evaluation before few-shot training

In [66]:
evaluate(fs_model, test_loader)

100%|██████████| 100/100 [00:29<00:00,  3.34it/s]

Model tested on 100 tasks. Accuracy: 81.68%





### Training

In [67]:
fs_model = train_few_shot(fs_model)

100%|██████████| 1000/1000 [06:12<00:00,  2.68it/s, loss=0.295]


Evaluation after training

In [68]:
evaluate(fs_model, test_loader)

100%|██████████| 100/100 [00:30<00:00,  3.33it/s]

Model tested on 100 tasks. Accuracy: 85.52%





### Compute prototypes and predict

In [69]:
protos = compute_prototypes(dataset=train_dataset, backbone=fs_model.backbone, device=device)

Extracting embeddings: 100%|██████████| 764/764 [03:36<00:00,  3.53it/s]
Computing prototypes: 100%|██████████| 2427/2427 [00:00<00:00, 9391.94it/s]


In [70]:
predict_and_evaluate(test_dataset=test_dataset_noaug, backbone=fs_model.backbone, prototypes=protos, device=device)

Predicting labels: 100%|██████████| 72/72 [01:24<00:00,  1.17s/it]


0.06520787746170678