In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import learn2learn as l2l

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
############################
# 1. Data Loading via learn2learn
############################
N_WAY = 5
K_SHOT = 1
Q_QUERY = 5

tasksets = l2l.vision.benchmarks.get_tasksets(
    "omniglot",
    train_ways=N_WAY,
    train_samples=K_SHOT+Q_QUERY,
    test_ways=N_WAY,
    test_samples=K_SHOT+Q_QUERY,
    num_tasks=20000,  
    root='data',
)


Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to data/omniglot-py/images_background.zip


100%|██████████| 9.46M/9.46M [00:00<00:00, 41.7MB/s]


Extracting data/omniglot-py/images_background.zip to data/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to data/omniglot-py/images_evaluation.zip


100%|██████████| 6.46M/6.46M [00:00<00:00, 33.2MB/s]


Extracting data/omniglot-py/images_evaluation.zip to data/omniglot-py


In [5]:
############################
# 2. Model Definition
############################
class SimpleCNN(nn.Module):
    def __init__(self, output_size=64):
        super(SimpleCNN, self).__init__()
        # Padding to maintain spatial dims before pooling
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(64*1*1, output_size)

    def forward(self, x):
        # (B,1,28,28)
        x = F.relu(self.pool(self.conv1(x))) # (B,64,14,14)
        x = F.relu(self.pool(self.conv2(x))) # (B,64,7,7)
        x = F.relu(self.pool(self.conv3(x))) # (B,64,3,3)
        x = F.relu(self.pool(self.conv4(x))) # (B,64,1,1)
        x = x.view(x.size(0), -1) # (B,64)
        x = self.fc(x) # (B,64)
        return x

feature_extractor = SimpleCNN().to(device)

def init_memory():
    return nn.Parameter(torch.zeros(N_WAY, 64, device=device))

def classify_with_memory(embeddings, memory):
    return torch.matmul(embeddings, memory.t())  # (B,N_WAY)


In [6]:
############################
# 3. Optimizer
############################
outer_optimizer = optim.Adam(feature_extractor.parameters(), lr=1e-3)


In [19]:
############################
# 4. Training Loop
############################
feature_extractor.train()
inner_steps = 5
inner_lr = 0.1
meta_batches = 100

for meta_step in range(meta_batches):
    task = tasksets.train.sample()
    X, Y = task  # X.shape: [N_WAY*(K_SHOT+Q_QUERY), 1,28,28], Y.shape: [N_WAY*(K_SHOT+Q_QUERY)]

    support_count = N_WAY * K_SHOT
    query_count = N_WAY * Q_QUERY

    support_x = X[:support_count]
    support_y = Y[:support_count]
    query_x = X[support_count:support_count+query_count]
    query_y = Y[support_count:support_count+query_count]

    memory = init_memory()
    memory.requires_grad = True

    inner_optimizer = torch.optim.SGD([memory], lr=inner_lr)
    for _ in range(inner_steps):
        sup_feat = feature_extractor(support_x)
        sup_logits = classify_with_memory(sup_feat, memory)
        sup_loss = F.cross_entropy(sup_logits, support_y)

        inner_optimizer.zero_grad()
        sup_loss.backward()
        inner_optimizer.step()

    que_feat = feature_extractor(query_x)
    que_logits = classify_with_memory(que_feat, memory)
    que_loss = F.cross_entropy(que_logits, query_y)

    outer_optimizer.zero_grad()
    que_loss.backward()
    outer_optimizer.step()

    print(f"Meta-step {meta_step}: Query Loss = {que_loss.item():.4f}")


Meta-step 0: Query Loss = 1.6482
Meta-step 1: Query Loss = 1.6379
Meta-step 2: Query Loss = 1.6319
Meta-step 3: Query Loss = 1.6287
Meta-step 4: Query Loss = 1.6264
Meta-step 5: Query Loss = 1.6228
Meta-step 6: Query Loss = 1.6217
Meta-step 7: Query Loss = 1.6203
Meta-step 8: Query Loss = 1.6171
Meta-step 9: Query Loss = 1.6160
Meta-step 10: Query Loss = 1.6150
Meta-step 11: Query Loss = 1.6142
Meta-step 12: Query Loss = 1.6133
Meta-step 13: Query Loss = 1.6123
Meta-step 14: Query Loss = 1.6119
Meta-step 15: Query Loss = 1.6112
Meta-step 16: Query Loss = 1.6112
Meta-step 17: Query Loss = 1.6107
Meta-step 18: Query Loss = 1.6104
Meta-step 19: Query Loss = 1.6102
Meta-step 20: Query Loss = 1.6099
Meta-step 21: Query Loss = 1.6100
Meta-step 22: Query Loss = 1.6101
Meta-step 23: Query Loss = 1.6099
Meta-step 24: Query Loss = 1.6099
Meta-step 25: Query Loss = 1.6096
Meta-step 26: Query Loss = 1.6094
Meta-step 27: Query Loss = 1.6111
Meta-step 28: Query Loss = 1.6095
Meta-step 29: Query Loss