In [1]:
import torch
from torch import optim, nn, utils
import numpy as np
import layers
import lenet_mnist as models
from torchvision import datasets, transforms

seed = 2022
torch.manual_seed(2022)

<torch._C.Generator at 0x1f207135690>

In [2]:
transform = transforms.ToTensor()
root = "./MNIST_DATASET"
train_dataset = datasets.MNIST(root, transform=transform, train=True, download=True)
train_dataset, valid_dataset = utils.data.random_split(train_dataset, [50000, 10000])
test_dataset = datasets.MNIST(root, transform=transform, train=False, download=True)

print("Train : ", len(train_dataset))
print("Validation : ", len(valid_dataset))
print("Test : ", len(test_dataset))

train_batchsize = 64
test_batchsize = 256

train_dataloader = utils.data.DataLoader(train_dataset, batch_size=train_batchsize, shuffle=True)
valid_dataloader = utils.data.DataLoader(valid_dataset, batch_size=test_batchsize, shuffle=False)
test_dataloader = utils.data.DataLoader(test_dataset, batch_size=test_batchsize, shuffle=False)
for data in train_dataloader:
  image, labels = data
  print(image.shape)
  print(labels.shape)
  break

Train :  50000
Validation :  10000
Test :  10000
torch.Size([64, 1, 28, 28])
torch.Size([64])


In [3]:
def mnist_experiment(bnn_model, train_dataloader, valid_dataloader, use_aleatoric=True, **kwargs):
  if bnn_model == "gaussian":
    model = models.gaussian_lenet5(kwargs["var_type"], dict(), kwargs["is_lrt"], use_aleatoric)
    if torch.cuda.is_available():
      model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
    # Every forward gives different output.
    bnn_type = "random"
  elif bnn_model == "dropout":
    model = models.dropout_lenet5(kwargs["dropout_rate"], kwargs["dropout_type"], dict(), use_aleatoric)
    if torch.cuda.is_available():
      model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"], weight_decay=(1 - kwargs["dropout_rate"]) / 2 * kwargs["num_sample"])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
    # Every forward gives different output.
    bnn_type = "random"
  elif bnn_model == "ensemble":
    model = models.ensemble_lenet5(kwargs["num_ensemble"], use_aleatoric)
    if torch.cuda.is_available():
      model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
    # Forward gives [batch_size * num_ensemble, output_shape]
    bnn_type = "ensemble"
  elif bnn_model == "swag":
    model = models.lenet5(use_aleatoric)
    if torch.cuda.is_available():
      model = model.cuda()
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
    # SWAG is trained as simple NN.
    bnn_type = "swag"
  elif bnn_model == "batchensemble":
    model = models.batchensemble_lenet5(kwargs["num_models"], if torch.cuda.is_available():
      model = model.cuda()
    use_aleatoric)
    optimizer = optim.SGD(model.parameters(), kwargs["lr"], kwargs["momentum"])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1 ** (1/100))
    # Forward gives [batch_size * num_models, output_shape]
    bnn_type = "ensemble"
    kwargs["num_ensemble"] = kwargs["num_models"]
  else:
    raise ValueError("No bnn model choosen.")

  criterion = nn.CrossEntropyLoss()
  train_loss_res = []
  valid_loss_res = []
  train_acc_res = []
  valid_acc_res = []
  for epoch in range(kwargs["epoch"]):
    train_loss = 0.0
    train_acc_count = 0
    model.train()
    for data in train_dataloader:
      images, labels = data
      if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
      if bnn_type == "random":
        for ind_sample in range(kwargs["num_sample"]):
          outputs = model(images)
          if use_aleatoric:
            output_mean, output_std = torch.chunk(outputs, 2, dim=1)
            eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
            outputs = output_mean + eps * output_std
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          train_loss += loss.item() / kwargs["num_sample"]
          output_pred = torch.argmax(outputs.detach(), dim=1)
          train_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["num_sample"]
          
      elif bnn_type == "ensemble":
        outputs = model(images)
        if use_aleatoric:
          output_mean, output_std = torch.chunk(outputs, 2, dim=1)
          eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
          outputs = output_mean + eps * output_std
        labels = labels.repeat(kwargs["num_ensemble"]) # [y1, y2, ..., y1, y2, ..., ] with num_ensemble times
        loss = criterion(outputs, labels) * kwargs["num_ensemble"]
        loss.backward()
        optimizer.step()

        train_loss += loss.item() / kwargs["num_ensemble"]
        output_pred = torch.argmax(outputs.detach(), dim=1)
        train_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["num_ensemble"]
      else:
        outputs = model(images)
        if use_aleatoric:
          output_mean, output_std = torch.chunk(outputs, 2, dim=1)
          eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
          outputs = output_mean + eps * output_std
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        output_pred = torch.argmax(outputs.detach(), dim=1)
        train_acc_count += torch.count_nonzero(output_pred == labels).item()
    train_loss_res.append(train_loss)
    train_acc_res.append(train_acc_count / len(train_dataset))
    scheduler.step()

    model.eval()
    valid_loss = 0.0
    valid_acc_count = 0
    with torch.no_grad():
      for data in valid_dataloader:
        images, labels = data
        if torch.cuda.is_available():
          images = images.cuda()
          labels = labels.cuda()
        if bnn_type == "random":
          for ind_sample in range(kwargs["valid_num_sample"]):
            outputs = model(images)
            if use_aleatoric:
              output_mean, output_std = torch.chunk(outputs, 2, dim=1)
              eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
              outputs = output_mean + eps * output_std
            loss = criterion(outputs, labels)
            
            valid_loss += loss.item() / kwargs["valid_num_sample"]
            output_pred = torch.argmax(outputs.detach(), dim=1)
            valid_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["valid_num_sample"]
        elif bnn_type == "ensemble":
          outputs = model(images)
          if use_aleatoric:
            output_mean, output_std = torch.chunk(outputs, 2, dim=1)
            eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
            outputs = output_mean + eps * output_std
          labels = labels.repeat(kwargs["num_ensemble"]) # [y1, y2, ..., y1, y2, ..., ] with num_ensemble times
          loss = criterion(outputs, labels) * kwargs["num_ensemble"]
          valid_loss += loss.item() / kwargs["num_ensemble"]
          output_pred = torch.argmax(outputs.detach(), dim=1)
          valid_acc_count += torch.count_nonzero(output_pred == labels).item() / kwargs["num_ensemble"]
        else:
          outputs = model(images)
          if use_aleatoric:
            output_mean, output_std = torch.chunk(outputs, 2, dim=1)
            eps = torch.normal(0, 1, output_mean.shape, device='cuda' if torch.cuda.is_available() else 'cpu')
            outputs = output_mean + eps * output_std
          loss = criterion(outputs, labels)
          valid_loss += loss.item()
          output_pred = torch.argmax(outputs.detach(), dim=1)
          valid_acc_count += torch.count_nonzero(output_pred == labels).item()
    valid_loss_res.append(valid_loss)
    valid_acc_res.append(valid_acc_count / len(valid_dataset))
  
  return model, train_loss_res, train_acc_res, valid_loss_res, valid_acc_res

In [4]:
res = dict()
g_model_sq, tr_los, tr_ac, te_los, te_ac = mnist_experiment("gaussian", train_dataloader, valid_dataloader, var_type='sq', is_lrt=True, num_sample=5, valid_num_sample=20, epoch=5, lr=0.00001, momentum=0.1)
res["g_sq"] = (tr_los, tr_ac, te_los, te_ac)
model, tr_los, tr_ac, te_los, te_ac = mnist_experiment("gaussian", train_dataloader, valid_dataloader, var_type='exp', is_lrt=True, num_sample=5, valid_num_sample=20, epoch=5, lr=0.00001, momentum=0.1)
res["g_exp"] = (tr_los, tr_ac, te_los, te_ac)

In [6]:
res["d_w"] = (tr_los, tr_ac, te_los, te_ac)
d_model_f, tr_los, tr_ac, te_los, te_ac = mnist_experiment("dropout", train_dataloader, valid_dataloader, dropout_rate=0.2, dropout_type='f', num_sample=5, valid_num_sample=20, epoch=5, lr=0.00001, momentum=0.1)
res["d_f"] = (tr_los, tr_ac, te_los, te_ac)
d_model_c, tr_los, tr_ac, te_los, te_ac = mnist_experiment("dropout", train_dataloader, valid_dataloader, dropout_rate=0.2, dropout_type='c', num_sample=5, valid_num_sample=20, epoch=5, lr=0.00001, momentum=0.1)
res["d_c"] = (tr_los, tr_ac, te_los, te_ac)

KeyboardInterrupt: 

In [None]:
e_model, tr_los, tr_ac, te_los, te_ac = mnist_experiment("ensemble", train_dataloader, valid_dataloader, num_ensemble = 5, epoch=5, lr=0.00001, momentum=0.1)
res["e"] = (tr_los, tr_ac, te_los, te_ac)

In [None]:
model, tr_los, tr_ac, te_los, te_ac = mnist_experiment("swag", train_dataloader, valid_dataloader, epoch=5, lr=0.00001, momentum=0.1)
res["s"] = (tr_los, tr_ac, te_los, te_ac)

In [None]:
b_model, tr_los, tr_ac, te_los, te_ac = mnist_experiment("batchensemble", train_dataloader, valid_dataloader, num_models=5, epoch=5, lr=0.00001, momentum=0.1)

In [None]:
for k, v in res.items():
  print(k, v[0])