In [1]:
import os

import torch
from torch.nn.functional import cross_entropy
from torch.utils.data import Subset
from torchvision.datasets import MNIST
from torchvision import transforms

from local_sgd import LocalSGD
from minibatch_sgd import MinibatchSGD
from models import SimpleFFN
from lr_gridsearch import lr_gridsearch

# Set seed manually for reproducibility
torch.manual_seed(42)

# WandB constants
# -----------------------------------------------------------------------------
os.environ["WANDB_ENTITY"] = "RADFAN"
os.environ["WANDB_PROJECT"] = "LocalSGD"

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = MNIST(
    root='./mnist',
    train=True,
    transform=transform,
    download=True,
)
model = SimpleFFN(input_size=28*28, hidden_size=100, output_size=10)
loss_fn = cross_entropy

In [3]:
NUM_WORKERS = [5]
K = [20]
NUM_EPOCHS = [1]
NUM_SAMPLES_TO_CHOOSE = [10000]

LR_GRID_LSGD = [1e-1, 1e-2, 1e-3]
LR_GRID_MSGD = [1e-0, 1e-1, 1e-2]

In [None]:
for num_workers, k, num_epochs, num_samples in zip(
    NUM_WORKERS, K, NUM_EPOCHS, NUM_SAMPLES_TO_CHOOSE
):
    lr_gridsearch(
        algorithm=LocalSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID_LSGD,
    )
    
    lr_gridsearch(
        algorithm=MinibatchSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID_MSGD,
    )