In [None]:
import os
import copy
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models

import matplotlib.pyplot as plt
from PIL import Image

from easyFSL_helper import *

from tqdm.notebook import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)

from torch.utils.tensorboard import SummaryWriter

In [None]:
# https://www.kaggle.com/code/wenewone/transfer-learning-example-on-cub-200-2011-dataset
class CUB():
    def __init__(self, root, dataset_type='train', train_ratio=1, valid_seed=123, transform=None, target_transform=None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        df_img = pd.read_csv(os.path.join(root, 'images.txt'), sep=' ', header=None, names=['ID', 'Image'], index_col=0)
        df_label = pd.read_csv(os.path.join(root, 'image_class_labels.txt'), sep=' ', header=None, names=['ID', 'Label'], index_col=0)
        df_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'), sep=' ', header=None, names=['ID', 'Train'], index_col=0)
        df = pd.concat([df_img, df_label, df_split], axis=1)

        df['Label'] = df['Label'] - 1

        # split data
        if dataset_type == 'test':
            df = df[df['Train'] == 0]
        elif dataset_type == 'train' or dataset_type == 'valid':
            df = df[df['Train'] == 1]
            # random split train, valid
            if train_ratio != 1:
                np.random.seed(valid_seed)
                indices = list(range(len(df)))
                np.random.shuffle(indices)
                split_idx = int(len(indices) * train_ratio) + 1
            elif dataset_type == 'valid':
                raise ValueError('train_ratio should be less than 1!')
            if dataset_type == 'train':
                df = df.iloc[indices[:split_idx]]
            else:       # dataset_type == 'valid'
                df = df.iloc[indices[split_idx:]]
        else:
            raise ValueError('Unsupported dataset_type!')
        self.img_name_list = df['Image'].tolist()
        self.label_list = df['Label'].tolist()
        # Convert greyscale images to RGB mode
        self._convert2rgb()

    def __len__(self):
        return len(self.label_list)
    
    def get_labels(self):
        return self.label_list

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'images', self.img_name_list[idx])
        image = Image.open(img_path)
        target = self.label_list[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)
        return image, target

    def _convert2rgb(self):
        for i, img_name in enumerate(self.img_name_list):
            img_path = os.path.join(self.root, 'images', img_name)
            image = Image.open(img_path)
            color_mode = image.mode
            if color_mode != 'RGB':
                self.img_name_list[i] = img_name.replace('.jpg', '_rgb.jpg')

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((180, 180)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((180, 180)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
batch_size = 64
path = 'archive\CUB_200_2011'

NUM_WORKERS = 4
SPLIT_RATIO = 0.9
RANDOM_SEED = 123
CLASS_NUM = 200

N = 5
K = 5
n_query = 10
n_tasks_per_epoch = 100
n_validation_tasks = 50

# create dataset
train_set = CUB(path, 'train', SPLIT_RATIO, RANDOM_SEED, transform=train_transform)
val_set = CUB(path, 'valid', SPLIT_RATIO, RANDOM_SEED, transform=val_transform)

print("Train: {}".format(len(train_set)))
print("Valid: {}".format(len(val_set)))

In [None]:
train_sampler = TaskSampler(train_set, n_way=N, n_shot=K, n_query=n_query, n_tasks=n_tasks_per_epoch)
val_sampler = TaskSampler(val_set, n_way=N, n_shot=K, n_query=n_query, n_tasks=n_validation_tasks)

train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

In [None]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        scores = -dists
        return scores

In [None]:
resnet = models.resnet34()
few_shot_classifier = PrototypicalNetworks(resnet).to(device)

In [None]:
loss_module = nn.CrossEntropyLoss()
optimizer = optim.Adam(few_shot_classifier.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

In [None]:
def train_model(model, optimizer, scheduler, train_data_loader, loss_module, num_epochs=100, logging_dir='runs/our_experiment'):
    os.makedirs(logging_dir, exist_ok=True)
    writer = SummaryWriter(logging_dir)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for support_images, support_labels, query_images, query_labels, _ in tqdm(train_data_loader, 'Epoch %d'%(epoch + 1)):
            ## Step 2: Run the model on the input data
            classification_scores = model(support_images.to(device), support_labels.to(device), query_images.to(device))

            ## Step 3: Calculate the loss
            loss = loss_module(classification_scores, query_labels.to(device))

            ## Step 4: Perform backpropagation
            optimizer.zero_grad()
            loss.backward()

            ## Step 5: Update the parameters
            optimizer.step()
            epoch_loss += loss.item() * len(support_images)

        # Add average loss to TensorBoard
        epoch_loss /= len(train_data_loader.dataset)

        # Calling scheduler
        scheduler.step(epoch_loss)

        validation_accuracy = evaluate(few_shot_classifier, val_loader, device=device, tqdm_prefix="Validation")

        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            best_state = copy.deepcopy(few_shot_classifier.state_dict())
            # state_dict() returns a reference to the still evolving model's state so we deepcopy
            # https://pytorch.org/tutorials/beginner/saving_loading_models
            print("Found a new best model.")

        writer.add_scalar('training_loss',
                          epoch_loss,
                          global_step = epoch + 1)
        writer.add_scalar("Val_acc", validation_accuracy, epoch)

        model_dir = logging_dir.split('/')[-1]
        os.makedirs(f'models/{model_dir}', exist_ok=True)
        if((epoch + 1) % 10 == 0):
          torch.save(model.state_dict(), f'models/{model_dir}/model{epoch + 1}.pt')

    writer.close()
      

In [None]:
train_model(few_shot_classifier, optimizer, scheduler, train_loader, loss_module, num_epochs=10)