# **Model Agnostic Meta learning (MAML)**

The mathematical formula for MAML can be expressed as follows:

Given a set of tasks T = {T1, T2, …, TN}, where each task Ti has a training set Di, MAML aims to find a set of parameters θ that can be quickly adapted to new tasks.

1. Initialization: Initialize the model parameters θ randomly or with pre-trained weights.

2. Inner loop: For each task Ti, compute the adapted parameters θi by taking a few gradient steps on the loss function L(Di, θ) using the training data Di.

3. Outer loop: Update the initial parameters θ by taking the gradient descent step on the meta-objective J(T, θ) over all tasks. This objective measures the performance of the adapted parameters θi on the validation set for each task. Different meta-objectives can be used, such as minimizing the average loss or maximizing the accuracy across tasks.

4. Repeat steps 2 and 3 for a few iterations to refine the initial parameters.


In [None]:
!pip install higher

Collecting higher
  Downloading higher-0.2.1-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->higher)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->higher)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->higher)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->higher)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->higher)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->higher)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import higher  # for differentiable inner loop updates

# ----- Define the Model -----
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# ----- Prepare the MNIST Dataset -----
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# For demonstration, use a smaller subset of MNIST as our "meta-training" pool.
# In a real few-shot setup, you would create tasks based on classes and sample a support & query set.
meta_train_indices = np.random.choice(len(mnist_train), 10000, replace=False)
meta_train_dataset = Subset(mnist_train, meta_train_indices)

# Create a DataLoader (we’ll sample tasks from it later)
meta_train_loader = DataLoader(meta_train_dataset, batch_size=64, shuffle=True)

# ----- Helper Function to Sample a Task -----
def sample_task(dataset, support_size=32, query_size=32):
    """
    Randomly sample support and query sets from the dataset.
    """
    indices = np.random.choice(len(dataset), support_size + query_size, replace=False)
    support_indices = indices[:support_size]
    query_indices = indices[support_size:]

    support_loader = DataLoader(Subset(dataset, support_indices), batch_size=support_size)
    query_loader = DataLoader(Subset(dataset, query_indices), batch_size=query_size)

    # Get one batch from each loader
    support_images, support_labels = next(iter(support_loader))
    query_images, query_labels = next(iter(query_loader))

    return support_images, support_labels, query_images, query_labels

# ----- MAML Inner Loop Step Using higher -----
def maml_inner_loop(model, inner_optimizer, support_images, support_labels, inner_steps=2, inner_lr=0.01):
    loss_fn = nn.CrossEntropyLoss()
    # Use higher to create a functional version of the model and optimizer state.
    with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt):
        # Perform inner loop adaptation on the support set.
        for _ in range(inner_steps):
            support_preds = fmodel(support_images)
            support_loss = loss_fn(support_preds, support_labels)
            diffopt.step(support_loss)
        return fmodel  # returns the adapted model

# ----- MAML Outer Loop (Meta-Training) -----
def meta_train(model, meta_dataset, meta_optimizer, epochs=5, tasks_per_epoch=100,
               support_size=32, query_size=32, inner_steps=1, inner_lr=0.01):
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        meta_loss = 0.0
        for task in range(tasks_per_epoch):
            # Sample a task: get support and query sets
            support_images, support_labels, query_images, query_labels = sample_task(meta_dataset, support_size, query_size)

            # Send data to device (CPU or GPU)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            support_images = support_images.to(device)
            support_labels = support_labels.to(device)
            query_images = query_images.to(device)
            query_labels = query_labels.to(device)
            model.to(device)

            # Create an inner optimizer for the adaptation (using the same initial model parameters)
            inner_optimizer = optim.SGD(model.parameters(), lr=inner_lr)
            # Get adapted model after inner loop updates
            fmodel = maml_inner_loop(model, inner_optimizer, support_images, support_labels, inner_steps, inner_lr)

            # Evaluate the adapted model on the query set
            query_preds = fmodel(query_images)
            query_loss = loss_fn(query_preds, query_labels)

            # Accumulate the loss
            meta_loss += query_loss

        # Average loss over tasks
        meta_loss /= tasks_per_epoch

        # Meta-optimization step: update the original model parameters
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()

        print(f"Epoch {epoch+1}/{epochs}, Meta Loss: {meta_loss.item():.4f}")

# ----- Main Script -----
if __name__ == "__main__":
    # Initialize model and meta-optimizer
    model = SimpleCNN(num_classes=10)
    meta_optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Run meta-training
    meta_train(model, meta_train_dataset, meta_optimizer, epochs=5, tasks_per_epoch=100,
               support_size=32, query_size=32, inner_steps=1, inner_lr=0.01)

    # After meta-training, you can adapt the model quickly to new tasks.
    # For example, sample a new task and perform a few inner loop updates:
    support_images, support_labels, query_images, query_labels = sample_task(meta_train_dataset)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    support_images, support_labels = support_images.to(device), support_labels.to(device)

    # Create a new inner optimizer for adaptation
    inner_optimizer = optim.SGD(model.parameters(), lr=0.01)
    adapted_model = maml_inner_loop(model, inner_optimizer, support_images, support_labels, inner_steps=5, inner_lr=0.01)

    # Evaluate the adapted model on the query set
    query_images, query_labels = query_images.to(device), query_labels.to(device)
    adapted_preds = adapted_model(query_images)
    predicted_labels = torch.argmax(adapted_preds, dim=1)
    print("Predicted labels after adaptation:", predicted_labels.cpu().numpy())


Epoch 1/5, Meta Loss: 2.3021
Epoch 2/5, Meta Loss: 2.2693
Epoch 3/5, Meta Loss: 2.2148
Epoch 4/5, Meta Loss: 2.1454
Epoch 5/5, Meta Loss: 2.0572
Predicted labels after adaptation: [6 0 2 9 9 2 2 6 9 6 9 7 1 9 6 1 7 3 1 9 1 3 3 0 9 2 9 9 9 9 0 3]
