In [1]:
import os

import torch
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.utils.data import Subset

from local_sgd import LocalSGD
from minibatch_sgd import MinibatchSGD
from synthetic_dataset import create_dummy_dataset
from models import LinearModel
from lr_gridsearch import lr_gridsearch

# Set seed manually for reproducibility
torch.manual_seed(42)

# WandB constants
# -----------------------------------------------------------------------------
os.environ["WANDB_ENTITY"] = "RADFAN"
os.environ["WANDB_PROJECT"] = "LocalSGD"

In [2]:
dataset = create_dummy_dataset(
    num_samples=50000,
    num_features=25
)
model = LinearModel(input_size=25, bias=True)
loss_fn = binary_cross_entropy_with_logits

In [3]:
NUM_WORKERS = [50]
K = [10]
NUM_EPOCHS = [1]
NUM_SAMPLES_TO_CHOOSE = [50000]

LR_GRID = [1e-0, 1e-1, 1e-2, 1e-3]

In [4]:
for num_workers, k, num_epochs, num_samples in zip(
    NUM_WORKERS, K, NUM_EPOCHS, NUM_SAMPLES_TO_CHOOSE
):
    lr_gridsearch(
        algorithm=LocalSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID
    )
    
    lr_gridsearch(
        algorithm=MinibatchSGD,
        model=model,
        dataset=Subset(dataset, indices=range(num_samples)),
        loss_fn=loss_fn,
        num_workers=num_workers,
        K=k,
        num_epochs=num_epochs,
        lr_grid=LR_GRID
    )

[34m[1mwandb[0m: Currently logged in as: [33mevgurovv[0m ([33mRADFAN[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch 1: 100%|██████████| 100/100 [01:23<00:00,  1.19it/s]


0,1
loss,▃█▃▁▆▂▁▅▁▄▂▂▅▂▃▂▂▂▄▆▃▂▁▄▂▄▂▁▂▂▁▂▁▂▄▄▃▆▃▅

0,1
loss,0.73207


Epoch 1: 100%|██████████| 100/100 [01:24<00:00,  1.19it/s]


0,1
loss,█▅▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.4082


Epoch 1: 100%|██████████| 100/100 [01:23<00:00,  1.20it/s]


0,1
loss,█▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
loss,0.45602


Epoch 1: 100%|██████████| 100/100 [01:24<00:00,  1.19it/s]


0,1
loss,███▇▇▇▇▆▆▆▅▅▅▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
loss,0.56148


Epoch 1: 100%|██████████| 100/100 [01:04<00:00,  1.55it/s]


0,1
loss,█▄▃▃▃▃▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.40581


Epoch 1: 100%|██████████| 100/100 [01:05<00:00,  1.53it/s]


0,1
loss,█▇▇▇▆▆▆▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.45585


Epoch 1: 100%|██████████| 100/100 [01:04<00:00,  1.56it/s]


0,1
loss,███▇▆▆▆▆▅▅▅▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁

0,1
loss,0.56141


Epoch 1: 100%|██████████| 100/100 [01:06<00:00,  1.51it/s]


0,1
loss,████▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
loss,0.63001
