# Train a model and log the results on Weights & Biases

In this notebook we will see how to log the results of a model training on Weights & Biases. 
The code used to train the model is taken from the [pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/convolutional_neural_network/main.py).

In [None]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import wandb
from utils import log_test_predictions

In [None]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

# These two additional parameters are for logging to wandb 
# the images and the corresponding model predictions. 
NUM_BATCHES_TO_LOG = 10
NUM_IMAGES_PER_BATCH = 32

In [None]:
# Prepare the data 

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

In [None]:
# Define the model we will use

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
model = ConvNet(num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Initialization of a W&B Run object 

The run object is the object used to log data to W&B.
We can create a run object using the `wandb.init()` function.
There are several parameters that can be passed to the `wandb.init()` function to customize the run object:
- `project`: The name of the project to which the experiment belongs.
- `name`: The name of the experiment.
- `config`: A dictionary of configuration parameters for the run we're starting. These are static parameters that usually do not change during the process and that will be logged to WANDB. They can be useful to identify the run and to compare different runs.
- `tags`: A list of tags to add to the run. These can be useful to filter the different experiments and to group them by tags.

In [None]:
wandb_run = wandb.init(project="Temperatures", # name of the project in which we want to store our runs
                        config={
                            "num_epochs": num_epochs,
                            "batch_size": batch_size,
                            "learning_rate": learning_rate,
                            "num_classes": num_classes,
                            "criterion": "CrossEntropyLoss", 
                            "optimizer": "Adam"                   
                        })

## Train the model and log the results

We can use the classic functions to train models with Pytorch and add just a few lines of code to 
log the results to wandb:
- In the train function we call the log function to log the loss of the model after each step. In this case we use i*epoch as the step number that we want to use as "x" axis in the plot.
- Then we call the log function to log the accuracy of the model on the test set after each epoch.
- We also log a table with some of the images of the test dataset with the corresponding vector of confidence of the model.

In the end log on Wandb the final model that we trained. To log it we have to save the model on disk, then we create an Artifact object with the model file and log it to Wandb using the function log_artifact(). 
We can do the same thing with the dataset used to train the model so that the experiment will be reproducible in the future.

In [None]:
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

        wandb.log({"loss": loss, "step": i*epoch})
        
    # ✨ W&B: Create a Table to store predictions for each test step
    columns=["id", "image", "guess", "truth"]
    for digit in range(10):
        columns.append("score_" + str(digit))
    test_table = wandb.Table(columns=columns)
    
    # Test the model
    model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    log_counter = 0
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            if log_counter < NUM_BATCHES_TO_LOG:
                log_test_predictions(images, labels, outputs, predicted, test_table, log_counter, NUM_IMAGES_PER_BATCH)
                log_counter += 1
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
        
        accuracy = 100 * correct / total
        wandb.log({"epoch" : epoch, "acc" : accuracy})

    wandb.log({"test_predictions" : test_table})

torch.save(model.state_dict(), 'model.ckpt')

# We store the model on wandb
artifact = wandb.Artifact(name="model", type="model")
artifact.add_file(local_path="model.ckpt")
wandb.log_artifact(artifact)

wandb.finish()