In [5]:
import torch
from torch import nn, optim, utils
import lenet5_cifar10 as models
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

In [6]:
transform = transforms.ToTensor()
root = "./CIFAR10_DATASET"
train_dataset = datasets.CIFAR10(root, transform=transform, train=True, download=True)
train_dataset, valid_dataset = utils.data.random_split(train_dataset, [40000, 10000])
test_dataset = datasets.CIFAR10(root, transform=transform, train=False, download=True)

print("Train : ", len(train_dataset))
print("Validation : ", len(valid_dataset))
print("Test : ", len(test_dataset))

train_batchsize = 64
test_batchsize = 256

train_dataloader = utils.data.DataLoader(train_dataset, batch_size=train_batchsize, shuffle=True)
valid_dataloader = utils.data.DataLoader(valid_dataset, batch_size=test_batchsize, shuffle=False)
test_dataloader = utils.data.DataLoader(test_dataset, batch_size=test_batchsize, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified
Train :  40000
Validation :  10000
Test :  10000


In [17]:
def cifar10_experiment(bnn_model, train_dataloader, valid_dataloader, use_aleatoric=True, **kwargs):
  if bnn_model == "gaussian":
    model = models.gaussian_lenet5(kwargs["var_type"], dict(), kwargs["is_lrt"], use_aleatoric)
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    # Every forward gives different output.
    bnn_type = "random"
  elif bnn_model == "dropout":
    model = models.dropout_lenet5(kwargs["dropout_rate"], kwargs["dropout_type"], dict(), use_aleatoric)
    model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"], weight_decay=(1 - kwargs["dropout_rate"]) / (2 * kwargs["num_sample"] * 10))
    # Every forward gives different output.
    bnn_type = "random"
  elif bnn_model == "ensemble":
    model = models.ensemble_lenet5(kwargs["num_ensemble"], use_aleatoric)
    model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    # Forward gives [batch_size * num_ensemble, output_shape]
    bnn_type = "ensemble"
  elif bnn_model == "swag":
    model = models.lenet5(use_aleatoric)
    model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    # SWAG is trained as simple NN.
    bnn_type = "swag"
  elif bnn_model == "batchensemble":
    model = models.batchensemble_lenet5(kwargs["num_models"], use_aleatoric)
    model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
    # Forward gives [batch_size * num_models, output_shape]
    bnn_type = "ensemble"
    kwargs["num_ensemble"] = kwargs["num_models"]
  else:
    raise ValueError("No bnn model choosen.")
  scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
  
  criterion = nn.CrossEntropyLoss()
  train_loss_res = []
  valid_loss_res = []
  train_acc_res = []
  valid_acc_res = []
  for epoch in range(kwargs["epoch"]):
    train_loss = 0.0
    train_acc_count = 0
    model.train()
    for data in train_dataloader:
      images, labels = data
      images = images.cuda()
      labels = labels.cuda()
      optimizer.zero_grad()
      if bnn_type == "random":
        loss = 0
        for ind_sample in range(kwargs["num_sample"]):
          outputs = model(images)
          if use_aleatoric:
            output_mean, output_std = torch.chunk(outputs, 2, dim=1)
            eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
            outputs = output_mean + eps * output_std
          output_pred = torch.argmax(outputs.detach(), dim=1)
          train_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["num_sample"]
          loss += criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() / kwargs["num_sample"]
          
      elif bnn_type == "ensemble":
        outputs = model(images)
        if use_aleatoric:
          output_mean, output_std = torch.chunk(outputs, 2, dim=1)
          eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
          outputs = output_mean + eps * output_std
        labels = labels.repeat(kwargs["num_ensemble"]) # [y1, y2, ..., y1, y2, ..., ] with num_ensemble times
        loss = criterion(outputs, labels) * kwargs["num_ensemble"]
        loss.backward()
        optimizer.step()

        train_loss += loss.item() / kwargs["num_ensemble"]
        output_pred = torch.argmax(outputs.detach(), dim=1)
        train_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["num_ensemble"]
      else:
        outputs = model(images)
        if use_aleatoric:
          output_mean, output_std = torch.chunk(outputs, 2, dim=1)
          eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
          outputs = output_mean + eps * output_std
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        output_pred = torch.argmax(outputs.detach(), dim=1)
        train_acc_count += torch.count_nonzero(output_pred == labels).item()
    train_loss_res.append(train_loss)
    train_acc_res.append(train_acc_count / len(train_dataset))
    scheduler.step()

    model.eval()
    valid_loss = 0.0
    valid_acc_count = 0
    with torch.no_grad():
      for data in valid_dataloader:
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()
        if bnn_type == "random":
          for ind_sample in range(kwargs["valid_num_sample"]):
            outputs = model(images)
            if use_aleatoric:
              output_mean, output_std = torch.chunk(outputs, 2, dim=1)
              eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
              outputs = output_mean + eps * output_std
            loss = criterion(outputs, labels)
            
            valid_loss += loss.item() / kwargs["valid_num_sample"]
            output_pred = torch.argmax(outputs.detach(), dim=1)
            valid_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["valid_num_sample"]
        elif bnn_type == "ensemble":
          outputs = model(images)
          if use_aleatoric:
            output_mean, output_std = torch.chunk(outputs, 2, dim=1)
            eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
            outputs = output_mean + eps * output_std
          labels = labels.repeat(kwargs["num_ensemble"]) # [y1, y2, ..., y1, y2, ..., ] with num_ensemble times
          loss = criterion(outputs, labels) * kwargs["num_ensemble"]
          valid_loss += loss.item() / kwargs["num_ensemble"]
          output_pred = torch.argmax(outputs.detach(), dim=1)
          valid_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["num_ensemble"]
        else:
          outputs = model(images)
          if use_aleatoric:
            output_mean, output_std = torch.chunk(outputs, 2, dim=1)
            eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
            outputs = output_mean + eps * output_std
          loss = criterion(outputs, labels)
          valid_loss += loss.item()
          output_pred = torch.argmax(outputs.detach(), dim=1)
          valid_acc_count += torch.count_nonzero(output_pred == labels).item()
    valid_loss_res.append(valid_loss)
    valid_acc_res.append(valid_acc_count / len(valid_dataset))
    print(f"Epoch {epoch+1} ended")
    print(f"Training   loss/acc : {train_loss:.3f}/{train_acc_count / len(train_dataset):.3f}")
    print(f"Validation loss/acc : {valid_loss:.3f}/{valid_acc_count / len(valid_dataset):.3f}")
  
  return model, train_loss_res, train_acc_res, valid_loss_res, valid_acc_res

In [22]:
d_model, tr_loss, tr_ac, val_loss, val_ac = cifar10_experiment("dropout", train_dataloader, valid_dataloader, dropout_rate=0.1, dropout_type='f', lr=0.01, momentum=0.9, epoch=200, num_sample=3, valid_num_sample=3)

Epoch 1 ended
Training   loss/acc : 1316.858/0.214
Validation loss/acc : 77.606/0.277
Epoch 2 ended
Training   loss/acc : 1210.018/0.294
Validation loss/acc : 76.893/0.306
Epoch 3 ended
Training   loss/acc : 1161.193/0.327
Validation loss/acc : 71.945/0.348
Epoch 4 ended
Training   loss/acc : 1133.600/0.345
Validation loss/acc : 69.929/0.377
Epoch 5 ended
Training   loss/acc : 1098.421/0.364
Validation loss/acc : 71.279/0.359
Epoch 6 ended
Training   loss/acc : 1071.249/0.375
Validation loss/acc : 69.215/0.371
Epoch 7 ended
Training   loss/acc : 1046.353/0.394
Validation loss/acc : 68.133/0.389
Epoch 8 ended
Training   loss/acc : 1024.819/0.406
Validation loss/acc : 66.114/0.412
Epoch 9 ended
Training   loss/acc : 1003.203/0.418
Validation loss/acc : 65.713/0.410
Epoch 10 ended
Training   loss/acc : 987.155/0.429
Validation loss/acc : 63.090/0.436
Epoch 11 ended
Training   loss/acc : 967.304/0.438
Validation loss/acc : 63.520/0.431
Epoch 12 ended
Training   loss/acc : 959.180/0.446
Val

In [24]:
torch.save(d_model, "dropout.pt")