In [2]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from moe import RoutedCNN, SimpleCNN, SparseMoEConvBlockWeighted

In [None]:
dataset = load_dataset("cifar100")

In [None]:
dataset

In [4]:
from torchvision import transforms
import torchvision

# compose the transforms
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [12]:
def collate_fn(batch):
    imgs = []
    labels = []
    for i in range(len(batch)):
        img = batch[i]["img"]
        img = transform(torchvision.transforms.ToPILImage()(img).convert("RGB"))
        imgs.append(img)
        labels.append(batch[i]["coarse_label"])
    return {
        "img": torch.stack(imgs),
        "coarse_label": torch.tensor(labels),
    }

In [15]:
train_loader = DataLoader(
    dataset["train"].with_format("torch"),
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn,
)

test_loader = DataLoader(
    dataset["test"].with_format("torch"),
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn,
)

In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report


num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = RoutedCNN().to(device)

learning_rate = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, batch in enumerate(tqdm(train_loader)):
        images = batch["img"].to(device)
        labels = batch["coarse_label"].to(device)

        outputs, router_loss = model(images, return_router_loss=True)
        loss = criterion(outputs, labels)

        # Combine the losses
        lambda_balance = (
            10  # Adjust this hyperparameter to control the strength of load balancing
        )

        total_loss = loss + lambda_balance * router_loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if (i + 1) % len(train_loader) == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {total_loss.item():.4f}, Router loss: {router_loss.item():.4f}"
            )

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for batch in tqdm(test_loader):
            images = batch["img"].to(device)
            labels = batch["coarse_label"].to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(
            f"F1 score of the model on the test images: {f1_score(labels.cpu(), predicted.cpu(), average='macro')}"
        )
        print(f"Accuracy of the model on the test images: {100 * correct / total}%")


In [22]:
torch.save(model.state_dict(), "routed_cnn_100.pth")

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    preds = []
    targets = []
    for batch in tqdm(test_loader):
        images = batch["img"].to(device)
        labels = batch["coarse_label"].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        preds.extend(predicted.cpu().numpy())
        targets.extend(labels.cpu().numpy())

    print(f"Accuracy of the model on the test images: {100 * correct / total}%")
    print(classification_report(targets, preds))