In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import copy

# 1. CONFIGURATION & DATA GENERATION

# Hyperparameters
K_SHOT = 10              # Support set size [cite: 12]
Q_QUERY = 10             # Query set size for evaluation
INNER_LR = 0.1           # Step size for inner loop
META_LR = 0.001          # Step size for outer loop
EPOCHS = 2000            # Meta-training epochs [cite: 18]
TASKS_PER_BATCH = 32     # Meta-batch size

# Dataset Parameters
RADIUS = 2.0             # Fixed radius [cite: 8]
INPUT_RANGE = [-5, 5]    # Input space limits [cite: 7]
CENTER_RANGE = [-3, 3]   # Center uniform sampling range [cite: 8]

def get_circle_task():
    """Sample a random center (cx, cy) from [-3, 3]."""
    cx = np.random.uniform(CENTER_RANGE[0], CENTER_RANGE[1])
    cy = np.random.uniform(CENTER_RANGE[0], CENTER_RANGE[1])
    return np.array([cx, cy])

def sample_points(task_center, num_points):
    """Generate x in [-5, 5] and labels based on distance to center."""
    x = np.random.uniform(INPUT_RANGE[0], INPUT_RANGE[1], (num_points, 2))
    # Equation: Dist < radius => Label 1, else 0 [cite: 11]
    dists = np.sqrt(np.sum((x - task_center)**2, axis=1))
    y = (dists < RADIUS).astype(np.float32).reshape(-1, 1)
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# 2. MODEL DEFINITION (FUNCTIONAL MLP)

# We use a functional approach to allow manual gradient updates (MAML requirement)
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        # Architecture: 2 -> 40 -> 40 -> 1 [cite: 16]
        self.w1 = nn.Parameter(torch.randn(2, 40) / np.sqrt(2))
        self.b1 = nn.Parameter(torch.zeros(40))
        self.w2 = nn.Parameter(torch.randn(40, 40) / np.sqrt(40))
        self.b2 = nn.Parameter(torch.zeros(40))
        self.w3 = nn.Parameter(torch.randn(40, 1) / np.sqrt(40))
        self.b3 = nn.Parameter(torch.zeros(1))

    def forward(self, x, params=None):
        """Allows forward pass using either internal params or updated temporary params."""
        if params is None:
            params = [self.w1, self.b1, self.w2, self.b2, self.w3, self.b3]

        w1, b1, w2, b2, w3, b3 = params

        x = torch.relu(x @ w1 + b1)
        x = torch.relu(x @ w2 + b2)
        x = torch.sigmoid(x @ w3 + b3)
        return x

    def get_params(self):
        return [self.w1, self.b1, self.w2, self.b2, self.w3, self.b3]


# 3. MAML TRAINING LOOP

maml_model = SimpleMLP()
meta_optimizer = optim.Adam(maml_model.parameters(), lr=META_LR)
loss_fn = nn.BCELoss()

print("Starting MAML Training...")
for epoch in range(EPOCHS):
    meta_loss = 0.0

    # Outer Loop
    for _ in range(TASKS_PER_BATCH):
        task_center = get_circle_task()

        # Support Set (K=10)
        x_supp, y_supp = sample_points(task_center, K_SHOT)
        # Query Set
        x_qry, y_qry = sample_points(task_center, Q_QUERY)

        # Inner Loop (1 Step) [cite: 19]
        # 1. Forward on support
        params = maml_model.get_params()
        preds = maml_model(x_supp, params)
        loss = loss_fn(preds, y_supp)

        # 2. Compute Gradients
        grads = torch.autograd.grad(loss, params, create_graph=True)

        # 3. Manual Update (theta_prime = theta - alpha * grad)
        updated_params = [p - INNER_LR * g for p, g in zip(params, grads)]

        # 4. Compute Loss on Query set using updated params
        qry_preds = maml_model(x_qry, updated_params)
        task_loss = loss_fn(qry_preds, y_qry)
        meta_loss += task_loss

    # Meta-Update [cite: 20]
    meta_optimizer.zero_grad()
    meta_loss /= TASKS_PER_BATCH
    meta_loss.backward()
    meta_optimizer.step()

    if epoch % 200 == 0:
        print(f"Epoch {epoch}: Loss = {meta_loss.item():.4f}")

# 4. BASELINE TRAINING (JOINT TRAINING)

baseline_model = SimpleMLP()
baseline_opt = optim.Adam(baseline_model.parameters(), lr=1e-3)

print("\nStarting Baseline Training...")
for epoch in range(EPOCHS):
    # Sample batch from mixed tasks [cite: 23]
    # We simulate "mixing" by generating fresh points from random tasks every step
    x_batch = []
    y_batch = []
    for _ in range(32): # Batch size
        c = get_circle_task()
        x, y = sample_points(c, 1) # 1 point per random task
        x_batch.append(x)
        y_batch.append(y)

    x_batch = torch.cat(x_batch)
    y_batch = torch.cat(y_batch)

    baseline_opt.zero_grad()
    preds = baseline_model(x_batch) # Uses internal params
    loss = loss_fn(preds, y_batch)
    loss.backward()
    baseline_opt.step()

    if epoch % 200 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

# 5. EVALUATION & VISUALIZATION


# Test Task Generation
test_center = get_circle_task()
x_test_supp, y_test_supp = sample_points(test_center, K_SHOT) # Support set
x_test_eval, y_test_eval = sample_points(test_center, 100)    # Evaluation set

#A. Quantitative: Accuracy vs Steps [cite: 29]
steps = list(range(11)) # 0 to 10
maml_accs = []
base_accs = []

# Helper for accuracy
def get_acc(model, params, x, y):
    preds = model(x, params)
    return ((preds > 0.5) == y).float().mean().item()

# Evaluate MAML
curr_params = [p.clone() for p in maml_model.get_params()]
for i in range(11):
    maml_accs.append(get_acc(maml_model, curr_params, x_test_eval, y_test_eval))
    # Update step
    loss = loss_fn(maml_model(x_test_supp, curr_params), y_test_supp)
    grads = torch.autograd.grad(loss, curr_params)
    curr_params = [p - INNER_LR * g for p, g in zip(curr_params, grads)]

# Evaluate Baseline (Fine-tuning) [cite: 25]
# Need to copy model to avoid overwriting trained weights
base_eval_model = copy.deepcopy(baseline_model)
base_params = [p for p in base_eval_model.parameters()]
# We need a standard optimizer for baseline fine-tuning or manual updates
# To keep comparison fair (SGD), we use manual updates same as MAML
for i in range(11):
    base_accs.append(get_acc(base_eval_model, base_params, x_test_eval, y_test_eval))
    # Update step
    # Re-compute loss to get grad for current params
    output = base_eval_model(x_test_supp, base_params) # Functional call logic applied here
    loss = loss_fn(output, y_test_supp)
    grads = torch.autograd.grad(loss, base_params)
    base_params = [p - INNER_LR * g for p, g in zip(base_params, grads)]

# Plot 1: Accuracy Curve
plt.figure(figsize=(8, 5))
plt.plot(steps, maml_accs, marker='o', label='MAML')
plt.plot(steps, base_accs, marker='x', label='Baseline (Pre-trained)')
plt.title('Test Accuracy vs Gradient Steps (K=10)')
plt.xlabel('Gradient Steps')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()

#B. Qualitative: Heatmap [cite: 33]

# Create Meshgrid
xx, yy = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
grid_tensor = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)

# Ground Truth
dists = np.sqrt((xx - test_center[0])**2 + (yy - test_center[1])**2)
gt_mask = dists < RADIUS

# MAML Prediction (After 1 step)
# Reset params, take 1 step
curr_params = [p.clone() for p in maml_model.get_params()]
loss = loss_fn(maml_model(x_test_supp, curr_params), y_test_supp)
grads = torch.autograd.grad(loss, curr_params)
step1_params = [p - INNER_LR * g for p, g in zip(curr_params, grads)]
maml_probs = maml_model(grid_tensor, step1_params).detach().numpy().reshape(xx.shape)

# Baseline Prediction (After 1 step)
base_params = [p for p in baseline_model.parameters()] # Original pre-trained
output = baseline_model(x_test_supp, base_params)
loss = loss_fn(output, y_test_supp)
grads = torch.autograd.grad(loss, base_params)
base_step1_params = [p - INNER_LR * g for p, g in zip(base_params, grads)]
base_probs = baseline_model(grid_tensor, base_step1_params).detach().numpy().reshape(xx.shape)

# Plot Heatmaps
fig, ax = plt.subplots(1, 3, figsize=(18, 5))

# Ground Truth
ax[0].contour(xx, yy, gt_mask, levels=[0.5], colors='k', linestyles='--')
ax[0].set_title("Ground Truth Boundary")
ax[0].scatter(x_test_supp[:,0], x_test_supp[:,1], c=y_test_supp.flatten(), cmap='coolwarm', edgecolors='k')

# MAML
c1 = ax[1].contourf(xx, yy, maml_probs, levels=20, cmap='RdBu_r', alpha=0.8)
ax[1].contour(xx, yy, gt_mask, levels=[0.5], colors='k', linestyles='--') # GT Reference
ax[1].scatter(x_test_supp[:,0], x_test_supp[:,1], c=y_test_supp.flatten(), cmap='coolwarm', edgecolors='k')
ax[1].set_title("MAML (1 Step)")

# Baseline
c2 = ax[2].contourf(xx, yy, base_probs, levels=20, cmap='RdBu_r', alpha=0.8)
ax[2].contour(xx, yy, gt_mask, levels=[0.5], colors='k', linestyles='--') # GT Reference
ax[2].scatter(x_test_supp[:,0], x_test_supp[:,1], c=y_test_supp.flatten(), cmap='coolwarm', edgecolors='k')
ax[2].set_title("Baseline (1 Step)")

plt.show()

Starting MAML Training...
Epoch 0: Loss = 0.3441
Epoch 200: Loss = 0.2955
Epoch 400: Loss = 0.2968
Epoch 600: Loss = 0.3351
Epoch 800: Loss = 0.3127
