# **Prototypical Learning**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets, models
import os
import random
from PIL import Image

In [None]:
# Define neural network for protypical learning
class ProtypicalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ProtypicalNetwork, self).__init__()
        self.vgg16 = models.vgg16(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.vgg16.features.children()))

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.fc1 = nn.Linear(512, 64)
        self.fc2 = nn.Linear(64, 64)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        return x


In [None]:
# Calculate distance between samples
def euclidean_distance(x1, x2):
    return torch.norm(x1[:,None]-x2, dim=2)

In [None]:
# Get Omniglot data from internet
def get_omniglot_data(path='./data'):
    # Create directory path if not present
    if not os.path.exists(path):
        os.makedirs(path)
    # Define the transform to apply to the images
    transform = transforms.Compose([
        transforms.ToTensor(), # Convert PIL image to PyTorch tensor
        transforms.Resize((32, 32)), # Resize to 32x32 for consistency
        transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values
    ])
    # Download and load the Omniglot dataset (background set)
    train_dataset = datasets.Omniglot(
        root=path,
        background=True,
        download=True,
        transform=transform
    )
    # Download and load the Omniglot dataset (evaluation set)
    # test_dataset = datasets.Omniglot(
    #     root=path,
    #     background=False,
    #     download=True,
    #     transform=transform
    # )
    return train_dataset, test_dataset
get_omniglot_data()

100%|██████████| 9.46M/9.46M [00:00<00:00, 354MB/s]
100%|██████████| 6.46M/6.46M [00:00<00:00, 360MB/s]


(Dataset Omniglot
     Number of datapoints: 19280
     Root location: ./data/omniglot-py
     StandardTransform
 Transform: Compose(
                ToTensor()
                Resize(size=(32, 32), interpolation=bilinear, max_size=None, antialias=True)
                Normalize(mean=(0.5,), std=(0.5,))
            ),
 Dataset Omniglot
     Number of datapoints: 13180
     Root location: ./data/omniglot-py
     StandardTransform
 Transform: Compose(
                ToTensor()
                Resize(size=(32, 32), interpolation=bilinear, max_size=None, antialias=True)
                Normalize(mean=(0.5,), std=(0.5,))
            ))

In [None]:
# Load data for train step
def load_data(path, num_classes = 1, train_samples = 5, test_samples = 3):
  # Get list of class directories
  class_dirs = []
  for lang in os.listdir(path):
    if os.path.isdir(os.path.join(path, lang)):
      for char in os.listdir(os.path.join(path, lang)):
        if os.path.isdir(os.path.join(path, lang, char)):
          class_dirs.append(os.path.join(path, lang, char))
  random.shuffle(class_dirs) # get list of randomized class directories
  class_dirs = class_dirs[:num_classes] # get first num_classes
  # Variables for support set and query set
  support_set = []
  query_set = []
  support_labels = []
  query_labels = []
  for cls_dir in class_dirs:
    # Get 'language_character{num}' as class label
    cls_name = "_".join(cls_dir.split(os.sep)[-2:])
    images = [os.path.join(cls_dir, img) for img in os.listdir(cls_dir) if img.endswith('.png')]
    # Check for availablity of images for support set and query set
    if len(images) >= train_samples + test_samples:
      sampled_images = random.sample(images, train_samples + test_samples)
      support_set.extend(sampled_images[:-test_samples])
      query_set.extend(sampled_images[-test_samples:train_samples+test_samples])
      support_labels.append([cls_name]*train_samples)
      query_labels.append([cls_name]*test_samples)
      # Stop if support set have num_classes*train_samples
      if len(support_set) >= num_classes*train_samples:
        break
  return support_set, query_set, support_labels, query_labels


# s, q, l_s, l_q = load_data(path='./data/omniglot-py/images_background')

In [None]:
# Function to train Protypical Network
def train(model, optimizer, epochs=10, num_classes=1, train_samples=5, test_samples=3):
    history = {}
    correct_predictions = 0
    total_queries = 0
    epoch = 0
    # Define the transform to apply to the images
    transform = transforms.Compose([
        transforms.ToTensor(), # Convert PIL image to PyTorch tensor
        transforms.Resize((32, 32)), # Resize to 32x32 for consistency
        transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values
    ])
    model.cuda()
    while epoch < epochs:
        model.train()
        total_loss = 0
        # load support set and query set
        support_set, query_set, support_labels, query_labels = load_data(path='./data/omniglot-py/images_background', num_classes=num_classes, train_samples=train_samples, test_samples=test_samples)
        if not support_set or not query_set:
          continue
        else:
          epoch += 1
          print(f"Epoch {epoch}/{epochs}:")
          history[epoch] = {}
          history["all_losses"] = float()
          history["predictions"] = []
          history["labels"] = []

        print(f"Loading support set images: {len(support_set)} Images")
        support_set = torch.stack([transform((Image.open(img)).convert('RGB')) for img in support_set]).cuda()

        print(f"Loading query set images: {len(query_set)} Images")
        query_set = torch.stack([transform((Image.open(img)).convert('RGB')) for img in query_set]).cuda()

        # Generate predictions
        support_preds = model(support_set)
        query_preds = model(query_set)

        # Prototypes for each class
        prototypes = support_preds.view(-1, train_samples, 64).mean(dim=1)

        # Calculate distance between prototypes and query outputs
        distances = euclidean_distance(query_preds, prototypes) # Shape: (num_queries, num_classes)

        # Create labels for query_set
        query_labels = []
        for i in range(num_classes):
          query_labels.extend([i]*test_samples)
        query_labels = torch.tensor(query_labels).cuda()

        # Calculate log probabilities
        log_probs = -distances

        # Calculate loss
        loss = nn.CrossEntropyLoss()(log_probs, query_labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        history[epoch]["loss"] = loss.item()
        history["all_losses"] += loss.item()

        # Performance measures
        _, predicted = torch.min(distances, dim=1)
        correct_predictions += (predicted == query_labels).sum().item()
        total_queries += len(query_labels)
        accuracy = correct_predictions / total_queries

        history[epoch]["accuracy"] = accuracy
        history["predictions"].extend(predicted.cpu().numpy())
        history["labels"].extend(query_labels.cpu().numpy())

        print(f"Epoch {epoch}: Loss = {loss.item()}, Accuracy = {accuracy}")
    print(f"Final accuracy: {correct_predictions/total_queries}")
    return history, prototypes

model = ProtypicalNetwork(input_dim=3, hidden_dim=64, output_dim=64)
optimizer = optim.Adam(model.parameters(), lr=0.001)
history, proto = train(model=model, optimizer=optimizer, epochs=10, num_classes=75, train_samples=15, test_samples=3)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 178MB/s]


Epoch 1/10:
Loading support set images: 1125 Images
Loading query set images: 225 Images
Epoch 1: Loss = 3.96459698677063, Accuracy = 0.5288888888888889
Epoch 2/10:
Loading support set images: 1125 Images
Loading query set images: 225 Images
Epoch 2: Loss = 3.860280990600586, Accuracy = 0.5333333333333333
Epoch 3/10:
Loading support set images: 1125 Images
Loading query set images: 225 Images
Epoch 3: Loss = 3.733813762664795, Accuracy = 0.5318518518518518
Epoch 4/10:
Loading support set images: 1125 Images
Loading query set images: 225 Images
Epoch 4: Loss = 3.648183822631836, Accuracy = 0.5366666666666666
Epoch 5/10:
Loading support set images: 1125 Images
Loading query set images: 225 Images
Epoch 5: Loss = 3.59014630317688, Accuracy = 0.5306666666666666
Epoch 6/10:
Loading support set images: 1125 Images
Loading query set images: 225 Images
Epoch 6: Loss = 3.4132184982299805, Accuracy = 0.5259259259259259
Epoch 7/10:
Loading support set images: 1125 Images
Loading query set images:

In [None]:
proto.shape

torch.Size([75, 64])