In [None]:
%pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.14.2-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.19.1-py2.py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone

In [None]:
import wandb 
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset
    
    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})
    
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

In [None]:
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

def get_model(dropout):
    "A simple model"
    model = nn.Sequential(nn.Flatten(),
                         nn.Linear(28*28, 256),
                         nn.BatchNorm1d(256),
                         nn.ReLU(),
                         nn.Dropout(dropout),
                         nn.Linear(256,10)).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)

In [None]:
# Launch 5 experiments, trying different dropout rates
for _ in range(5):
    # 🐝 initialise a wandb run
    wandb.init(
        project="pytorch-intro",
        config={
            "epochs": 10,
            "batch_size": 128,
            "lr": 1e-3,
            "dropout": random.uniform(0.01, 0.80),
            })
    
    # Copy your config 
    config = wandb.config

    # Get the data
    train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
    
    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

   # Training
    example_ct = 0
    step_ct = 0
    for epoch in range(config.epochs):
        model.train()
        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            example_ct += len(images)
            metrics = {"train/train_loss": train_loss, 
                       "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                       "train/example_ct": example_ct}
            
            if step + 1 < n_steps_per_epoch:
                # 🐝 Log train metrics to wandb 
                wandb.log(metrics)
                
            step_ct += 1

        val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

        # 🐝 Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss, 
                       "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})
        
        print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary['test_accuracy'] = 0.8

    # 🐝 Close your wandb run 
    wandb.finish()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 61819125.83it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 13530179.14it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 37828344.34it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6145331.86it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Train Loss: 0.597, Valid Loss: 0.371554, Accuracy: 0.90
Train Loss: 0.345, Valid Loss: 0.298088, Accuracy: 0.91
Train Loss: 0.486, Valid Loss: 0.273572, Accuracy: 0.92
Train Loss: 0.339, Valid Loss: 0.253735, Accuracy: 0.93
Train Loss: 0.350, Valid Loss: 0.239259, Accuracy: 0.92
Train Loss: 0.238, Valid Loss: 0.235063, Accuracy: 0.92
Train Loss: 0.276, Valid Loss: 0.225597, Accuracy: 0.93
Train Loss: 0.257, Valid Loss: 0.217012, Accuracy: 0.93
Train Loss: 0.261, Valid Loss: 0.211583, Accuracy: 0.93
Train Loss: 0.245, Valid Loss: 0.212486, Accuracy: 0.94


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▃▃▃▃▂▃▂▃▃▂▂▃▂▁▁▂▁▂▂▂▁▁▁▂▁▁▂▁▁▁▂▁▁▂▂▂▁
val/val_accuracy,▁▃▄▆▅▅▆▇▇█
val/val_loss,█▅▄▃▂▂▂▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.24512
val/val_accuracy,0.9365
val/val_loss,0.21249


Train Loss: 0.519, Valid Loss: 0.347337, Accuracy: 0.90
Train Loss: 0.376, Valid Loss: 0.275725, Accuracy: 0.92
Train Loss: 0.298, Valid Loss: 0.257115, Accuracy: 0.92
Train Loss: 0.363, Valid Loss: 0.236444, Accuracy: 0.93
Train Loss: 0.201, Valid Loss: 0.221504, Accuracy: 0.94
Train Loss: 0.196, Valid Loss: 0.211070, Accuracy: 0.93
Train Loss: 0.315, Valid Loss: 0.206134, Accuracy: 0.94
Train Loss: 0.164, Valid Loss: 0.198603, Accuracy: 0.94
Train Loss: 0.191, Valid Loss: 0.195017, Accuracy: 0.93
Train Loss: 0.209, Valid Loss: 0.191402, Accuracy: 0.94


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▄▃▃▃▃▃▃▂▃▂▃▂▂▂▃▂▂▂▂▂▁▁▂▂▂▂▂▂▁▁▁▂▁▁▁▂▂
val/val_accuracy,▁▅▅▆▇▇█▇▇█
val/val_loss,█▅▄▃▂▂▂▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.20902
val/val_accuracy,0.939
val/val_loss,0.1914


Train Loss: 0.380, Valid Loss: 0.306504, Accuracy: 0.91
Train Loss: 0.332, Valid Loss: 0.246628, Accuracy: 0.93
Train Loss: 0.250, Valid Loss: 0.218304, Accuracy: 0.93
Train Loss: 0.182, Valid Loss: 0.196025, Accuracy: 0.94
Train Loss: 0.190, Valid Loss: 0.194334, Accuracy: 0.94
Train Loss: 0.260, Valid Loss: 0.174082, Accuracy: 0.94
Train Loss: 0.149, Valid Loss: 0.176033, Accuracy: 0.94
Train Loss: 0.157, Valid Loss: 0.170783, Accuracy: 0.95
Train Loss: 0.105, Valid Loss: 0.156935, Accuracy: 0.95
Train Loss: 0.158, Valid Loss: 0.157044, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▃▃▃▄▂▂▂▃▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▂
val/val_accuracy,▁▄▅▆▆▆▇▇██
val/val_loss,█▅▄▃▃▂▂▂▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.15785
val/val_accuracy,0.95
val/val_loss,0.15704


Train Loss: 0.282, Valid Loss: 0.303825, Accuracy: 0.91
Train Loss: 0.259, Valid Loss: 0.247942, Accuracy: 0.93
Train Loss: 0.172, Valid Loss: 0.219394, Accuracy: 0.93
Train Loss: 0.100, Valid Loss: 0.199290, Accuracy: 0.94
Train Loss: 0.182, Valid Loss: 0.181764, Accuracy: 0.94
Train Loss: 0.186, Valid Loss: 0.179963, Accuracy: 0.94
Train Loss: 0.132, Valid Loss: 0.164190, Accuracy: 0.94
Train Loss: 0.173, Valid Loss: 0.159274, Accuracy: 0.95
Train Loss: 0.096, Valid Loss: 0.156011, Accuracy: 0.95
Train Loss: 0.061, Valid Loss: 0.167347, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▃▃▃▂▂▂▂▂▃▂▁▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▂▁▂▁▂▁▁▁▁▁▁
val/val_accuracy,▁▃▄▆▇▇▇▇█▇
val/val_loss,█▅▄▃▂▂▁▁▁▂

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.06111
val/val_accuracy,0.9475
val/val_loss,0.16735


Train Loss: 0.308, Valid Loss: 0.284322, Accuracy: 0.92
Train Loss: 0.162, Valid Loss: 0.222787, Accuracy: 0.94
Train Loss: 0.086, Valid Loss: 0.199396, Accuracy: 0.94
Train Loss: 0.094, Valid Loss: 0.176070, Accuracy: 0.95
Train Loss: 0.139, Valid Loss: 0.172301, Accuracy: 0.94
Train Loss: 0.064, Valid Loss: 0.160672, Accuracy: 0.95
Train Loss: 0.042, Valid Loss: 0.158104, Accuracy: 0.95
Train Loss: 0.025, Valid Loss: 0.156680, Accuracy: 0.95
Train Loss: 0.069, Valid Loss: 0.154472, Accuracy: 0.95
Train Loss: 0.039, Valid Loss: 0.154366, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▆▅▄▂▃▃▃▂▃▂▂▂▂▂▂▁▂▂▁▁▁▁▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▄▅▇▆████▇
val/val_loss,█▅▃▂▂▁▁▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.03868
val/val_accuracy,0.95
val/val_loss,0.15437
