In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
import random
import numpy as np
from PIL import Image
import torchvision as tv
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
from sklearn.metrics import roc_auc_score
from src import utils, plots, lossfunc, models, metrics

In [3]:
nn = torch.nn
F = torch.nn.functional
tvt = tv.transforms
tvd = tv.datasets
tu = torch.utils
tud = torch.utils.data

In [4]:
utils.set_seed(2022)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.benchmarks = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

In [5]:
# hyperparams
num_classes = 10
batch_size = 128
train_samples = 45000
valid_samples = 5000
epochs = 300
lr = 0.05
momentum = 0.9
weight_decay = 5e-4

### Data

In [6]:
standardize = tvt.Normalize(
    mean = (0.4914, 0.4822, 0.4465),
    std = (0.2023, 0.1994, 0.2010)
)

train_transform = tvt.Compose([
    tvt.Resize(32),
    tvt.RandomCrop(32, padding=4),
    tvt.RandomHorizontalFlip(),
    tvt.RandomApply([tvt.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5),
    tvt.RandomGrayscale(p=0.5),
    tvt.RandomApply(nn.ModuleList([tvt.RandomRotation((0, 360))]), p=0.5),
    tvt.ToTensor(),
    standardize,
])

test_transform = tvt.Compose([
    tvt.ToTensor(),
    standardize,
])

# Train
train_data = tvd.CIFAR10('/tmp/data', train=True, download=True, transform=train_transform) #50K

train, valid = tud.random_split(train_data, [train_samples, valid_samples])

train_loader = tud.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

valid_loader = tud.DataLoader(valid, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# Test
test_data = tvd.CIFAR10('/tmp/data', train=False, download=True, transform=test_transform)

test_loader = tud.DataLoader(test_data, batch_size=batch_size, shuffle=True,num_workers=2, pin_memory=True)

# OOD
fake_data = tvd.FakeData(size=train_samples, image_size=(3, 32, 32),
                         num_classes=num_classes, transform=train_transform)

fake_loader = tud.DataLoader(fake_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

fake_valid_data = tvd.FakeData(size=valid_samples, image_size=(3, 32, 32), num_classes=num_classes, transform=train_transform)

fake_valid_loader = tud.DataLoader(fake_valid_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

ood_data = tvd.CIFAR100('/tmp/data', train=False, download=True, transform=test_transform)

ood_loader = tud.DataLoader(ood_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Model

In [7]:
device = "cuda:0"
model = models.WideResNet28x10(num_classes=num_classes, dropout_rate=0.25)
criterion = lossfunc.contrastive_regularized
opt = torch.optim.SGD(model.parameters(), lr, momentum, weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=300, verbose=True)
model.to(device)

Adjusting learning rate of group 0 to 5.0000e-02.


WideResNet28x10(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (layer1): Sequential(
    (0): WideBasic(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(16, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Dropout(p=0.25, inplace=False)
      (bn2): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (shortcut): Sequential(
        (0): Conv2d(16, 160, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (1): WideBasic(
      (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Dropout(p=0.25, inplace=False)
      (bn2): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): C

### Train

In [9]:
train_loss, train_acc = [], []
valid_loss, valid_acc = [], []
margin, cosine = [], []
best_val_loss = 0.0
data_loader = (train_loader, fake_loader)
validation_loader = (valid_loader, fake_valid_loader)

# Train
for epoch in tqdm(range(epochs)):
    tr_loss, tr_acc, cos = utils.train(model, opt, data_loader, criterion, device)
    train_loss.append(tr_loss), train_acc.append(tr_acc), cosine.append(cos)
    print("Train:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tCosine: {:.4f}".format(tr_loss, 100.0 * tr_acc, cos))
    
    val_loss, val_acc, val_margin, _ = utils.test(model, validation_loader, criterion, device)
    valid_loss.append(val_loss), valid_acc.append(val_acc), margin.append(val_margin)
    print("Test:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tValidation Margin: {:.4f}".format(val_loss, 100.0 * val_acc, val_margin))
    
    # Get bool not ByteTensor
    is_best = True if epoch == 0 else bool(val_loss < best_val_loss)
    
    # Get greater Tensor to keep track best acc
    operator = max if epoch == 0 else min
    best_val_loss = torch.FloatTensor([operator(val_loss, best_val_loss)])
    
    # Save checkpoint if val_loss is a new best
    utils.save_checkpoint({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optim_state_dict': opt.state_dict(),
        'valid_accuracy': valid_acc
    }, is_best, val_acc, filename='{}_{}_{}_{}.pth.tar'.format(
        model.__str__(), 
        train_data.__class__.__name__, 
        criterion.__name__, 
        datetime.now().strftime("%d-%m-%Y-%H:%M:%S")))
    
    scheduler.step()

  0%|          | 0/300 [00:00<?, ?it/s]

Train:	Average Loss: 3.2577,	Accuracy: 22.69%,	Cosine: 0.8527
Test:	Average Loss: 3.4466,	Accuracy: 23.18%,	Validation Margin: -0.3363
=> Saving a new best, best_valid_acc: 0.23177083333333334
Adjusting learning rate of group 0 to 4.9999e-02.
Train:	Average Loss: 3.1360,	Accuracy: 27.22%,	Cosine: 0.8147
Test:	Average Loss: 3.5535,	Accuracy: 24.70%,	Validation Margin: -0.5510
Adjusting learning rate of group 0 to 4.9995e-02.


KeyboardInterrupt: 

### Results

In [None]:
min_valid_loss = min(valid_loss)
best_epoch, = np.where(np.array(valid_loss) == min_valid_loss)
max_valid_acc = valid_acc[best_epoch.item()]
plt.plot(range(epochs), train_loss, label='train loss')
plt.plot(range(epochs), train_acc, label='train acc')
plt.plot(range(epochs), valid_loss, label='valid loss')
plt.plot(range(epochs), valid_acc, label='valid acc')
plt.plot(range(epochs), margin, label='margin')
plt.vlines(best_epoch.item(), ymin=min_valid_loss, ymax=max_valid_acc, colors='black', label='best-valid@{:.2f}%'.format(100.0 * max_valid_acc))
plt.title("Loss vs Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.grid()
plt.show()

In [None]:
chkpt = torch.load('./chkpts/WideResNet28x10_CIFAR10_contrastive_regularised_26-12-2022-19:58:52.pth.tar')
model.load_state_dict(chkpt['model_state_dict'])

In [None]:
loss, acc, te_margin, logits_in = utils.test(model, test_loader, lossfunc.cross_entropy, device)
print("Test:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tMargin: {:.4f}".format(loss, 100.0 * acc, te_margin))
loss, acc, ood_margin, logits_out = utils.test(model, ood_loader, lossfunc.cross_entropy, device)
print("Test OoD:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tMargin: {:.4f}".format(loss, 100.0 * acc, ood_margin))

In [None]:
logits_in = torch.vstack(logits_in)
logits_out = torch.vstack(logits_out)
metrics_in = metrics.dirichlet_uncertainty(logits_in.cpu().numpy())
metrics_out = metrics.dirichlet_uncertainty(logits_out.cpu().numpy())

In [None]:
y_test = len(logits_in)
y_ood = len(logits_out)
for string in ['confidence', 'entropy_of_conf', 'mutual_information']:
    if string == "confidence":
        y_true = np.r_[np.ones(y_test), np.zeros(y_ood)]
        y_scores = np.r_[metrics_in[string], metrics_out[string]]
    else:
        y_true = np.r_[np.ones(y_ood), np.zeros(y_test)]
        y_scores = np.r_[metrics_out[string], metrics_in[string]]
    print("ROC values:\n {} = {},\n".format(string, roc_auc_score(y_true, y_scores)))

In [None]:
f, ax = plt.subplots(1, 3, figsize=(14, 5))

for i, string in enumerate(['confidence', 'entropy_of_conf', 'mutual_information']):
#     ax[i].set_xscale('log')
    ax[i].hist(metrics_in[string], bins=np.linspace(0, 3, num=10), label='Test', alpha=0.5)
    ax[i].hist(metrics_out[string], bins=np.linspace(0, 3, num=10), label='OoD', alpha=0.5)
    ax[i].legend()
    ax[i].set_title(string)

In [None]:
X, y = next(iter(test_loader))
X, y = X.to(device), y.to(device)

In [None]:
# %matplotlib widget
model.eval()
Xi, Yi, Zi = utils.draw_loss(model, X[0:1], y[0:1], device=device)
plots.plot_loss(Xi, Yi, Zi)