In [None]:
!pip install wandb

In [None]:
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def create_dataloader(is_train, batch_size, slice=5):
    """
    Get a training dataloader for loading processed batched dataset for training
    Args:
      is_train: bool
      batch_size: int
      slice: int

    return:
      loader: torch.utils.data.DataLoader
    """
    full_dataset = torchvision.datasets.FashionMNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset,
                                         batch_size=batch_size,
                                         shuffle=True if is_train else False,
                                         pin_memory=True, num_workers=2)
    return loader

In [None]:
train_loader = create_dataloader(is_train = True, batch_size = 32)
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)

In [None]:

class SimpleNN(nn.Module):
  def __init__(self, in_channels: int = 1,
               kernel_size: int= 3, stride:
               int= 1, n_classes: int = 10,
               dropout: float = 0.3):
    """Simple Neural network architecture for FashionMNIST classification

    Args:
      in_channels: int
      kernel_size: int
      stride: int
      n_classes: int
      dropout: float
      """
    super(SimpleNN, self).__init__()
    self.in_channels = in_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.n_classes = n_classes
    self.dropout = dropout
    self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

    self.model =  nn.Sequential(nn.Conv2d(self.in_channels, 16, kernel_size = self.kernel_size, stride = self.stride),
                          nn.BatchNorm2d(16),
                          nn.ReLU(),
                          nn.MaxPool2d(kernel_size = 2),
                          nn.Flatten(),
                          nn.Linear(13 * 13 * 16, 256),
                          nn.BatchNorm1d(256),
                          nn.ReLU(),
                          nn.Dropout(self.dropout),
                          nn.Linear(256, self.n_classes)).to(self.device)

  def forward(self, x):
    output = self.model(x)
    return output

In [None]:
model = SimpleNN(in_channels = 1, dropout = 0.4)

In [None]:
model

SimpleNN(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=2704, out_features=256, bias=True)
    (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Dropout(p=0.4, inplace=False)
    (9): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [None]:
def log_image_table(images, predicted, labels, probs):
  """
  Create a log table for experiment comparison

  Args:
    images: torch.tensor
    predicted: torch.tensor
    labels: torch.tensor
    probs: torch.tensor
  """

  # Create a wandb Table to log images, labels and predictions
  table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
  for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
      table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
  wandb.log({"predictions_table":table}, commit=False)

In [None]:
def validate(model, valid_dl, loss_func, log_images=False, batch_idx=0):
  """
  Compute performance of the model on the validation dataset and log a wandb.Table

  Args:
    model: torch.nn.Module
    valid_dl: torch.utils.data.DataLoader
    loss_func: torch.nn.BinaryCrossEntropy
    log_images: bool
    batch_idx: int

  """
  model.eval()
  val_loss = 0.
  with torch.no_grad():
      correct = 0
      for i, (images, labels) in enumerate(valid_dl):
          images, labels = images.to(device), labels.to(device)

          # Forward propagation
          outputs = model(images)
          val_loss += loss_func(outputs, labels)*labels.size(0)

          # Compute accuracy and accumulate
          _, predicted = torch.max(outputs.data, 1)
          correct += (predicted == labels).sum().item()

          # Log one batch of images to the dashboard, always same batch_idx.
          if i==batch_idx and log_images:
              log_image_table(images, predicted, labels, outputs.softmax(dim=1))
  return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

In [None]:
# Launch 3 experiments with different dropout rates
for _ in range(3):
    # initialise a wandb run
    wandb.init(
        project="Fashion-MNIST-Classification",
        config={
            "epochs": 10,
            "batch_size": 128,
            "lr": 1e-3,
            "dropout": random.uniform(0.01, 0.80),
            })

    # Copy your config
    config = wandb.config

    # Get the data
    train_dl = create_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = create_dataloader(is_train=False, batch_size=2*config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

    # A simple MLP model
    model = SimpleNN(in_channels=1, dropout=config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

   # Training
    example_ct = 0
    step_ct = 0
    for epoch in range(config.epochs):
        model.train()
        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            example_ct += len(images)
            metrics = {"train/train_loss": train_loss,
                       "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                       "train/example_ct": example_ct}

            if step + 1 < n_steps_per_epoch:
                # Log train metrics to wandb
                wandb.log(metrics)

            step_ct += 1

        val_loss, accuracy = validate(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

        # Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss,
                       "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})

        print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary['test_accuracy'] = 0.8

    # Close your wandb run
    wandb.finish()

Train Loss: 0.596, Valid Loss: 0.413211, Accuracy: 0.85
Train Loss: 0.358, Valid Loss: 0.359685, Accuracy: 0.87
Train Loss: 0.313, Valid Loss: 0.364193, Accuracy: 0.87
Train Loss: 0.284, Valid Loss: 0.322773, Accuracy: 0.89
Train Loss: 0.209, Valid Loss: 0.305100, Accuracy: 0.89
Train Loss: 0.155, Valid Loss: 0.321308, Accuracy: 0.89
Train Loss: 0.184, Valid Loss: 0.340890, Accuracy: 0.88
Train Loss: 0.116, Valid Loss: 0.312305, Accuracy: 0.90
Train Loss: 0.119, Valid Loss: 0.346075, Accuracy: 0.89
Train Loss: 0.158, Valid Loss: 0.321600, Accuracy: 0.89


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▆▆▄▅▄▃▃▂▄▂▃▃▂▃▂▂▂▂▂▃▂▂▂▂▁▂▂▂▁▂▂▁▁▁▂▁▂▁▁
val/val_accuracy,▁▄▄▇▇▇▅█▇█
val/val_loss,█▅▅▂▁▂▃▁▄▂

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.15829
val/val_accuracy,0.8945
val/val_loss,0.3216


[34m[1mwandb[0m: Currently logged in as: [33mengrfaizan-ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Train Loss: 0.341, Valid Loss: 0.396059, Accuracy: 0.86
Train Loss: 0.320, Valid Loss: 0.368896, Accuracy: 0.87
Train Loss: 0.438, Valid Loss: 0.322818, Accuracy: 0.89
Train Loss: 0.184, Valid Loss: 0.326449, Accuracy: 0.88
Train Loss: 0.211, Valid Loss: 0.416485, Accuracy: 0.86
Train Loss: 0.128, Valid Loss: 0.311538, Accuracy: 0.90
Train Loss: 0.128, Valid Loss: 0.326986, Accuracy: 0.89
Train Loss: 0.059, Valid Loss: 0.346680, Accuracy: 0.90
Train Loss: 0.080, Valid Loss: 0.333361, Accuracy: 0.90
Train Loss: 0.098, Valid Loss: 0.399065, Accuracy: 0.89


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▆▆▅▄▄▄▃▃▄▃▃▂▃▂▃▁▂▂▂▂▂▁▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁
val/val_accuracy,▁▃▆▆▂█▇▇█▇
val/val_loss,▇▅▂▂█▁▂▃▂▇

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.09782
val/val_accuracy,0.8935
val/val_loss,0.39907


Train Loss: 0.472, Valid Loss: 0.402348, Accuracy: 0.86
Train Loss: 0.460, Valid Loss: 0.336740, Accuracy: 0.89
Train Loss: 0.359, Valid Loss: 0.333175, Accuracy: 0.89
Train Loss: 0.264, Valid Loss: 0.328039, Accuracy: 0.89
Train Loss: 0.332, Valid Loss: 0.304157, Accuracy: 0.90
Train Loss: 0.294, Valid Loss: 0.298192, Accuracy: 0.90
Train Loss: 0.225, Valid Loss: 0.293916, Accuracy: 0.90
Train Loss: 0.297, Valid Loss: 0.323904, Accuracy: 0.89
Train Loss: 0.191, Valid Loss: 0.300282, Accuracy: 0.90
Train Loss: 0.143, Valid Loss: 0.312663, Accuracy: 0.90


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▆▄▃▃▄▄▃▂▃▃▃▂▂▂▂▃▂▂▂▂▂▂▃▂▂▁▁▂▁▂▂▁▁▂▁▁▁▁
val/val_accuracy,▁▆▅▅███▆█▇
val/val_loss,█▄▄▃▂▁▁▃▁▂

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.14301
val/val_accuracy,0.8985
val/val_loss,0.31266
