In [5]:
import os

import torch
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.utils.data import Subset

from local_sgd import LocalSGD
from minibatch_sgd import MinibatchSGD
from synthetic_dataset import create_dummy_dataset
from models import LinearModel
from lr_gridsearch import lr_gridsearch

# Set seed manually for reproducibility
torch.manual_seed(42)

# WandB constants
# -----------------------------------------------------------------------------
os.environ["WANDB_ENTITY"] = "RADFAN"
os.environ["WANDB_PROJECT"] = "LocalSGD"

In [6]:
dataset = create_dummy_dataset(
    num_samples=50000,
    num_features=25
)
model = LinearModel(input_size=25, bias=True)
loss_fn = binary_cross_entropy_with_logits

In [7]:
NUM_WORKERS = [5]
K = [40]
NUM_EPOCHS = [1]
NUM_SAMPLES_TO_CHOOSE = [20000]

LR_GRID_LSGD = [1e-1, 5e-2, 1e-2]
LR_GRID_MSGD = [1e-0, 5e-1, 1e-1, 5e-2]

In [None]:
for num_workers, k, num_epochs, num_samples in zip(
    NUM_WORKERS, K, NUM_EPOCHS, NUM_SAMPLES_TO_CHOOSE
):
    lr_gridsearch(
        algorithm=LocalSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID_LSGD
    )
    
    lr_gridsearch(
        algorithm=MinibatchSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID_MSGD
    )