<a name="top"></a>
# **Meta Learning: Few-shot Classification**

In [1]:
!nvidia-smi

zsh:1: command not found: nvidia-smi


## **Step 1: Download Data**

Download dataset "omniglot", which is a classic dataset about few-shot learning and has been pre-processed and augmented.  

In [2]:
workspace_dir = '.'

# Download dataset
!wget https://github.com/xraychen/shiny-disco/releases/download/Latest/omniglot.tar.gz \
    -O "{workspace_dir}/Omniglot.tar.gz"
!wget https://github.com/xraychen/shiny-disco/releases/download/Latest/omniglot-test.tar.gz \
    -O "{workspace_dir}/Omniglot-test.tar.gz"

# Use `tar' command to decompress
!tar -zxf "{workspace_dir}/Omniglot.tar.gz" -C "{workspace_dir}/"
!tar -zxf "{workspace_dir}/Omniglot-test.tar.gz" -C "{workspace_dir}/"

zsh:1: command not found: wget
zsh:1: command not found: wget
tar: Error opening archive: Failed to open './Omniglot.tar.gz'
tar: Error opening archive: Failed to open './Omniglot-test.tar.gz'


## **Step 2: Build the model**

### Library importation

In [8]:
pip install numpy tqdm torch torchvision

Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting torch
  Using cached torch-2.8.0-cp311-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting torchvision
  Using cached torchvision-0.23.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting filelock (from torch)
  Using cached filelock-3.19.1-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.9.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl.m

In [9]:
# Import modules we need
import glob, random
from collections import OrderedDict

import numpy as np
from tqdm.auto import tqdm

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from PIL import Image
from IPython.display import display

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE = {device}")

# Fix random seeds
random_seed = 0
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

  from .autonotebook import tqdm as notebook_tqdm


DEVICE = cpu


### Model Construction Preliminaries

Since our task is image classification, we need to build a CNN-based model.  
However, to implement MAML algorithm, we should adjust some code in `nn.Module`.


Take a look at MAML pseudocode...

<img src="https://i.imgur.com/9aHlvfX.png" width="50%" />

On the 10-th line, what we take gradients on are those $\theta$ representing  <font color="#0CC">**the original model parameters**</font> (outer loop) instead of those in the <font color="#0C0">**inner loop**</font>, so we need to use `functional_forward` to compute the output logits of input image instead of `forward` in `nn.Module`.

The following defines these functions.

<!-- 由於在第10行，我們是要對原本的參數 θ 微分，並非 inner-loop (Line5~8) 的 θ' 微分，因此在 inner-loop，我們需要用 functional forward 的方式算出 input image 的 output logits，而不是直接用 nn.module 裡面的 forward（直接對 θ 微分）。在下面我們分別定義了 functional forward 以及 forward 函數。 -->

### Model block definition

In [36]:
def ConvBlock(in_ch: int, out_ch: int):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )


def ConvBlockFunction(x, w, b, w_bn, b_bn):
    x = F.conv2d(x, w, b, padding=1)
    x = F.batch_norm(
        x, running_mean=None, running_var=None, weight=w_bn, bias=b_bn, training=True
    )
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x

### Model definition

In [37]:
class Classifier(nn.Module):
    def __init__(self, in_ch, k_way):
        super(Classifier, self).__init__()
        self.conv1 = ConvBlock(in_ch, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 64)
        self.conv4 = ConvBlock(64, 64)
        self.logits = nn.Linear(64, k_way)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.logits(x)
        return x

    def functional_forward(self, x, params):
        """
        Arguments:
        x: input images [batch, 1, 28, 28]
        params: model parameters,
                i.e. weights and biases of convolution
                     and weights and biases of
                                   batch normalization
                type is an OrderedDict

        Arguments:
        x: input images [batch, 1, 28, 28]
        params: The model parameters,
                i.e. weights and biases of convolution
                     and batch normalization layers
                It's an `OrderedDict`
        """
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

### Create Label

This function is used to create labels.  
In a N-way K-shot few-shot classification problem,
each task has `n_way` classes, while there are `k_shot` images for each class.  
This is a function that creates such labels.


In [38]:
def create_label(n_way, k_shot):
    return torch.arange(n_way).repeat_interleave(k_shot).long()


# Try to create labels for 5-way 2-shot setting
create_label(5, 2)

tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])

### Accuracy calculation

In [39]:
def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

### Define Dataset

Define the dataset.  
The dataset returns images of a random character, with (`k_shot + q_query`) images,  
so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`.  


In [40]:
# Dataset for train and val
class Omniglot(Dataset):
    def __init__(self, data_dir, k_way, q_query, task_num=None):
        self.file_list = [
            f for f in glob.glob(data_dir + "**/character*", recursive=True)
        ]
        # limit task number if task_num is set
        if task_num is not None:
            self.file_list = self.file_list[: min(len(self.file_list), task_num)]
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.n = k_way + q_query

    def __getitem__(self, idx):
        sample = np.arange(20)

        # For random sampling the characters we want.
        np.random.shuffle(sample)
        img_path = self.file_list[idx]
        img_list = [f for f in glob.glob(img_path + "**/*.png", recursive=True)]
        img_list.sort()
        imgs = [self.transform(Image.open(img_file)) for img_file in img_list]
        # `k_way + q_query` examples for each character
        imgs = torch.stack(imgs)[sample[: self.n]]
        return imgs

    def __len__(self):
        return len(self.file_list)

## **Step 3: Learning Algorithms**



### Meta Learning

Here is the main Meta Learning algorithm.

In [41]:
def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())

        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            # Simply training
            train_label = create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss = criterion(logits, train_label)
            # Inner gradients update!#
            # Calculate gradients
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

            # Update fast_weights
            # θ' = θ - α * ∇loss
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )

        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:
            """ training / validation """
            val_label = create_label(n_way, q_query).to(device)

            # Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss = criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))
        else:
            """ testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    # Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        """ Outer Loop Update """
        # φ backpropagation
        meta_batch_loss.backward()
        # Update parameters
        optimizer.step()

    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

## **Step 4: Initialization**

After defining all components we need, the following initialize a model before training.

### Hyperparameters

In [42]:
n_way = 5
k_shot = 1
q_query = 1
train_inner_train_step = 5
val_inner_train_step = 5
inner_lr = 0.4
meta_lr = 0.001
meta_batch_size = 32
max_epoch = 30
eval_batches = 20
train_data_path = "./Omniglot/images_background/"

### Dataloader initialization

In [43]:
def dataloader_init(datasets, shuffle=True, num_workers=2):
    train_set, val_set = datasets
    train_loader = DataLoader(
        train_set,
        # The "batch_size" here is not the meta batch size, but how many different characters in a task, i.e. the "n_way" in few-shot classification.
        batch_size=n_way,
        num_workers=num_workers,
        shuffle=shuffle,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set, batch_size=n_way, num_workers=num_workers, shuffle=shuffle, drop_last=True
    )

    train_iter = iter(train_loader)
    val_iter = iter(val_loader)
    return (train_loader, val_loader), (train_iter, val_iter)

### Model & optimizer initialization

In [44]:
def model_init():
    meta_model = Classifier(1, n_way).to(device)
    optimizer = torch.optim.Adam(meta_model.parameters(), lr=meta_lr)
    loss_fn = nn.CrossEntropyLoss().to(device)
    return meta_model, optimizer, loss_fn

### Utility function to get a meta-batch

In [45]:
def get_meta_batch(meta_batch_size, k_shot, q_query, data_loader, iterator):
    data = []
    for _ in range(meta_batch_size):
        try:
            # a "task_data" tensor is representing the data of a task, with size of [n_way, k_shot+q_query, 1, 28, 28]
            task_data = next(iterator)
        except StopIteration:
            iterator = iter(data_loader)
            task_data = next(iterator)
        train_data = task_data[:, :k_shot].reshape(-1, 1, 28, 28)
        val_data = task_data[:, k_shot:].reshape(-1, 1, 28, 28)
        task_data = torch.cat((train_data, val_data), 0)
        data.append(task_data)
    return torch.stack(data).to(device), iterator

<a name="mainprog" id="mainprog"></a>
## **Step 5: Main program for training & testing**

### Start training!
With `solver = 'base'`, the solver is a transfer learning algorithm.

Once you finish the TODO blocks in the `MetaSolver`, change the variable `solver = 'meta'` to start training with meta learning algorithm.


In [46]:
solver = 'meta'
meta_model, optimizer, loss_fn = model_init()

# init solver and dataset according to solver type
if solver == 'meta':
    Solver = MetaSolver
    dataset = Omniglot(train_data_path, k_shot, q_query)
    train_split = int(0.8 * len(dataset))
    val_split = len(dataset) - train_split
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_split, val_split]
    )
    (train_loader, val_loader), (train_iter, val_iter) = dataloader_init((train_set, val_set))
else:
    raise NotImplementedError


# main training loop
for epoch in range(max_epoch):
    print("Epoch %d" % (epoch + 1))
    train_meta_loss = []
    train_acc = []
    # The "step" here is a meta-gradinet update step
    for step in tqdm(range(max(1, len(train_loader) // meta_batch_size))):
        x, train_iter = get_meta_batch(
            meta_batch_size, k_shot, q_query, train_loader, train_iter
        )
        meta_loss, acc = Solver(
            meta_model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=train_inner_train_step
        )
        train_meta_loss.append(meta_loss.item())
        train_acc.append(acc)
    print("  Loss    : ", "%.3f" % (np.mean(train_meta_loss)), end="\t")
    print("  Accuracy: ", "%.3f %%" % (np.mean(train_acc) * 100))

    # See the validation accuracy after each epoch.
    # Early stopping is welcomed to implement.
    val_acc = []
    for eval_step in tqdm(range(max(1, len(val_loader) // (eval_batches)))):
        x, val_iter = get_meta_batch(
            eval_batches, k_shot, q_query, val_loader, val_iter
        )
        # We update three inner steps when testing.
        _, acc = Solver(
            meta_model,
            optimizer,
            x,
            n_way,
            k_shot,
            q_query,
            loss_fn,
            inner_train_step=val_inner_train_step,
            train=False,
        )
        val_acc.append(acc)
    print("  Validation accuracy: ", "%.3f %%" % (np.mean(val_acc) * 100))

Epoch 1


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.239	  Accuracy:  51.719 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  53.000 %
Epoch 2


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.163	  Accuracy:  55.469 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  55.000 %
Epoch 3


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.176	  Accuracy:  55.156 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  52.000 %
Epoch 4


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.092	  Accuracy:  59.375 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  63.000 %
Epoch 5


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.049	  Accuracy:  59.062 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  56.000 %
Epoch 6


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.061	  Accuracy:  62.031 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  64.000 %
Epoch 7


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.045	  Accuracy:  62.187 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  64.000 %
Epoch 8


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.044	  Accuracy:  61.250 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  72.000 %
Epoch 9


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  1.031	  Accuracy:  63.125 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  56.000 %
Epoch 10


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.953	  Accuracy:  67.656 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  67.000 %
Epoch 11


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.962	  Accuracy:  63.906 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  69.000 %
Epoch 12


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.886	  Accuracy:  69.219 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  68.000 %
Epoch 13


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.872	  Accuracy:  69.375 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  68.000 %
Epoch 14


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.992	  Accuracy:  62.969 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  61.000 %
Epoch 15


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.935	  Accuracy:  66.250 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  69.000 %
Epoch 16


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.846	  Accuracy:  72.344 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  67.000 %
Epoch 17


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.812	  Accuracy:  69.219 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  64.000 %
Epoch 18


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.773	  Accuracy:  71.562 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  74.000 %
Epoch 19


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.698	  Accuracy:  76.094 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  81.000 %
Epoch 20


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.671	  Accuracy:  78.281 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  81.000 %
Epoch 21


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.653	  Accuracy:  77.812 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  79.000 %
Epoch 22


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.661	  Accuracy:  76.875 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  76.000 %
Epoch 23


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.644	  Accuracy:  78.750 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  71.000 %
Epoch 24


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.640	  Accuracy:  79.531 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  77.000 %
Epoch 25


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.548	  Accuracy:  82.969 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  86.000 %
Epoch 26


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.552	  Accuracy:  83.594 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  85.000 %
Epoch 27


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.537	  Accuracy:  83.125 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  82.000 %
Epoch 28


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.529	  Accuracy:  83.594 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  80.000 %
Epoch 29


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.535	  Accuracy:  83.125 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  82.000 %
Epoch 30


  0%|          | 0/4 [00:00<?, ?it/s]

  Loss    :  0.526	  Accuracy:  83.438 %


  0%|          | 0/1 [00:00<?, ?it/s]

  Validation accuracy:  86.000 %


### Testing the result

In [47]:
import os

# test dataset
class OmniglotTest(Dataset):
    def __init__(self, test_dir):
        self.test_dir = test_dir
        self.n = 5

        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, idx):
        support_files = [
            os.path.join(self.test_dir, "support", f"{idx:>04}", f"image_{i}.png")
            for i in range(self.n)
        ]
        query_files = [
            os.path.join(self.test_dir, "query", f"{idx:>04}", f"image_{i}.png")
            for i in range(self.n)
        ]

        support_imgs = torch.stack(
            [self.transform(Image.open(e)) for e in support_files]
        )
        query_imgs = torch.stack([self.transform(Image.open(e)) for e in query_files])

        return support_imgs, query_imgs

    def __len__(self):
        return len(os.listdir(os.path.join(self.test_dir, "support")))

In [48]:
test_inner_train_step = 100 # we can change this

test_batches = 20
test_dataset = OmniglotTest("Omniglot-test")
test_loader = DataLoader(test_dataset, batch_size=test_batches, shuffle=False)

output = []
for _, batch in enumerate(tqdm(test_loader)):
    support_set, query_set = batch
    x = torch.cat([support_set, query_set], dim=1)
    x = x.to(device)

    labels = Solver(
        meta_model,
        optimizer,
        x,
        n_way,
        k_shot,
        q_query,
        loss_fn,
        inner_train_step=test_inner_train_step,
        train=False,
        return_labels=True,
    )

    output.extend(labels)

# write to csv
with open("output.csv", "w") as f:
    f.write(f"id,class\n")
    for i, label in enumerate(output):
        f.write(f"{i},{label}\n")

  0%|          | 0/32 [00:00<?, ?it/s]

## **Reference**
1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)
1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)
1. Machine Learning 2022 Spring by National Taiwan University instructed by 李宏毅(Hung-Yi Lee) [Meta Learning](https://www.youtube.com/watch?v=FmJ4T4k88jY)