In [1]:
import os

import torch
from torch.nn.functional import binary_cross_entropy
from torch.utils.data import Subset

from local_sgd import LocalSGD
from minibatch_sgd import MinibatchSGD
from dataset import create_dummy_dataset
from model import LinearModel
from lr_gridsearch import lr_gridsearch

# Set seed manually for reproducibility
torch.manual_seed(42)

# WandB constants
# -----------------------------------------------------------------------------
os.environ["WANDB_ENTITY"] = "RADFAN"
os.environ["WANDB_PROJECT"] = "LocalSGD"

In [2]:
dataset = create_dummy_dataset(
    num_samples=50000,
    num_features=25
)
model = LinearModel(input_size=25, bias=True)
loss_fn = binary_cross_entropy

In [3]:
NUM_WORKERS = [50, 50, 50, 500, 500, 500]
K = [5, 40, 200, 5, 40, 200] 
NUM_EPOCHS = [1, 4, 20, 1, 4, 20]
NUM_SAMPLES_TO_CHOOSE = [25000, 50000, 50000, 25000, 50000, 50000] 

LR_GRID = [1e-0, 1e-1, 1e-2, 1e-3, 1e-4]

In [None]:
for num_workers, k, num_epochs, num_samples in zip(
    NUM_WORKERS, K, NUM_EPOCHS, NUM_SAMPLES_TO_CHOOSE
):
    lr_gridsearch(
        algorithm=LocalSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID
    )
    
    lr_gridsearch(
        algorithm=MinibatchSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID
    )

[34m[1mwandb[0m: Currently logged in as: [33mevgurovv[0m ([33mRADFAN[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch 1: 100%|██████████| 100/100 [00:29<00:00,  3.36it/s]


0,1
loss,▇▆▂▄▁▃▂▃▃▃▁▂▁▅▂▂▄▂▂▁▃▁▂▂▂▃▄▂▃█▅▃▂▃▄▃▃▄▅▄

0,1
loss,0.54959


Epoch 1: 100%|██████████| 100/100 [00:28<00:00,  3.48it/s]


0,1
loss,█▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.41128


Epoch 1: 100%|██████████| 100/100 [00:29<00:00,  3.35it/s]


0,1
loss,█▇▇▇▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
loss,0.4875


Epoch 1: 100%|██████████| 100/100 [00:28<00:00,  3.51it/s]


0,1
loss,███▇▇▇▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁

0,1
loss,0.58533


Epoch 1: 100%|██████████| 100/100 [00:28<00:00,  3.49it/s]


0,1
loss,███▇▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁

0,1
loss,0.6411


Epoch 1: 100%|██████████| 200/200 [01:10<00:00,  2.82it/s]


0,1
loss,█▃▃▂▄▂▂▂▂▂▁▁▃▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▂▁

0,1
loss,0.40279


Epoch 1:  28%|██▊       | 55/200 [00:19<00:51,  2.83it/s]


KeyboardInterrupt: 