## Notebook for testing MoE with MNIST  
Main objective here is to figure out how to get the gradients to go through the loss function using PyTorch

In [None]:
import torch
from torch import nn
import numpy as np 
import random
import time
import sys
import os
import matplotlib.pyplot as plt

#for the disk_memoize function
import pickle
import hashlib
from functools import wraps
from tqdm import tqdm

# Add scripts folder path so I can get load_mnist
repo_root = os.path.abspath("..")  # one level up from /notebook
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
from scripts.MNIST import load_mnist
print(sys.path)

#just some basic stuff to set for reproducability
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

['c:\\Users\\caleb\\OneDrive - Uppsala universitet\\Fall 2025\\Projects Course\\mixture-of-experts-organization\\mixture-of-experts-project', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0\\python313.zip', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0\\DLLs', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0\\Lib', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0', '', 'C:\\Users\\caleb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python313\\site-packages', 'C:\\Users\\caleb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python313\\site-packages\\win32', 'C:\\Users\\caleb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCac

In [24]:
#want to get the data in a linear format becuase our simple MoE will be with linear layers
def get_data(linear = True):
    #get the train and test data from the dataset
    xtrain,ytrain,xtest,ytest = load_mnist.load_mnist()
    #if we want to work with flattened/ linear input
    if linear:
        xtrain = torch.Tensor(xtrain).to(DEVICE)
        ytrain = torch.Tensor(ytrain).to(DEVICE)
        xtest = torch.Tensor(xtest).to(DEVICE)
        ytest = torch.Tensor(ytest).to(DEVICE)
    else:
        #converting to Tensors for easy PyTorch implementation and reshape for a CNN
        xtrain = torch.Tensor(xtrain).reshape(60000, 1,28,28).to(DEVICE)
        ytrain = torch.Tensor(ytrain).to(DEVICE)
        xtest = torch.Tensor(xtest).reshape(10000, 1,28,28).to(DEVICE)
        ytest = torch.Tensor(ytest).to(DEVICE)
    #first we want to put our data in a pytorch dataset so we can mini batch and enumerate through it later more easily
    train_dataset = torch.utils.data.TensorDataset(xtrain, ytrain)
    test_dataset = torch.utils.data.TensorDataset(xtest, ytest)

    return train_dataset, test_dataset

#get the datasets
train_dataset, test_dataset = get_data()

Retrieving test images


100%|██████████| 980/980 [00:03<00:00, 287.81it/s]
100%|██████████| 1135/1135 [00:03<00:00, 307.95it/s]
100%|██████████| 1032/1032 [00:03<00:00, 282.57it/s]
100%|██████████| 1010/1010 [00:03<00:00, 291.87it/s]
100%|██████████| 982/982 [00:03<00:00, 318.64it/s]
100%|██████████| 892/892 [00:03<00:00, 277.85it/s]
100%|██████████| 958/958 [00:02<00:00, 325.49it/s]
100%|██████████| 1028/1028 [00:03<00:00, 295.71it/s]
100%|██████████| 974/974 [00:03<00:00, 310.53it/s]
100%|██████████| 1009/1009 [00:03<00:00, 309.01it/s]
100%|██████████| 10/10 [00:33<00:00,  3.34s/it]


Retrieving train images


100%|██████████| 5923/5923 [00:20<00:00, 290.92it/s]
100%|██████████| 6742/6742 [00:25<00:00, 267.41it/s]
100%|██████████| 5958/5958 [00:21<00:00, 278.58it/s]
100%|██████████| 6131/6131 [00:24<00:00, 249.93it/s]
100%|██████████| 5842/5842 [00:20<00:00, 282.54it/s]
100%|██████████| 5421/5421 [00:15<00:00, 342.80it/s]
100%|██████████| 5918/5918 [00:17<00:00, 329.50it/s]
100%|██████████| 6265/6265 [00:15<00:00, 403.19it/s]
100%|██████████| 5851/5851 [00:15<00:00, 370.79it/s]
100%|██████████| 5949/5949 [00:17<00:00, 341.42it/s]
100%|██████████| 10/10 [03:14<00:00, 19.48s/it]


In [33]:
#Making a dataloader for this specific CNN which is a wrapper around the Dataset for easy use
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=60000, shuffle=True)
#make the batch size for the test DataLoader the size of the dataset for evaluation.
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size = test_dataset.tensors[0].shape[0], shuffle=True)

Below I really want to try and understand how backpropagation with PyTorch so I can try and implement a SoftMoE model. So we are looking in depth at just a one layer NN and then will try and figure out how this can be implemented with a gating mechanism.

In [48]:
#calculating the accuracy given outputs not softmaxed and labels one hot encoding.
def calculate_accuracy(outputs, labels):
    #don't need to softmax because the max value will be the max softmax we just pull the index to get the digit prediction 
    _, output_index = torch.max(outputs,1)
    #get the index/ digit of the label
    _, label_index = torch.max(labels, 1)
    # return the number of correct matches and divide by the size to get accuracy
    return (output_index == label_index).sum().item()/labels.size(0)


#training loop function
def training_loop(train_loader, test_loader, num_epochs, model, loss_function, optimizer):
    #arrays for our plots
    training_loss = []
    training_accuracy = []
    test_loss = []
    test_accuracy =[]
    #Setting up the training loop
    print("Starting the Training Loop")
    for epoch in range(num_epochs):
        #keep the loss and accuracies after each mini batch
        batch_loss = []
        batch_accuracy = []
        #loop through a mini-batch on the same train loadear
        for batch_index, (data, label) in enumerate(train_loader):
            # Forward pass
            outputs = model(data)
            #evaluate the loss
            loss = loss_function(outputs, label)
            print("Loss: ", loss)
            #append the loss to the batch loss
            batch_loss.append(loss.item())
            #calculate the accuracy based on the outputs (not softmaxed) and labels. Do outputs.data so we don't pass gradient info
            batch_accuracy.append(calculate_accuracy(outputs.data, label))
            # Backward pass setting gradients to zero
            optimizer.zero_grad()
            print("After zero grad")
            for name, param in model.named_parameters():
                print("Parameter name: ", name)
                print("Parameter shape: ", param.shape)
                print("Parameter values: ", param)
                print("Parameter grad: ", param.grad)
            #calcualting gradients
            loss.backward()
            print("After backwards")
            for name, param in model.named_parameters():
                print("Parameter name: ", name)
                print("Parameter shape: ", param.shape)
                print("Parameter values: ", param)
                print("Parameter grad: ", param.grad)

            #updating parameters
            optimizer.step()
            print("After optimization step")
            for name, param in model.named_parameters():
                print("Parameter name: ", name)
                print("Parameter shape: ", param.shape)
                print("Parameter values: ", param)
                print("Parameter grad: ", param.grad)

        #add to the training epoch accuracies and losses
        training_accuracy.append(np.average(batch_accuracy))
        training_loss.append(np.average(batch_loss))
        #get the test loss and accuracy
        #change mode
        model.eval()
        #so we don't accidentally change anything
        with torch.no_grad():
            #get the "batch" of the test data which is all of it
            for batch_index, (data, label) in enumerate(test_loader):
                #get our test predicitons
                test_predictions = model(data)
                #test loss and move to cpu so I can plot
                loss = loss_function(test_predictions, label).to("cpu")
                #append statistics
                test_loss.append(loss)
                test_accuracy.append(calculate_accuracy(test_predictions.data, label))
        #back to training mode
        model.train()
        #printing
        print(f"Epoch: {epoch} done. Test loss {test_loss[epoch]}. Test accuracy {test_accuracy[epoch]}")
    return training_loss, training_accuracy, test_loss, test_accuracy



class OneLayer(torch.nn.Module):
    def __init__(self, input_size = 784, hidden_size = 50, output_size = 10):
        super().__init__()
        #First hiddent layer
        self.hidden = torch.nn.Linear(input_size, hidden_size)
        #ReLU activation function
        self.relu = torch.nn.ReLU()
        #output layer
        self.output = torch.nn.Linear(hidden_size, output_size)
    #forward pass through the network
    def forward(self, x):
        #pass through first hidden layer

        #the shape of the input is [batch size, image size flattened]
        #print("x at start: ", x.shape)
        x = self.hidden(x)
        #activation function
        x = self.relu(x)
        #pass through the output layer
        x = self.output(x)
        return x

# setting the hyperparameters for exercise 1
input_size_1 = 784
num_classes_1 = 10
learning_rate_1 = 0.001
num_epochs_1 = 3


#This is the Neural Network model
model_one_layer= OneLayer(input_size = input_size_1, hidden_size = 50, output_size = num_classes_1).to(DEVICE)
#Our loss function will be cross entropy since we are getting a probability distribution
loss_1 = torch.nn.CrossEntropyLoss()
#Here we are going to use classic stochastic gradient descent without any special optimizations
optimizer_1 = torch.optim.SGD(model_one_layer.parameters(), lr= learning_rate_1)
start_1 = time.time()
training_loss_1, training_accuracy_1, test_loss_1, test_accuracy_1 = training_loop(train_loader, test_loader, 
num_epochs_1, model_one_layer, loss_1,optimizer_1)

end_1 = time.time()
total_time = end_1 - start_1

Starting the Training Loop
Loss:  tensor(2.3030, grad_fn=<DivBackward1>)
After zero grad
Parameter name:  hidden.weight
Parameter shape:  torch.Size([50, 784])
Parameter values:  Parameter containing:
tensor([[ 0.0248, -0.0255, -0.0022,  ...,  0.0128,  0.0226, -0.0343],
        [ 0.0022,  0.0083, -0.0131,  ..., -0.0083,  0.0130,  0.0020],
        [-0.0183,  0.0001, -0.0179,  ...,  0.0127, -0.0327, -0.0273],
        ...,
        [-0.0250,  0.0308, -0.0327,  ...,  0.0019,  0.0334, -0.0356],
        [ 0.0151,  0.0056,  0.0210,  ...,  0.0216, -0.0341, -0.0346],
        [-0.0199,  0.0284,  0.0199,  ...,  0.0332, -0.0357, -0.0035]],
       requires_grad=True)
Parameter grad:  None
Parameter name:  hidden.bias
Parameter shape:  torch.Size([50])
Parameter values:  Parameter containing:
tensor([ 0.0043,  0.0058,  0.0029, -0.0113,  0.0036, -0.0087,  0.0347, -0.0299,
        -0.0029, -0.0249,  0.0287,  0.0263, -0.0348,  0.0079, -0.0210,  0.0187,
         0.0164,  0.0307,  0.0025, -0.0235, -0.0357

In [45]:
for name, param in model_one_layer.named_parameters():
    print(name, param.shape)

hidden.weight torch.Size([50, 784])
hidden.bias torch.Size([50])
output.weight torch.Size([10, 50])
output.bias torch.Size([10])


## SoftMoE implementation  
Below we have the first attempt at a SoftMoE implementation. This will pass data through all the experts and a gating mechanism which will then aggregate the output based on probabilities from the gating mechanism. This takes the form  
$\begin{equation}\sum_i^NG_i(x)E_i(x)\end{equation}.$  
Where $G_i(x)$ is the output from the gating mechanism for the ith expert and $E_i(x)$ is the output from the ith expert. For our first implementation we will have one layer with only a few experts that are then aggregated to produce an output as represented in the image below (without top-k) which is cited [here](http://example.comhttps://apxml.com/posts/how-to-implement-moe-pytorch)  
![Architecture](SoftMoe.png "SoftMoE architecture")  
