# Using wandb to track experiments.

Demo task: multi-class image classification using CIFAR10 dataset.

In [1]:
from sklearn.metrics import average_precision_score
from torch.utils.data import DataLoader
from torchvision import datasets, models
from torchvision import transforms as T
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
!pip install wandb

!wandb login
# Put your API key here

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.15.3-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.24.0-py2.py3-none-any.whl (206 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.5/206.5 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools (from wandb)
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting

# The next cell includes-
- Collecting the CIFAR10 dataset and defining data loaders.
- Methods to load model, criterion, optimizer and schedulers.
- Definition of AverageMeter

In [3]:
# Downloading CIFAR10 dataset
inp_transforms = T.Compose([T.ToTensor(),
                            T.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])])
tgt_transforms = T.Lambda(lambda y: torch.zeros(10, dtype=torch.long).scatter_(0, torch.tensor(y), value=1))
cifar10 = datasets.CIFAR10(root = "/.",
                           transform = inp_transforms,
                           target_transform = tgt_transforms,
                           download = True)

# Defining dataset split (80-20)
train_dataset, val_dataset = torch.utils.data.random_split(cifar10,
                                                           [int(len(cifar10)*0.80), int(len(cifar10)*0.20)])

# Defining the dataloaders
train_dataloader = DataLoader(train_dataset,
                              batch_size=200,
                              shuffle=True)
val_dataloader = DataLoader(val_dataset,
                            batch_size=200,
                            shuffle=False)


# Method to get model based on config param model_type
def get_model(model_type):
    model = None
    if model_type == "pretrained": # Loading pretrained ResNet18 and with updated to final fc layer. 
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(512, 10)
        model = model.to(device)
    elif model_type == "scratch": # Loading a blank ResNet18 which generated 10 outputs.
        model = models.resnet18(num_classes=10)
        model = model.to(device)
    else:
        raise NotImplemented
    return model


# Method to get criterion, optimizer and scheduler based on config params.
def get_criterion_optimizer_scheduler(config, model):
    optim_dct = {
        "adam": optim.Adam,
        "SGD": optim.SGD,
        "RMSprop": optim.RMSprop
    }
    optimizer = optim_dct[config["optimizer"]](model.parameters(), lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           factor=0.1,
                                                           patience=config["scheduler_patience"],
                                                           threshold=config["scheduler_thresh"])
    criterion = nn.CrossEntropyLoss()
    return criterion, optimizer, scheduler



# Remainder of this cell includes definition of AverageMeter (can be ignored)
"""
Code taken from Pytorch ImageNet examples
https://github.com/pytorch/examples/blob/main/imagenet/main.py#L375
"""
class Summary():
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.val_history = list()
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.val_history = list()

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.val_history.append(val)

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)        
        return fmtstr.format(**self.__dict__)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 30752035.15it/s]


Extracting /./cifar-10-python.tar.gz to /.


# Following cell includes-
- Defining the train and eval loops.
- Method to trigger training loops based on config parameters.

In [4]:
# The train function without wandb logging

def train(model, criterion, optimizer, scheduler, epochs, train_dataloader, val_dataloader, device):
    for epoch in range(epochs):
        model.train()
        loss_meter = AverageMeter("train_loss", ":.5f")
        epoch_outs, epoch_tgt = list(), list()
        for data, tgt_vec in tqdm(train_dataloader):
            data, tgt_vec = data.to(device), tgt_vec.to(device)
            targets = torch.argmax(tgt_vec, axis=1)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, targets)
            loss_meter.update(loss.item(), data.shape[0])
            loss.backward()
            optimizer.step()
            epoch_outs.append(out)
            epoch_tgt.append(tgt_vec)
        predictions = torch.vstack([torch.softmax(out, axis=1) for out in epoch_outs]).detach().cpu().numpy()
        targets = torch.cat([tgt for tgt in epoch_tgt], dim=0).detach().cpu().numpy()
        ap_score = average_precision_score(targets, predictions)
        eval_loss_meter, eval_ap_score = evaluate(model, criterion, val_dataloader, device)
        data_to_log = {
            "epoch": epoch+1,
            "train_loss": loss_meter.avg,
            "eval_loss": eval_loss_meter.avg,
            "train_ap_score": ap_score,
            "eval_ap_score": eval_ap_score,
            "lr": optimizer.state_dict()["param_groups"][0]["lr"],
        }
        scheduler.step(eval_loss_meter.avg)
        print(data_to_log)
        wandb.log(data_to_log)


@torch.no_grad()
def evaluate(model, criterion, val_dataloader, device):
    model.eval()
    loss_meter = AverageMeter("eval_loss", ":.5f")
    epoch_outs, epoch_tgt = list(), list()
    for data, tgt_vec in val_dataloader:
        data, tgt_vec = data.to(device), tgt_vec.to(device)
        targets = torch.argmax(tgt_vec, axis=1)
        out = model(data)
        loss = criterion(out, targets)
        loss_meter.update(loss.item(), data.shape[0])
        epoch_outs.append(out)
        epoch_tgt.append(tgt_vec)
    predictions = torch.vstack([torch.softmax(out, axis=1) for out in epoch_outs]).detach().cpu().numpy()
    targets = torch.cat([tgt for tgt in epoch_tgt], dim=0).detach().cpu().numpy()
    ap_score = average_precision_score(targets, predictions)
    return loss_meter, ap_score


def trigger_training(config):
    model = get_model(config["model_type"])
    criterion, optimizer, scheduler = get_criterion_optimizer_scheduler(config, model)
    epochs = config["num_epochs"]

    train(model, criterion, optimizer, scheduler, epochs, train_dataloader, val_dataloader, device)


# Complete the config file, edit the cells in this notebook to log data to wandb and trigger training loops!

In [5]:
# Fill the Config file below and log the experiment at wandb
config = {
    "lr": 0.0, 
    "model_type": "scratch", # pretrained/scratch
    "optimizer": "adam", # adam/SGD/RMSprop
    "criterion": "ce",
    "scheduler_patience": 3,
    "scheduler_thresh": 0.001,
    "num_epochs": 40, # CHANGE
    "gpu_id": 0,
    "wandb_run_name": "bhav" ### FILL YOUR NAME HERE
}


In [6]:
import wandb

In [7]:
wandb.init(entity = "dhruv_sri",   # wandb username. (NOT REQUIRED ARG. ANYMORE, it fetches from initial login)
           project = "wandb_demo", # wandb project name. New project will be created if given project is missing.
           config = config         # Config dict
          )
wandb.run.name = config["wandb_run_name"]

[34m[1mwandb[0m: Currently logged in as: [33mbhavberi[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668697049999535, max=1.0…

In [8]:
trigger_training(config)


100%|██████████| 200/200 [00:34<00:00,  5.85it/s]


{'epoch': 1, 'train_loss': 2.5776724207401274, 'eval_loss': 2.580969338417053, 'train_ap_score': 0.09983939892637754, 'eval_ap_score': 0.0995316337350098, 'lr': 0.0}


100%|██████████| 200/200 [00:31<00:00,  6.37it/s]


{'epoch': 2, 'train_loss': 2.578222749233246, 'eval_loss': 2.583479585647583, 'train_ap_score': 0.09989179590697285, 'eval_ap_score': 0.09960804065908385, 'lr': 0.0}


100%|██████████| 200/200 [00:34<00:00,  5.85it/s]


{'epoch': 3, 'train_loss': 2.5787194848060606, 'eval_loss': 2.587963171005249, 'train_ap_score': 0.09973366782172097, 'eval_ap_score': 0.09943244731674321, 'lr': 0.0}


100%|██████████| 200/200 [00:28<00:00,  7.12it/s]


{'epoch': 4, 'train_loss': 2.578044408559799, 'eval_loss': 2.5824219608306884, 'train_ap_score': 0.09976094621563625, 'eval_ap_score': 0.09936654332955074, 'lr': 0.0}


100%|██████████| 200/200 [00:32<00:00,  6.17it/s]


{'epoch': 5, 'train_loss': 2.578254451751709, 'eval_loss': 2.582160472869873, 'train_ap_score': 0.09968076999592583, 'eval_ap_score': 0.09937048681722294, 'lr': 0.0}


100%|██████████| 200/200 [00:31<00:00,  6.45it/s]


{'epoch': 6, 'train_loss': 2.5775618994235994, 'eval_loss': 2.583077459335327, 'train_ap_score': 0.09981655835170591, 'eval_ap_score': 0.09953907440033177, 'lr': 0.0}


100%|██████████| 200/200 [00:31<00:00,  6.36it/s]


{'epoch': 7, 'train_loss': 2.578031997680664, 'eval_loss': 2.584074549674988, 'train_ap_score': 0.0997643496266566, 'eval_ap_score': 0.09917765404401999, 'lr': 0.0}


100%|██████████| 200/200 [00:31<00:00,  6.29it/s]


{'epoch': 8, 'train_loss': 2.5784185945987703, 'eval_loss': 2.582074136734009, 'train_ap_score': 0.09980409546571213, 'eval_ap_score': 0.09942697325621201, 'lr': 0.0}


100%|██████████| 200/200 [00:30<00:00,  6.66it/s]


{'epoch': 9, 'train_loss': 2.577744176387787, 'eval_loss': 2.5877378129959108, 'train_ap_score': 0.09990603343561665, 'eval_ap_score': 0.09949768610457384, 'lr': 0.0}


100%|██████████| 200/200 [00:32<00:00,  6.21it/s]


{'epoch': 10, 'train_loss': 2.5781187176704408, 'eval_loss': 2.586380596160889, 'train_ap_score': 0.09975790743784127, 'eval_ap_score': 0.0993462950119057, 'lr': 0.0}


100%|██████████| 200/200 [00:30<00:00,  6.65it/s]


{'epoch': 11, 'train_loss': 2.5789030587673185, 'eval_loss': 2.582804951667786, 'train_ap_score': 0.09968117656151748, 'eval_ap_score': 0.09938760041131, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.81it/s]


{'epoch': 12, 'train_loss': 2.5773394775390623, 'eval_loss': 2.5851004409790037, 'train_ap_score': 0.10003359427168021, 'eval_ap_score': 0.09953250726562575, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.05it/s]


{'epoch': 13, 'train_loss': 2.579117681980133, 'eval_loss': 2.58171236038208, 'train_ap_score': 0.09965929905071697, 'eval_ap_score': 0.09923617544598974, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.94it/s]


{'epoch': 14, 'train_loss': 2.5772964322566985, 'eval_loss': 2.583012375831604, 'train_ap_score': 0.09979227475978662, 'eval_ap_score': 0.09941445150218611, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.04it/s]


{'epoch': 15, 'train_loss': 2.578106826543808, 'eval_loss': 2.5862761735916138, 'train_ap_score': 0.09968462454666613, 'eval_ap_score': 0.09936452666139381, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.80it/s]


{'epoch': 16, 'train_loss': 2.5782781052589416, 'eval_loss': 2.586283416748047, 'train_ap_score': 0.09976767966293944, 'eval_ap_score': 0.09962116908450873, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.22it/s]


{'epoch': 17, 'train_loss': 2.579158067703247, 'eval_loss': 2.5843823766708374, 'train_ap_score': 0.09978601437520443, 'eval_ap_score': 0.09963335669426968, 'lr': 0.0}


100%|██████████| 200/200 [00:21<00:00,  9.44it/s]


{'epoch': 18, 'train_loss': 2.578124680519104, 'eval_loss': 2.57983588218689, 'train_ap_score': 0.09992805904825239, 'eval_ap_score': 0.09952184112868408, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.95it/s]


{'epoch': 19, 'train_loss': 2.5794551515579225, 'eval_loss': 2.582182822227478, 'train_ap_score': 0.09947215748431809, 'eval_ap_score': 0.09949204686587002, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.19it/s]


{'epoch': 20, 'train_loss': 2.5770978331565857, 'eval_loss': 2.5848310232162475, 'train_ap_score': 0.09980902781452677, 'eval_ap_score': 0.09951943430444563, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.16it/s]


{'epoch': 21, 'train_loss': 2.5789215397834777, 'eval_loss': 2.5853428411483765, 'train_ap_score': 0.09973415133565941, 'eval_ap_score': 0.09921596741497271, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.93it/s]


{'epoch': 22, 'train_loss': 2.578049658536911, 'eval_loss': 2.58352032661438, 'train_ap_score': 0.09980458367221405, 'eval_ap_score': 0.09926259808704847, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.27it/s]


{'epoch': 23, 'train_loss': 2.578490762710571, 'eval_loss': 2.5818199491500855, 'train_ap_score': 0.09964781151029536, 'eval_ap_score': 0.09963847284006774, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.26it/s]


{'epoch': 24, 'train_loss': 2.5796469795703887, 'eval_loss': 2.586447024345398, 'train_ap_score': 0.09957247962420017, 'eval_ap_score': 0.09914065335209174, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.03it/s]


{'epoch': 25, 'train_loss': 2.5783488368988037, 'eval_loss': 2.583089442253113, 'train_ap_score': 0.0996673563826614, 'eval_ap_score': 0.09947601472920845, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.80it/s]


{'epoch': 26, 'train_loss': 2.5793893325328825, 'eval_loss': 2.583512001037598, 'train_ap_score': 0.09959281962594052, 'eval_ap_score': 0.09950244424362213, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.11it/s]


{'epoch': 27, 'train_loss': 2.5779205429553986, 'eval_loss': 2.585082378387451, 'train_ap_score': 0.09989376567913924, 'eval_ap_score': 0.09939614040611264, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.03it/s]


{'epoch': 28, 'train_loss': 2.5785559570789336, 'eval_loss': 2.580647702217102, 'train_ap_score': 0.09969688551361508, 'eval_ap_score': 0.09937959389997222, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.93it/s]


{'epoch': 29, 'train_loss': 2.5799320161342623, 'eval_loss': 2.582861843109131, 'train_ap_score': 0.09947095141027287, 'eval_ap_score': 0.09914436765169293, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.79it/s]


{'epoch': 30, 'train_loss': 2.5763994240760804, 'eval_loss': 2.58592191696167, 'train_ap_score': 0.10000311977704394, 'eval_ap_score': 0.09957888741974283, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.76it/s]


{'epoch': 31, 'train_loss': 2.5788033640384675, 'eval_loss': 2.5858178663253786, 'train_ap_score': 0.0999247814251232, 'eval_ap_score': 0.09948656322194863, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.96it/s]


{'epoch': 32, 'train_loss': 2.5801198399066925, 'eval_loss': 2.5834839963912963, 'train_ap_score': 0.099508394219294, 'eval_ap_score': 0.09924129662767753, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.83it/s]


{'epoch': 33, 'train_loss': 2.578018282651901, 'eval_loss': 2.5859339094161986, 'train_ap_score': 0.09972181051879589, 'eval_ap_score': 0.09930774909026661, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.78it/s]


{'epoch': 34, 'train_loss': 2.5795436692237854, 'eval_loss': 2.582578363418579, 'train_ap_score': 0.09956394089697389, 'eval_ap_score': 0.099468979097531, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.82it/s]


{'epoch': 35, 'train_loss': 2.5783085346221926, 'eval_loss': 2.5880533266067505, 'train_ap_score': 0.09989043958944081, 'eval_ap_score': 0.09963289728699967, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.88it/s]


{'epoch': 36, 'train_loss': 2.577661551237106, 'eval_loss': 2.585993070602417, 'train_ap_score': 0.09987838455751782, 'eval_ap_score': 0.09928147055228109, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.98it/s]


{'epoch': 37, 'train_loss': 2.578665007352829, 'eval_loss': 2.587017683982849, 'train_ap_score': 0.09973778637392414, 'eval_ap_score': 0.09958233159403881, 'lr': 0.0}


100%|██████████| 200/200 [00:20<00:00,  9.84it/s]


{'epoch': 38, 'train_loss': 2.5794949460029604, 'eval_loss': 2.581411848068237, 'train_ap_score': 0.09960253253520883, 'eval_ap_score': 0.0992376173896013, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.04it/s]


{'epoch': 39, 'train_loss': 2.5790510535240174, 'eval_loss': 2.585963978767395, 'train_ap_score': 0.09981071900645279, 'eval_ap_score': 0.09952128490844507, 'lr': 0.0}


100%|██████████| 200/200 [00:19<00:00, 10.03it/s]


{'epoch': 40, 'train_loss': 2.5782238912582396, 'eval_loss': 2.58040078163147, 'train_ap_score': 0.09960598690149781, 'eval_ap_score': 0.09964847941845847, 'lr': 0.0}


In [9]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eval_ap_score,▆▇▅▄▄▆▂▅▆▄▄▆▂▅▄██▆▆▆▂▃█▁▆▆▅▄▁▇▆▂▃▆█▃▇▂▆█
eval_loss,▂▄█▃▃▄▅▃█▇▄▅▃▄▆▆▅▁▃▅▆▄▃▇▄▄▅▂▄▆▆▄▆▃█▆▇▂▆▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_ap_score,▆▆▄▅▄▅▅▅▆▅▄█▃▅▄▅▅▇▁▅▄▅▃▂▃▃▆▄▁█▇▁▄▂▆▆▄▃▅▃
train_loss,▃▄▅▄▄▃▄▅▄▄▆▃▆▃▄▅▆▄▇▂▆▄▅▇▅▇▄▅█▁▆█▄▇▅▃▅▇▆▄

0,1
epoch,40.0
eval_ap_score,0.09965
eval_loss,2.5804
lr,0.0
train_ap_score,0.09961
train_loss,2.57822


# WandB Steps

In [8]:
### Step 1: Import WandB in your code

import wandb

### Step 1 ends

In [None]:
### Step 2:
# Initiate wandb in your script. The moment we trigger wandb.init(), an active
# socket connection is established between your machine and wandb server.
# We specify the entity (wandb username) and project (which wandb project to use for logging)

wandb.init(entity = "dhruv_sri",   # wandb username. (NOT REQUIRED ARG. ANYMORE, it fetches from initial login)
           project = "wandb_demo", # wandb project name. New project will be created if given project is missing.
           config = config         # Config dict
          )
wandb.run.name = config["wandb_run_name"]

### Step 2 ends.


In [None]:
### Step 3: Trigger wandb log
# This step is responsible for sending the logs to wandb

wandb.log(data_to_log)

### Step 3 ends.


In [None]:
### Step 4 (Optional)
# This closes the active socket connection to wandb server. Optional since wandb destructor does the same.

wandb.finish()

### Step 4 ends.


# WandB sweeps related steps

In [None]:
### Step 1:
# Create a WandB sweep config file.
# This config file will be used at the WandB website to initialize a sweep server
program: "demo.py"
method: "grid"
metric:
  name: "eval_ap_score"
  goal: "maximize"
parameters:
    criterion:
      value: "ce"
    gpu_id:
      value: 0
    lr:
      values: [0.1, 0.001, 0.0001]
    model_type:
      values: ["scratch", "pretrained"]
    num_epochs:
      value: 25
    optimizer:
      values: ["adam", "SGD", "RMSprop"]
    scheduler_patience:
      value: 3
    scheduler_thresh:
      value: 0.01

        
### A sample sweep config file if bayes method is used-
# program: wandb_demo.py
# method: bayes
# metric:
#   name: "eval_ap_score"
#   goal: maximize
# parameters:
#   lr:
#     distribution: uniform
#     min: 0.00001
#     max: 0.1
#   criterion:
#     distribution: categorical
#     value:
#       - ce
#   optimizer:
#     distribution: categorical
#     values:
#       - adam
#       - SGD
#       - RMSprop
#   model_type:
#     distribution: categorical
#     values:
#       - pretrained
#       - scratch
#   num_epochs:
#     value:
#       - 30
#   scheduler_thresh:
#     distribution: uniform
#     min: 0.001
#     max: 0.01
#   scheduler_patience:
#     distribution: int_uniform
#     min: 2
#     max: 10


In [None]:
### Step 2
# After using the above config on wandb website, you will get a sweep id in return.
# E.g. sweep id- dhruv_sri/wandb_demo/hbyp0tl8
#
# Add the following agent line in your code-
# Use the generated sweep id in the below code

wandb.agent(sweep_id="### FILL SWEEP ID HERE ###", function=sweep_agent_manager, count=100)


In [None]:
### Step 3
# Notice in above command we mentioned an argument named "function"
# Wandb agents must trigger a function where they can initiate a socket to wandb and get a config.
# So, we will use the following sweep_agent_manager function here-

def sweep_agent_manager():
    wandb.init()
    config = dict(wandb.config)
    run_name = f"{config['model_type']}_{config['optimizer']}_{config['lr']}"
    wandb.run.name = run_name
    trigger_training(config)


In [None]:
### Done.
# Now execute your training script on multiple machines.
# Each run will request the config file from wandb and related experiments will be logged.
# 
# NOTE!! wandb.log(data_to_log) must be present inside the code!! Else there is no meaning to sweep.


# ------------------------------ Ends ------------------------------