## Notebook for testing MoE with MNIST  
Main objective here is to figure out how to get the gradients to go through the loss function using PyTorch

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np 
import random
import time
import sys
import os
import matplotlib.pyplot as plt

#for the disk_memoize function
import pickle
import hashlib
from functools import wraps
from tqdm import tqdm

# Add scripts folder path so I can get load_mnist
repo_root = os.path.abspath("..")  # one level up from /notebook
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
from scripts.MNIST import load_mnist
print(sys.path)

#just some basic stuff to set for reproducability
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 0.001

['c:\\Users\\caleb\\OneDrive - Uppsala universitet\\Fall 2025\\Projects Course\\mixture-of-experts-organization\\mixture-of-experts-project', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0\\python313.zip', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0\\DLLs', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0\\Lib', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.13_3.13.2288.0_x64__qbz5n2kfra8p0', '', 'C:\\Users\\caleb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python313\\site-packages', 'C:\\Users\\caleb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python313\\site-packages\\win32', 'C:\\Users\\caleb\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCac

In [None]:
#want to get the data in a linear format becuase our simple MoE will be with linear layers
def get_data(linear = True):
    #get the train and test data from the dataset
    xtrain,ytrain,xtest,ytest = load_mnist.load_mnist()
    #if we want to work with flattened/ linear input
    if linear:
        xtrain = torch.Tensor(xtrain).to(DEVICE)
        ytrain = torch.Tensor(ytrain).to(DEVICE)
        xtest = torch.Tensor(xtest).to(DEVICE)
        ytest = torch.Tensor(ytest).to(DEVICE)
    else:
        #converting to Tensors for easy PyTorch implementation and reshape for a CNN
        xtrain = torch.Tensor(xtrain).reshape(60000, 1,28,28).to(DEVICE)
        ytrain = torch.Tensor(ytrain).to(DEVICE)
        xtest = torch.Tensor(xtest).reshape(10000, 1,28,28).to(DEVICE)
        ytest = torch.Tensor(ytest).to(DEVICE)
    #first we want to put our data in a pytorch dataset so we can mini batch and enumerate through it later more easily
    train_dataset = torch.utils.data.TensorDataset(xtrain, ytrain)
    test_dataset = torch.utils.data.TensorDataset(xtest, ytest)
    #Making a dataloader for this specific CNN which is a wrapper around the Dataset for easy use
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    #make the batch size for the test DataLoader the size of the dataset for evaluation.
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size = ytest.shape[0], shuffle=True)
    return train_loader, test_loader

#get the data
train_loader, test_loader = get_data()

Retrieving test images


100%|██████████| 980/980 [00:01<00:00, 521.80it/s]
100%|██████████| 1135/1135 [00:02<00:00, 524.07it/s]
100%|██████████| 1032/1032 [00:01<00:00, 520.58it/s]
100%|██████████| 1010/1010 [00:01<00:00, 523.60it/s]
100%|██████████| 982/982 [00:01<00:00, 527.58it/s]
100%|██████████| 892/892 [00:01<00:00, 524.74it/s]
100%|██████████| 958/958 [00:01<00:00, 528.52it/s]
100%|██████████| 1028/1028 [00:01<00:00, 526.77it/s]
100%|██████████| 974/974 [00:01<00:00, 524.24it/s]
100%|██████████| 1009/1009 [00:01<00:00, 523.25it/s]
100%|██████████| 10/10 [00:19<00:00,  1.91s/it]


Retrieving train images


100%|██████████| 5923/5923 [00:14<00:00, 405.26it/s]
100%|██████████| 6742/6742 [00:27<00:00, 246.78it/s]
100%|██████████| 5958/5958 [03:38<00:00, 27.21it/s] 
100%|██████████| 6131/6131 [00:34<00:00, 176.59it/s]
100%|██████████| 5842/5842 [00:29<00:00, 196.73it/s]
100%|██████████| 5421/5421 [00:27<00:00, 197.58it/s]
100%|██████████| 5918/5918 [00:28<00:00, 205.16it/s]
100%|██████████| 6265/6265 [00:28<00:00, 216.16it/s]
100%|██████████| 5851/5851 [00:28<00:00, 207.75it/s]
100%|██████████| 5949/5949 [00:29<00:00, 204.09it/s]
100%|██████████| 10/10 [07:48<00:00, 46.80s/it]


[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
