<a href="https://colab.research.google.com/github/bpfrd/few-shot-learning/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms
from PIL import Image
import sys

try:
    import pytorch_lightning as pl
except ImportError:
    !{sys.executable} -m pip install --quiet pytorch-lightning
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
%load_ext tensorboard
%reload_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [16]:
# Custom Dataset class
class ImageDataset(Dataset):

    def __init__(self, imgs, targets, img_transform=None):
        super().__init__()
        self.img_transform = img_transform
        self.imgs = imgs
        self.targets = targets

    def __getitem__(self, idx):
        if isinstance(idx, list):  # Handle batch of indices
            imgs, targets = [], []
            for i in idx:
                img, target = self.imgs[i], self.targets[i]
                img = Image.fromarray(img)
                if self.img_transform is not None:
                    img = self.img_transform(img)
                imgs.append(img)
                targets.append(target)
            return torch.stack(imgs), torch.tensor(targets)
        else:  # Handle single index
            img, target = self.imgs[idx], self.targets[idx]
            img = Image.fromarray(img)
            if self.img_transform is not None:
                img = self.img_transform(img)
            return img, target

    def __len__(self):
        return self.imgs.shape[0]

def dataset_from_labels(imgs, targets, class_set, **kwargs):
    class_mask = (targets[:, None] == class_set[None, :]).any(dim=-1)
    return ImageDataset(imgs=imgs[class_mask.numpy()], targets=targets[class_mask], **kwargs)

# Loading CIFAR100 dataset
CIFAR_train_set = CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
CIFAR_test_set = CIFAR100(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Merging original training and test set
CIFAR_all_images = np.concatenate([CIFAR_train_set.data, CIFAR_test_set.data], axis=0)
CIFAR_all_targets = torch.LongTensor(CIFAR_train_set.targets + CIFAR_test_set.targets)

# Splitting classes
classes = torch.randperm(100)
train_classes, val_classes, test_classes = classes[:80], classes[80:90], classes[90:]

# Resnet transform
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_set = dataset_from_labels(CIFAR_all_images, CIFAR_all_targets, train_classes, img_transform=train_transform)
val_set = dataset_from_labels(CIFAR_all_images, CIFAR_all_targets, val_classes, img_transform=test_transform)
test_set = dataset_from_labels(CIFAR_all_images, CIFAR_all_targets, test_classes, img_transform=test_transform)

# Define TaskSampler for meta-learning
class TaskSampler(Sampler):
    def __init__(self, dataset, n_way, k_shot, q_query, num_tasks):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.num_tasks = num_tasks
        self.labels = np.array(dataset.targets)
        self.classes = np.unique(self.labels)

    def __len__(self):
        return self.num_tasks

    def __iter__(self):
        for _ in range(self.num_tasks):
            batch = []
            support_batch = []
            query_batch = []
            selected_classes = np.random.choice(self.classes, self.n_way, replace=False)
            # print(f'{selected_classes=}')
            for cls in selected_classes:
                cls_indices = np.where(self.labels == cls)[0]
                support_indices = np.random.choice(cls_indices, self.k_shot, replace=False)
                query_indices = np.random.choice(np.setdiff1d(cls_indices, support_indices), self.q_query, replace=False)
                # print(f'{cls=} {support_indices=} {query_indices=}')
                support_batch.extend(support_indices)
                query_batch.extend(query_indices)
            # print(f'{support_batch=} {query_batch=}')
            yield support_batch, query_batch

# Custom collate function to handle support and query batches
def meta_collate_fn(batch):
    support, query = batch
    support_images, support_labels = support
    query_images, query_labels = query

    return (support_images, support_labels), (query_images, query_labels)

# Set parameters for meta-learning
n_way = 5
k_shot = 4
q_query = 3
num_tasks = 100

# Create DataLoader for training, validation, and testing
train_sampler = TaskSampler(train_set, n_way, k_shot, q_query, num_tasks)
val_sampler = TaskSampler(val_set, n_way, k_shot, q_query, num_tasks)
test_sampler = TaskSampler(test_set, n_way, k_shot, q_query, num_tasks)

train_loader = DataLoader(train_set, batch_sampler=train_sampler, collate_fn=meta_collate_fn)
val_loader = DataLoader(val_set, batch_sampler=val_sampler, collate_fn=meta_collate_fn)
test_loader = DataLoader(test_set, batch_sampler=test_sampler, collate_fn=meta_collate_fn)

# Verify DataLoader
for batch in train_loader:
    (support_images, support_labels), (query_images, query_labels) = batch
    print(f'{support_images.shape=} {support_labels.shape=} {support_labels=}')
    print(f'{query_images.shape=} {query_labels.shape=} {query_labels=}')
    break

Files already downloaded and verified
Files already downloaded and verified
support_images.shape=torch.Size([20, 3, 224, 224]) support_labels.shape=torch.Size([20]) support_labels=tensor([ 9,  9,  9,  9, 50, 50, 50, 50, 96, 96, 96, 96, 57, 57, 57, 57, 79, 79,
        79, 79])
query_images.shape=torch.Size([15, 3, 224, 224]) query_labels.shape=torch.Size([15]) query_labels=tensor([ 9,  9,  9, 50, 50, 50, 96, 96, 96, 57, 57, 57, 79, 79, 79])


### Prototypical network

In [17]:
# Define the feature extractor using pretrained ResNet
class ResNetEmbedding(nn.Module):
    def __init__(self, embedding_dim=64):
        super(ResNetEmbedding, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_dim)

    def forward(self, x):
        return self.resnet(x)

# Define the feature extractor using pretrained ResNet
class PrototypicalNetwork(LightningModule):
    def __init__(self, embedding_dim=64, learning_rate=0.001):
        super(PrototypicalNetwork, self).__init__()
        self.save_hyperparameters()
        self.encoder = ResNetEmbedding(embedding_dim=embedding_dim)
        self.learning_rate = learning_rate
        self.validation_step_outputs = []
        self.test_step_outputs = []

    # def forward(self, x):
    #     return self.resnet(x)

    def find_prototypes(self, support_embeddings, support_labels):
        classes = support_labels.unique()
        prototypes = []
        for label in classes:
            prototypes.append(support_embeddings[support_labels == label].mean(dim=0))
        return torch.stack(prototypes, dim=0), classes

    def prototypical_loss(self, support_emb, query_emb, support_labels, query_labels):
        prototypes, classes = self.find_prototypes(support_emb, support_labels)
        targets = (classes[None, :] == query_labels[:, None]).long().argmax(dim=-1)
        dists = F.pairwise_distance(prototypes.unsqueeze(0), query_emb.unsqueeze(1))
        preds = F.log_softmax(-dists, dim=1)
        loss = F.cross_entropy(preds, targets)
        return loss, preds, targets

    def training_step(self, batch, batch_idx):
        (support_images, support_labels), (query_images, query_labels) = batch
        support_emb, query_emb = self.encoder(support_images), self.encoder(query_images)
        loss, preds, targets = self.prototypical_loss(support_emb, query_emb, support_labels, query_labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        (support_images, support_labels), (query_images, query_labels) = batch
        support_emb, query_emb = self.encoder(support_images), self.encoder(query_images)
        loss, preds, targets = self.prototypical_loss(support_emb, query_emb, support_labels, query_labels)
        self.validation_step_outputs.append({'preds': preds, 'targets': targets})

    def on_validation_epoch_end(self):
        preds = torch.concat([x['preds'] for x in self.validation_step_outputs], dim=0)
        targets = torch.concat([x['targets'] for x in self.validation_step_outputs], dim=0)
        loss = F.cross_entropy(preds, targets)
        acc = (preds.argmax(dim=-1) == targets).float().mean()
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.validation_step_outputs.clear()  # free memory

    def test_step(self, batch, batch_idx):
        (support_images, support_labels), (query_images, query_labels) = batch
        support_emb, query_emb = self.encoder(support_images), self.encoder(query_images)
        loss, preds, targets = self.prototypical_loss(support_emb, query_emb, support_labels, query_labels)
        self.test_step_outputs.append({"preds": preds, "targets": targets})

    def on_test_epoch_end(self):
        preds = torch.concat([x['preds'] for x in self.test_step_outputs])
        targets = torch.concat([x['targets'] for x in self.test_step_outputs])
        loss = F.cross_entropy(preds, targets)
        acc = (preds.argmax(dim=-1) == targets).float().mean()
        self.log('test_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.test_step_outputs.clear()  # free memory

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Set parameters
num_epochs = 10
learning_rate = 0.001
embedding_dim = 64
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize model
model = PrototypicalNetwork(embedding_dim=embedding_dim, learning_rate=learning_rate)

# Model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='my_model/',
    filename='best_prototypical_network',
    save_top_k=1,
    mode='min'
)

# Initialize Trainer
trainer = Trainer(
    max_epochs=num_epochs,
    callbacks=[checkpoint_callback],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
)

# Train the model
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model, test_loader)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /content/my_model exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type            | Params | Mode 
----------------------------------------------------
0 | encoder | ResNetEmbedding | 11.2 M | train
----------------------------------------------------
32.8 K    Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.837    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 20. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_acc': 0.6306666731834412, 'test_loss': 0.9228059649467468}]

In [None]:
# %tensorboard --logdir lightning_logs/

### MAML

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from tqdm.auto import tqdm
# from torch.nn.utils.stateless import functional_call
from torch.func import functional_call


# Define the feature extractor using pretrained ResNet
class ResnetModel(nn.Module):
    def __init__(self, num_classes=5):
        super(ResnetModel, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)

        for param in self.resnet.parameters():
            param.requires_grad = False
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        # for param in self.resnet.fc.parameters():
        #     param.requires_grad = True

    def forward(self, x):
        return self.resnet(x)

def get_targets_from_labels(labels):
    classes = torch.unique(labels)
    targets = (classes[None, :] == labels[:, None]).long().argmax(dim=-1)
    return targets, classes

# Inner loop update function
def inner_update(model, loss_fn, support_images, support_targets, inner_lr):
    support_preds = model(support_images)
    support_loss = loss_fn(support_preds, support_targets)
    grads = torch.autograd.grad(support_loss, [p for p in model.parameters() if p.requires_grad], create_graph=True)

    updated_params = {}
    idx = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            updated_params[name] = param - (inner_lr * grads[idx])
            idx += 1
        else:
            updated_params[name] = param
    # updated_params = {name: param - inner_lr * grad for (name, param), grad in zip(model.named_parameters(), grads)}
    return updated_params

# Training function
def train_maml(model, train_loader, val_loader, meta_optimizer, inner_lr, meta_lr, inner_steps, num_epochs, device):
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        model.train()
        meta_loss = 0
        for (support_images, support_labels), (query_images, query_labels) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            support_images, support_labels = support_images.to(device), support_labels.to(device)
            query_images, query_labels = query_images.to(device), query_labels.to(device)
            support_targets, _ = get_targets_from_labels(support_labels)
            query_targets, _ = get_targets_from_labels(query_labels)

            task_model = deepcopy(model)
            for _ in range(inner_steps):
                updated_params = inner_update(task_model, loss_fn, support_images, support_targets, inner_lr)
                with torch.no_grad():
                  for name, param in task_model.named_parameters():
                      if name in updated_params:
                          param.copy_(updated_params[name])


            buffers = {k:v for k,v in model.named_buffers()}
            query_preds = functional_call(model, (updated_params, buffers), (query_images,))
            query_loss = loss_fn(query_preds, query_targets)
            meta_loss += query_loss

        # Update the meta-model parameters
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()

        # Validate the model
        model.train()
        val_loss = 0
        num_correct = count = 0

        for (support_images, support_labels), (query_images, query_labels) in val_loader:
            support_images, support_labels = support_images.to(device), support_labels.to(device)
            query_images, query_labels = query_images.to(device), query_labels.to(device)
            support_targets, _ = get_targets_from_labels(support_labels)
            query_targets, _ = get_targets_from_labels(query_labels)

            task_model = deepcopy(model)
            for _ in range(inner_steps):
                updated_params = inner_update(task_model, loss_fn, support_images, support_targets, inner_lr)
                with torch.no_grad():
                  for name, param in task_model.named_parameters():
                      if name in updated_params:
                          param.copy_(updated_params[name])

            buffers = {k:v for k,v in model.named_buffers()}
            query_preds = functional_call(model, (updated_params,buffers), (query_images,))
            query_loss = loss_fn(query_preds, query_targets)
            val_loss += query_loss.item()*len(query_images)
            num_correct += (query_preds.argmax(dim=-1) == query_targets).sum().item()
            count += query_images.shape[0]

        val_loss /= count
        val_acc = num_correct / count
        print(f'Epoch {epoch+1}/{num_epochs}, Meta Loss: {meta_loss.item()/len(train_loader)}, Validation Loss: {val_loss} Validation Acc: {val_acc}')

# Hyperparameters
inner_lr = 0.01
meta_lr = 0.01
inner_steps = 20
num_epochs = 20

# Initialize model, optimizer, and device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(f'{device=}')
model = ResnetModel(num_classes=n_way)
meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)

# Train the MAML model
train_maml(model, train_loader, val_loader, meta_optimizer, inner_lr, meta_lr, inner_steps, num_epochs, device)


device=device(type='cuda')


Epoch 1/20:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/20, Meta Loss: 1.2397013092041016, Validation Loss: 1.2494804191589355 Validation Acc: 0.5453333333333333


Epoch 2/20:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

## functional version

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from tqdm.auto import tqdm
# from torch.nn.utils.stateless import functional_call
from torch.func import grad, functional_call


# Define the feature extractor using pretrained ResNet
class ResnetModel(nn.Module):
    def __init__(self, num_classes=5):
        super(ResnetModel, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)

        for param in self.resnet.parameters():
            param.requires_grad = False
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
        # for param in self.resnet.fc.parameters():
        #     param.requires_grad = True

    def forward(self, x):
        return self.resnet(x)

def get_targets_from_labels(labels):
    classes = torch.unique(labels)
    targets = (classes[None, :] == labels[:, None]).long().argmax(dim=-1)
    return targets

# Training function
def train_maml(model, train_loader, val_loader, meta_optimizer, inner_lr, meta_lr, inner_steps, num_epochs, device):
    model.to(device)

    # Extract parameters from the model
    params = {name: param for name, param in model.named_parameters()}

    # Define the loss function
    loss_fn = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(num_epochs):

        meta_loss = 0
        for (support_images, support_labels), (query_images, query_labels) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            support_images, support_targets = support_images.to(device), get_targets_from_labels(support_labels).to(device)
            query_images, query_targets = query_images.to(device), get_targets_from_labels(query_labels).to(device)

            # Copy the initial parameters
            task_params = {name: param.clone() for name, param in model.named_parameters()}

            # Inner loop
            for _ in range(inner_steps):
                support_outputs = functional_call(model, task_params, support_images)
                support_loss = loss_fn(support_outputs, support_targets)

                # Compute gradients
                # grads = torch.autograd.grad(support_loss, task_params.values(), create_graph=True)
                grads = torch.autograd.grad(support_loss, [p for p in task_params.values() if p.requires_grad], create_graph=True)

                # Update parameters
                # task_params = {name: param - inner_lr * grad
                #               for (name, param), grad in zip(task_params.items(), grads)}
                idx = 0
                for name, param in task_params.items():
                    if param.requires_grad:
                        task_params[name] = param - (inner_lr * grads[idx])
                        idx += 1

            # Compute query loss
            query_outputs = functional_call(model, task_params, query_images)
            query_loss = loss_fn(query_outputs, query_targets)
            meta_loss += query_loss

        # Update the meta-model parameters
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()

        # Validate the model
        val_loss = 0
        num_correct = count = 0

        for (support_images, support_labels), (query_images, query_labels) in tqdm(val_loader):
            support_images, support_targets = support_images.to(device), get_targets_from_labels(support_labels).to(device)
            query_images, query_targets = query_images.to(device), get_targets_from_labels(query_labels).to(device)

            # Copy the initial parameters
            task_params = {name: param.clone() for name,param in model.named_parameters()}

            # Inner loop
            for _ in range(inner_steps):
                support_outputs = functional_call(model, task_params, support_images)
                support_loss = loss_fn(support_outputs, support_targets)

                # Compute gradients
                # grads = torch.autograd.grad(support_loss, task_params.values(), create_graph=False)
                grads = torch.autograd.grad(support_loss, [p for p in task_params.values() if p.requires_grad], create_graph=False)

                # Update parameters
                # task_params = {name: param - inner_lr * grad
                #               for (name, param), grad in zip(task_params.items(), grads)}
                idx = 0
                for name, param in task_params.items():
                    if param.requires_grad:
                        task_params[name] = param - (inner_lr * grads[idx])
                        idx += 1

                # Compute query loss
            query_outputs = functional_call(model, task_params, query_images)
            query_loss = loss_fn(query_outputs, query_targets)
            val_loss += query_loss.item()*len(query_images)
            num_correct += (query_outputs.argmax(dim=-1) == query_targets).sum().item()
            count += len(query_images)

        val_loss /= count
        val_acc = num_correct / count
        print(f'Epoch {epoch+1}/{num_epochs}, Meta Loss: {meta_loss.item()/len(train_loader)}, Validation Loss: {val_loss} Validation Acc: {val_acc}')

# Hyperparameters
inner_lr = 0.01
meta_lr = 0.004
inner_steps = 10
num_epochs = 20

# Initialize model, optimizer, and device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(f'{device=}')
model = ResnetModel(num_classes=n_way)
meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)

# Train the MAML model
train_maml(model, train_loader, val_loader, meta_optimizer, inner_lr, meta_lr, inner_steps, num_epochs, device)


device=device(type='cuda')


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 

## higher

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from tqdm.auto import tqdm
import torch.nn.functional as F
# from torch.nn.utils.stateless import functional_call
from torch.func import grad, functional_call
try:
    import higher
except ModuleNotFoundError:
    !{sys.executable} -m pip install --quiet higher
import higher

# Define the feature extractor using pretrained ResNet
class ResnetModel(nn.Module):
    def __init__(self, num_classes=5):
        super(ResnetModel, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)

        for param in self.resnet.parameters():
            param.requires_grad = False
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

def get_targets_from_labels(labels):
    classes = torch.unique(labels)
    targets = (classes[None, :] == labels[:, None]).long().argmax(dim=-1)
    return targets

# Training function
def train_maml(model, train_loader, val_loader, meta_optimizer, inner_lr, meta_lr, inner_steps, num_epochs, device):
    model.to(device)
    model.train()

    # Training loop
    for epoch in range(num_epochs):
        meta_loss = 0
        count = 0
        inner_opt = torch.optim.SGD(model.parameters(), lr=inner_lr)
        for (support_images, support_labels), (query_images, query_labels) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            support_images, support_targets = support_images.to(device), get_targets_from_labels(support_labels).to(device)
            query_images, query_targets = query_images.to(device), get_targets_from_labels(query_labels).to(device)

            with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
                # Inner loop
                for _ in range(inner_steps):
                    support_outputs = fmodel(support_images)
                    loss = F.cross_entropy(support_outputs, support_targets)
                    diffopt.step(loss)

                # Compute query loss
                query_outputs = fmodel(query_images)
                query_loss = F.cross_entropy(query_outputs, query_targets)
                count += len(query_images)
                meta_loss += query_loss*len(query_images)

        meta_loss /= count
        # Update the meta-model parameters
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()

        # Validate the model
        val_loss = 0
        num_correct = count = 0

        inner_opt = torch.optim.SGD(model.parameters(), lr=inner_lr)

        for (support_images, support_labels), (query_images, query_labels) in tqdm(val_loader, desc=f"Validation"):
            support_images, support_targets = support_images.to(device), get_targets_from_labels(support_labels).to(device)
            query_images, query_targets = query_images.to(device), get_targets_from_labels(query_labels).to(device)

            with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
                # Inner loop
                for _ in range(inner_steps):
                    support_outputs = fmodel(support_images)
                    loss = F.cross_entropy(support_outputs, support_targets)
                    diffopt.step(loss)

                # Compute query loss
                query_outputs = fmodel(query_images)
                query_loss = F.cross_entropy(query_outputs, query_targets)
                val_loss += query_loss.item()*len(query_images)
                num_correct += (query_outputs.argmax(dim=-1) == query_targets).sum().item()
                count += query_images.shape[0]

        val_loss /= count
        val_acc = num_correct / count
        # print(f'Epoch {epoch+1}/{num_epochs}, Meta Loss: {meta_loss.item()/len(train_loader)}, Validation Loss: {val_loss} Validation Acc: {val_acc}')
        print(f'Epoch {epoch+1}/{num_epochs}, Meta Loss: {meta_loss.item()} Validation Loss: {val_loss} Validation Acc: {val_acc}')

# Hyperparameters
inner_lr = 0.01
meta_lr = 0.002
inner_steps = 10
num_epochs = 50

# Initialize model, optimizer, and device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(f'{device=}')
model = ResnetModel(num_classes=n_way)
meta_optimizer = optim.Adam(model.parameters(), lr=meta_lr)

# Train the MAML model
train_maml(model, train_loader, val_loader, meta_optimizer, inner_lr, meta_lr, inner_steps, num_epochs, device)


device=device(type='cuda')




Epoch 1/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/50, Meta Loss: 1.3326889276504517 Validation Loss: 1.4059023439884186 Validation Acc: 0.44066666666666665


Epoch 2/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2/50, Meta Loss: 1.3358490467071533 Validation Loss: 1.405471261739731 Validation Acc: 0.44133333333333336


Epoch 3/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3/50, Meta Loss: 1.329766869544983 Validation Loss: 1.3862472450733185 Validation Acc: 0.4613333333333333


Epoch 4/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 4/50, Meta Loss: 1.3239490985870361 Validation Loss: 1.3693978488445282 Validation Acc: 0.4846666666666667


Epoch 5/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 5/50, Meta Loss: 1.306422233581543 Validation Loss: 1.3993423092365265 Validation Acc: 0.4553333333333333


Epoch 6/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 6/50, Meta Loss: 1.3005605936050415 Validation Loss: 1.3818376851081848 Validation Acc: 0.47533333333333333


Epoch 7/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 7/50, Meta Loss: 1.2896486520767212 Validation Loss: 1.3684722292423248 Validation Acc: 0.48333333333333334


Epoch 8/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 8/50, Meta Loss: 1.335615634918213 Validation Loss: 1.3808688879013062 Validation Acc: 0.4533333333333333


Epoch 9/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 9/50, Meta Loss: 1.2858344316482544 Validation Loss: 1.3790464746952056 Validation Acc: 0.4513333333333333


Epoch 10/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 10/50, Meta Loss: 1.2940279245376587 Validation Loss: 1.3774996280670166 Validation Acc: 0.44866666666666666


Epoch 11/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 11/50, Meta Loss: 1.263412356376648 Validation Loss: 1.3824767339229584 Validation Acc: 0.46


Epoch 12/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 12/50, Meta Loss: 1.2925138473510742 Validation Loss: 1.381031218767166 Validation Acc: 0.4646666666666667


Epoch 13/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 13/50, Meta Loss: 1.2842986583709717 Validation Loss: 1.3817827677726746 Validation Acc: 0.46


Epoch 14/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 14/50, Meta Loss: 1.2981290817260742 Validation Loss: 1.3489337891340256 Validation Acc: 0.4746666666666667


Epoch 15/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 15/50, Meta Loss: 1.2927172183990479 Validation Loss: 1.3724109077453612 Validation Acc: 0.45666666666666667


Epoch 16/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 16/50, Meta Loss: 1.2781084775924683 Validation Loss: 1.3811816430091859 Validation Acc: 0.45066666666666666


Epoch 17/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 17/50, Meta Loss: 1.2772417068481445 Validation Loss: 1.3863217532634735 Validation Acc: 0.45266666666666666


Epoch 18/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 18/50, Meta Loss: 1.2805142402648926 Validation Loss: 1.3623942375183105 Validation Acc: 0.472


Epoch 19/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 19/50, Meta Loss: 1.2688361406326294 Validation Loss: 1.3824213099479676 Validation Acc: 0.468


Epoch 20/50:   0%|          | 0/100 [00:00<?, ?it/s]

Validation:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 20/50, Meta Loss: 1.2807763814926147 Validation Loss: 1.392653294801712 Validation Acc: 0.42866666666666664


Epoch 21/50:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
!pip install higher

Collecting higher
  Downloading higher-0.2.1-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->higher)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->higher)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->higher)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->higher)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->higher)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->higher)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64