In [2]:
import os
import torch
import random
import numpy as np
import torchvision as tv
import torch.nn.functional as F
import matplotlib.pyplot as plt
from src import utils, plots, metrics
from sklearn.metrics import roc_auc_score

In [3]:
utils.set_seed(2022)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.benchmarks = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

In [None]:
test_transform = tv.transforms.Compose([
    tv.transforms.ToTensor(),
])
dataset = tv.datasets.FashionMNIST('/tmp/data', train=True, download=True, transform=test_transform) # 60
loader = torch.utils.data.DataLoader(
    dataset, batch_size=10, num_workers=0, shuffle=False)

mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

In [3]:
batch_size = 128

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform, ood_type=None, distance=15, no_jsd=True):
        if isinstance(data, tuple):
            self.data_in, self.data_ood = data
        else:
            self.data_in = data
        self.ood_type = ood_type
        self.distance = distance
        self.no_jsd = no_jsd
        self.clean = transform
        
    def __getitem__(self, index):
        x, y = self.data_in[index]
        if self.ood_type == 'uniform':
            x_tensor = tv.transforms.ToTensor()(x)
            x_ood = torch.FloatTensor(x_tensor.shape).uniform_(x_tensor.min() - self.distance, x_tensor.max() + self.distance)
            x_ood = tv.transforms.ToPILImage()(x_ood).convert("RGB")
            x, x_ood = self.clean(x), self.clean(x_ood) if self.no_jsd else (self.clean(x), self.ood(x_ood), self._ood2(x_ood))
            x = (x, x_ood)
        elif self.ood_type == 'augmix':
            x = (self.clean(x), self.ood(x), self.ood2(x))
        elif self.ood_type == 'datafusion':
            x_ood, y_ood = self.data_ood[index]
            x = (self.clean(x), self.clean(x_ood)) if self.no_jsd else (self.clean(x), self.ood(x_ood), self.ood2(x_ood))
        else:
            x = self.clean(x)
        return x, y

    def __len__(self):
        return len(self.data_in)

# normalize data==============================
normalize_cifar10 = tv.transforms.Normalize(
    mean=(0.4914, 0.4822, 0.4465),
    std=(0.2023, 0.1994, 0.2010)
)

normalize_cifar100 = tv.transforms.Normalize(
    mean=(0.5071, 0.4867, 0.4408),
    std=(0.2675, 0.2565, 0.2761)
)

normalize_svhn = tv.transforms.Normalize(
    mean=(0.4376, 0.4437, 0.4728),
    std=(0.1980, 0.2010, 0.1970)
)

normalize_fmnist = tv.transforms.Normalize(
    mean=(0.2860, 0.2860, 0.2860),
    std=(0.3205, 0.3205, 0.3205)
)

# normalize data
normalize = normalize_svhn

test_transform = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    normalize
])

fmnist_transform = tv.transforms.Compose([
#     tv.transforms.Resize((32, 32)),
    tv.transforms.Grayscale(num_output_channels=3),
])

# download train datasets===================================================
cifar10 = tv.datasets.CIFAR10('/tmp/data', train=True, download=True) #50K
cifar100 = tv.datasets.CIFAR100('/tmp/data', train=True, download=True) # 50K
svhn = tv.datasets.SVHN('/tmp/data', split='train', download=True) # 70K
fmnist = tv.datasets.FashionMNIST('/tmp/data', train=True, download=True, transform=fmnist_transform) # 60K

# download test datasets===================================================
cifar10_test = tv.datasets.CIFAR10('/tmp/data', train=False, download=True, transform=test_transform)
cifar100_test = tv.datasets.CIFAR100('/tmp/data', train=False, download=True, transform=test_transform)
svhn_test = tv.datasets.SVHN('/tmp/data', split='test', download=True, transform=test_transform)
fmnist_test = tv.datasets.FashionMNIST('/tmp/data', train=False, download=True, transform=test_transform) # 60K

# choose train and test dataset
dataset_list = [cifar10, ] * 2
cifar10 = torch.utils.data.ConcatDataset(dataset_list)
train_data = (svhn, cifar10)
train_samples = 60000
valid_samples = len(svhn) - train_samples
test_data = svhn_test

# indomain data augmentations=========================================
train_transform = tv.transforms.Compose([
    tv.transforms.Resize(32),
    tv.transforms.RandomCrop(32, padding=4),
    tv.transforms.RandomHorizontalFlip(),
    tv.transforms.ToTensor(),
    normalize,
])

# setup dataloaders========================================================
transforms = train_transform
train_ood_data = torch.utils.data.ConcatDataset([cifar100, svhn, fmnist])
dataset = MyDataset(train_data, transforms, ood_type='datafusion')

train, valid = torch.utils.data.random_split(dataset, [train_samples, valid_samples])
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: /tmp/data/train_32x32.mat
Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: /tmp/data/test_32x32.mat


In [None]:
cifar100_test = torch.utils.data.DataLoader(
    tv.datasets.CIFAR100('/tmp/data', train=False, download=True, transform=test_transform),
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True)

In [4]:
from models.wideresnet import WideResNet28x10
model = WideResNet28x10(num_classes=10, dropout_rate=0.25)

# hyperparams
epochs = 300

# optim
lr = 0.05
momentum = 0.9
weight_decay = 5e-4

opt = torch.optim.SGD(model.parameters(), lr, momentum, weight_decay)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.95, patience=25, verbose=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=300, verbose=True)

# def get_lr(step, total_steps, lr_max, lr_min):
#     """Compute learning rate according to cosine annealing schedule."""
#     return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

# learning_rate = 0.1
# scheduler = torch.optim.lr_scheduler.LambdaLR(opt,
#           lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
#           step,
#           epochs * len(cifar_train),
#           1,  # lr_lambda computes multiplicative factor
#           1e-6 / learning_rate))

Adjusting learning rate of group 0 to 5.0000e-02.


### Train Cosine

In [None]:
def compute_cosine(model, inputs, labels):
    model.eval()
    inputs = torch.cat(inputs, dim=0).detach()
    labels = torch.cat([labels] * 2, dim=0).detach()
    logits = model(inputs)
    crossent = F.cross_entropy(logits, labels)
#     caclulate grad of loss w.r.t. logits
    grad = torch.autograd.grad(crossent, logits)[0]
    grad_in, grad_ood = grad.detach().chunk(2, dim=0)
    cosine = F.cosine_similarity(grad_in, grad_ood, dim=1).mean()
    model.train()
    return cosine

In [5]:
def train_cosine(model, opt, data_loader, criterion, device):
    n_samples, error, correct, = 0.0, 0.0, 0.0
    cosine_error, psnr_error = 0.0, 0.0
    zero = torch.tensor([0.0]).to(device)
    one = torch.tensor([1.0]).to(device)
    model.train()
    for x, y in data_loader:
        x, x_ood, y = map(lambda var: var.to(device), (x[0], x[1], y))
        bsize = x.size(0)

        opt.zero_grad()
        inputs = torch.cat([x, x_ood], dim=0)
        logits = model(inputs)
        probs = F.softmax(logits, dim=1)
        probs_in, probs_ood = probs.split(bsize)
        cosine = F.cosine_similarity(probs_in, probs_ood, dim=1).mean()
        logits_in, logits_ood = logits.split(bsize)
        xent = criterion(logits_in, y)
#         margin = torch.maximum(zero, cosine + one)
#         correct_prob = probs_in[range(bsize), y]
#         l2 = 0.08 * ((correct_prob - 0.9954)**2).sum()
#         l1 = 0.05 * (probs_ood - (1.0/10.0)).abs().sum()
    
#         loss = margin + l1 + l2
        loss = 2 * xent + (-1.0 * cosine)
        y_hat = logits_in.max(dim=1)[1]
        
        correct += y.eq(y_hat.view_as(y)).sum().item()
        error += bsize * loss.item()
        cosine_error += bsize * cosine.item()
#         psnr_error += bsize * l1.item()
        n_samples += bsize

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
        opt.step()
#         scheduler.step()
    
    avg_loss = error / n_samples
    avg_acc = correct / n_samples

    return avg_loss, avg_acc, cosine_error/n_samples, psnr_error/n_samples

In [None]:
from tqdm.autonotebook import tqdm

device = "cuda:0"
criterion = torch.nn.CrossEntropyLoss()
model.to(device)

train_loss, train_acc = [], []
valid_loss, valid_acc = [], []
margin, cosine = [], []
best_val_loss = 0.0

# Train network
for epoch in tqdm(range(epochs)):
    
#     newlr = utils.swa_schedule(epoch, lr_init=lr)
#     for group in opt.param_groups:
#         group['lr'] = newlr
#         print("new learning rate: ", newlr)
    
    tr_loss, tr_acc, cos, psnr = train_cosine(model, opt, train_loader, criterion, device)
    train_loss.append(tr_loss), train_acc.append(tr_acc), cosine.append(cos)
    print("Train:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tCosine: {:.4f},\tPSNR: {:.4f}".format(tr_loss, 100.0 * tr_acc, cos, psnr))
    
    val_loss, val_acc, val_margin, _ = utils.test(model, valid_loader, criterion, device)
    valid_loss.append(val_loss), valid_acc.append(val_acc), margin.append(val_margin)
    print("Test:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tValidation Margin: {:.4f}".format(val_loss, 100.0 * val_acc, val_margin))
    
    # Get bool not ByteTensor
    is_best = True if epoch == 0 else bool(val_loss < best_val_loss)
    # Get greater Tensor to keep track best acc
    operator = max if epoch == 0 else min
    best_val_loss = torch.FloatTensor([operator(val_loss, best_val_loss)])
    # Save checkpoint if is a new best
    utils.save_checkpoint({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optim_state_dict': opt.state_dict(),
        'valid_accuracy': valid_acc
    }, is_best, val_acc, filename='./checkpoints/alternative/wideres28x10_drop_svhn_xent_cosine_ood_cifar10.pth.tar')
    
    scheduler.step()

  0%|          | 0/300 [00:00<?, ?it/s]

Train:	Average Loss: 3.5453,	Accuracy: 18.59%,	Cosine: 0.9727,	PSNR: 0.0000
Test:	Average Loss: 2.2449,	Accuracy: 18.55%,	Validation Margin: -0.1301
=> Saving a new best, best_valid_acc: 0.18548691257448896
Adjusting learning rate of group 0 to 4.9999e-02.
Train:	Average Loss: 3.4868,	Accuracy: 19.21%,	Cosine: 0.9845,	PSNR: 0.0000
Test:	Average Loss: 2.2172,	Accuracy: 19.97%,	Validation Margin: -0.1965
=> Saving a new best, best_valid_acc: 0.19966809987176587
Adjusting learning rate of group 0 to 4.9995e-02.
Train:	Average Loss: 2.3982,	Accuracy: 46.02%,	Cosine: 0.6727,	PSNR: 0.0000
Test:	Average Loss: 0.8087,	Accuracy: 73.24%,	Validation Margin: -0.6254
=> Saving a new best, best_valid_acc: 0.7324432375348873
Adjusting learning rate of group 0 to 4.9988e-02.
Train:	Average Loss: 0.7585,	Accuracy: 81.66%,	Cosine: 0.4057,	PSNR: 0.0000
Test:	Average Loss: 0.5966,	Accuracy: 81.83%,	Validation Margin: -0.7185
=> Saving a new best, best_valid_acc: 0.8182846797918081
Adjusting learning rate 

### Results

In [None]:
epochs=256
min_valid_loss = min(valid_loss)
best_epoch, = np.where(np.array(valid_loss) == min_valid_loss)
max_valid_acc = valid_acc[best_epoch.item()]
%matplotlib inline
plt.plot(range(epochs), train_loss, label='train loss')
plt.plot(range(epochs), train_acc, label='train acc')
plt.plot(range(epochs), valid_loss, label='valid loss')
plt.plot(range(epochs), valid_acc, label='valid acc')
plt.plot(range(epochs), margin, label='margin')
plt.vlines(best_epoch.item(), ymin=min_valid_loss, ymax=max_valid_acc, colors='black', label='best-valid@{:.2f}%'.format(100.0 * max_valid_acc))
plt.title("Loss vs Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.grid()
plt.show()

In [None]:
chkpt = torch.load('./checkpoints/baseline/allconv_svhn.pth.tar')
model.load_state_dict(chkpt['model_state_dict'])

In [None]:
criterion = torch.nn.CrossEntropyLoss()
loss, acc, te_margin, logits_in = utils.test(model, test_loader, criterion, device)
print("Test:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tMargin: {:.4f}".format(loss, 100.0 * acc, te_margin))
loss, acc, ood_margin, logits_out = utils.test(model, cifar100_test, criterion, device, ood=True)
print("Test OoD:\tAverage Loss: {:.4f},\tAccuracy: {:.2f}%,\tMaring: {:.4f}".format(loss, 100.0 * acc, ood_margin))

In [None]:
logits_in = torch.vstack(logits_in)
logits_out = torch.vstack(logits_out)
metrics_in = metrics.dirichlet_uncertainty(logits_in.cpu().numpy())
metrics_out = metrics.dirichlet_uncertainty(logits_out.cpu().numpy())

In [None]:
y_test = len(logits_in)
y_ood = len(logits_out)
for string in ['confidence', 'entropy_of_conf', 'mutual_information']:
    if string == "confidence":
        y_true = np.r_[np.ones(y_test), np.zeros(y_ood)]
        y_scores = np.r_[metrics_in[string], metrics_out[string]]
    else:
        y_true = np.r_[np.ones(y_ood), np.zeros(y_test)]
        y_scores = np.r_[metrics_out[string], metrics_in[string]]
    print("ROC values:\n {} = {},\n".format(string, roc_auc_score(y_true, y_scores)))

In [None]:
%matplotlib inline
f, ax = plt.subplots(1, 3, figsize=(14, 5))

for i, string in enumerate(['confidence', 'entropy_of_conf', 'mutual_information']):
#     ax[i].set_xscale('log')
    ax[i].hist(metrics_in[string], bins=np.linspace(0, 3, num=10), label='Test', alpha=0.5)
    ax[i].hist(metrics_out[string], bins=np.linspace(0, 3, num=10), label='OoD', alpha=0.5)
    ax[i].legend()
    ax[i].set_title(string)

In [None]:
X, y = next(iter(test_loader))
X, y = X.to(device), y.to(device)

In [None]:
%matplotlib widget
# model.to(device)
model.eval()
Xi, Yi, Zi = utils.draw_loss(model, X[0:1], y[0:1], device=device)
plots.plot_loss(Xi, Yi, Zi)