# Experiment Tracking with W&B

### Code is partly adopted from [wandb.me/intro](https://wandb.me/intro)

In [1]:
import wandb
import math, random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
from torchviz import make_dot
import pickle

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"device in use is {device}")

def load_data(is_train, batch_size, subset=5):
    "Loading MNIST dataset"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), subset))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, batch_size=batch_size, shuffle=True if is_train else False, pin_memory=True, num_workers=4)
    return loader

def the_model(dropout):
    model = nn.Sequential(nn.Flatten(), nn.Linear(28*28, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256,10)).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # Create a wandb Table to log images, labels and predictions
    table = wandb.Table(columns=["image", "prediction", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table})

cuda:0


In [11]:
# defining some hyperparameters
config = {"epochs": 10, "batch_size": 128, "lr": 1e-3, "dropout": random.uniform(0.01, 0.80)}

# start a wandb run
wandb.init(project="experiment-tracking", sync_tensorboard=True, config=config)

# we can use config as a dictionary or we can use wandb.config
config = wandb.config

# Loading train set and validation set
train_data = load_data(is_train=True, batch_size=config.batch_size)
valid_data = load_data(is_train=False, batch_size=config.batch_size)
n_steps_per_epoch = math.ceil(len(train_data.dataset) / config.batch_size)

# A multi layer perceptron (MLP) model
model = the_model(config.dropout)

# Visualizing the neural network (model) using four different methods

## 1. Visualizing using TensorBoard
tb_writer = SummaryWriter(log_dir=wandb.run.dir)
datum = next(iter(train_data))
tb_writer.add_graph(model.to(device), datum[0].to(device))

## 2. Visualizing using make_dot
model.to(device)
out = model(datum[0].to(device))
model_graph = make_dot(out)
pickle.dump(model_graph, open('model_graph.pkl', "wb" ))

## 3. Saving model layers using wandb.watch command
wandb.watch(model, log="all")

## 4. Visualizing model using ONNX
torch.onnx.export(model,               # model being run
                datum[0].to(device),                         # model input (or a tuple for multiple inputs)
                "model.onnx",   # where to save the model (can be a file or file-like object)
                export_params=True,        # store the trained parameter weights inside the model file
                opset_version=10,          # the ONNX version to export the model to
                do_constant_folding=True,  # whether to execute constant folding for optimization
                input_names = ['input'],   # the model's input names
                output_names = ['output'], # the model's output names
                dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                            'output' : {0 : 'batch_size'}})
wandb.save('model.onnx')

# Loss function and optimizer
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

# Training loop
for epoch in range(config.epochs):
    model.train()
    for step, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        train_loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        metrics = {"train/train_loss": train_loss, "train/epoch": epoch}
        
        if step + 1 < n_steps_per_epoch:
            # Log metrics to wandb for every batch
            wandb.log(metrics)

    val_loss, accuracy = validate_model(model, valid_data, loss_func, log_images=(epoch==(config.epochs-1)))
    
    # Log train and validation metrics to wandb for each epoch
    val_metrics = {"val/val_loss": val_loss, "val/val_accuracy": accuracy}

    wandb.log({**metrics, **val_metrics})
    
    print(f"Epoch {epoch}. Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

# If you had a test set, this is how you could log it as a Summary metric
wandb.summary['test_accuracy'] = 0.8

# Saving the learned model after the last epoch
torch.save(model.state_dict(), 'model_state.pt')

# To save the model state in wandb, you need to use the following command to upload it to your wandb run
wandb.save('model_state.pt')

# Close your wandb run 
wandb.finish()



Epoch 0. Train Loss: 0.379, Valid Loss: 0.284890, Accuracy: 0.91
Epoch 1. Train Loss: 0.262, Valid Loss: 0.229349, Accuracy: 0.93
Epoch 2. Train Loss: 0.130, Valid Loss: 0.194712, Accuracy: 0.94
Epoch 3. Train Loss: 0.130, Valid Loss: 0.184114, Accuracy: 0.94
Epoch 4. Train Loss: 0.096, Valid Loss: 0.181357, Accuracy: 0.94
Epoch 5. Train Loss: 0.119, Valid Loss: 0.167953, Accuracy: 0.95
Epoch 6. Train Loss: 0.105, Valid Loss: 0.158096, Accuracy: 0.95
Epoch 7. Train Loss: 0.031, Valid Loss: 0.160057, Accuracy: 0.95
Epoch 8. Train Loss: 0.064, Valid Loss: 0.158707, Accuracy: 0.95
Epoch 9. Train Loss: 0.101, Valid Loss: 0.162270, Accuracy: 0.95


VBox(children=(Label(value='0.869 MB of 1.681 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.517052…

0,1
global_step,▁
train/epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/train_loss,█▅▄▄▂▂▃▂▂▂▂▂▁▂▂▂▂▂▂▁▁▁▁▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▃▆▆▆▇█▇██
val/val_loss,█▅▃▂▂▂▁▁▁▁

0,1
global_step,0.0
test_accuracy,0.8
train/epoch,9.0
train/train_loss,0.10094
val/val_accuracy,0.9535
val/val_loss,0.16227
