In [None]:
# gating_experiment.py
# Standalone script to train a gating network on EMNIST without custom modules

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import EMNIST
from torchvision import transforms
import numpy as np

# Mapping of EMNIST classes to expert indices
CLASS_TO_EXPERT = {
    0: 0, 24: 0,
    1: 1, 18: 1, 21: 1,
    9: 2, 44: 2,
    15: 3, 40: 3
}
NUM_EXPERTS = 4

# Gating network: input_dim -> 4*input_dim -> input_dim -> NUM_EXPERTS
class GatingNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # self.gate = nn.Sequential(
        #     nn.Linear(input_dim, input_dim * 4),
        #     nn.ReLU(),
        #     nn.Linear(input_dim * 4, input_dim),
        #     nn.ReLU(),
        #     nn.Linear(input_dim, NUM_EXPERTS)
        # )
        self.gate = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 4)
        )
    def forward(self, x):
        return self.gate(x)


def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Prepare data
transform = transforms.Compose([
    transforms.ToTensor(),  # [0,1]
    transforms.Normalize((0.1307,), (0.3081,))
])
train_ds = EMNIST('data', split='balanced', train=True, download=True, transform=transform)
test_ds  = EMNIST('data', split='balanced', train=False, download=True, transform=transform)

batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)

# Instantiate gating net with input_dim = 28*28
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gating_net = GatingNet(input_dim=28*28).to(device)
optimizer = optim.AdamW(gating_net.parameters(), lr=1e-3)
kl_loss = nn.KLDivLoss(reduction='batchmean')

# Uniform distribution for non-mapped classes
target_uniform = torch.full((NUM_EXPERTS,), 1.0/NUM_EXPERTS, device=device)

# Training loop
num_epochs = 50
for epoch in range(1, num_epochs+1):
    gating_net.train()
    total_loss = 0.0
    correct = 0
    sup_count = 0
    for imgs, labels in train_loader:
        B = imgs.size(0)
        x = imgs.view(B, -1).to(device)
        y = labels.to(device)

        optimizer.zero_grad()
        logits = gating_net(x)                # (B, NUM_EXPERTS)
        logp = torch.log_softmax(logits, dim=1)

        # Build target distribution
        target = target_uniform.unsqueeze(0).expand(B, -1).clone()
        mask = torch.zeros(B, dtype=torch.bool, device=device)
        for i, yi in enumerate(y.tolist()):
            if yi in CLASS_TO_EXPERT:
                t = torch.zeros(NUM_EXPERTS, device=device)
                t[CLASS_TO_EXPERT[yi]] = 1.0
                target[i] = t
                mask[i] = True

        loss = kl_loss(logp, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * B
        preds = torch.argmax(logits, dim=1)
        # Supervised gating accuracy
        correct += (preds[mask] == torch.tensor([CLASS_TO_EXPERT.get(int(l), -1) for l in y[mask].tolist()], device=device)).sum().item()
        sup_count += mask.sum().item()

    avg_loss = total_loss / len(train_ds)
    sup_acc = correct / sup_count if sup_count > 0 else 0.0
    print(f"Epoch {epoch:2d} | Avg KL Loss: {avg_loss:.4f} | Sup Gating Acc: {sup_acc:.4f}")

# Final evaluation on test set
print("\nEvaluating on test set...")
gating_net.eval()
correct = 0
sup_count = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        B = imgs.size(0)
        x = imgs.view(B, -1).to(device)
        y = labels.to(device)
        logits = gating_net(x)
        preds = torch.argmax(logits, dim=1)
        for yi, pred in zip(y.tolist(), preds.tolist()):
            if yi in CLASS_TO_EXPERT:
                sup_count += 1
                if pred == CLASS_TO_EXPERT[yi]:
                    correct += 1

if sup_count > 0:
    print(f"Test Supervised Gating Acc: {correct/sup_count:.4f}")
else:
    print("No supervised samples in test set.")


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x32 and 128x4)