In [166]:
!pip install pytorch-lightning -q
!pip install wandb -q

In [168]:
import numpy as np
import torch
import os
import argparse
import time
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
from tqdm import tqdm
import pytorch_lightning as pl
import torchmetrics
import math
import wandb
import sklearn.metrics
import pandas as pd
from pathlib import Path

from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy, MetricCollection, Precision, Recall

In [36]:
wandb.login()
wandb.init(project="vos-debias", entity="ai-hacks")



VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [75]:
batch_size = 128
learning_rate = 0.003
epochs = 100
momentum = 0.9
decay = 0.0001

num_layers = 22
widen_factor = 2
droprate = 0.3

# energy reg
start_epoch = 40
sample_number = 1000
select = 1
sample_from = 10000
loss_weight = 0.1

In [38]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "/content", batch_size: int = batch_size, val_test_ratio=0.35, num_workers = os.cpu_count()):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.val_test_ratio = val_test_ratio

        mean, std = (0.491, 0.482, 0.446), (0.247, 0.243, 0.261)
        self.train_transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(), 
             transforms.RandomCrop(32, padding=4), 
             transforms.ToTensor(), 
             transforms.Normalize(mean, std)]
        )
        self.test_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
        )

        self.num_workers = num_workers

    def setup(self, stage: str):
        self.train_data = datasets.CIFAR10(self.data_dir, transform=self.train_transform, download=True, train=True)
        test_data = datasets.CIFAR10(self.data_dir, transform = self.test_transform, download=True, train=False)
        val_length = int(len(test_data) * self.val_test_ratio)
        test_length = len(test_data) - val_length
        self.val_data, self.test_data = random_split(test_data, [val_length, test_length])
        
    def train_dataloader(self):
        return DataLoader(self.train_data,num_workers = self.num_workers, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_data, num_workers = self.num_workers,batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_data, num_workers = self.num_workers,batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.test_data, num_workers = self.num_workers,batch_size=self.batch_size)

In [39]:
# Wide Resnet 
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        if self.equalInOut:
            out = self.relu2(self.bn2(self.conv1(out)))
        else:
            out = self.relu2(self.bn2(self.conv1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        if not self.equalInOut:
            return torch.add(self.convShortcut(x), out)
        else:
            return torch.add(x, out)


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)

    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) // 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

    def forward_virtual(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out), out

    def intermediate_forward(self, x, layer_index):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        return out
    
    def feature_list(self, x):
        out_list = [] 
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out_list.append(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out), out_list

In [120]:
class VOSModel(pl.LightningModule):
    def __init__(self, num_classes = 10, threshold = 0.95, learning_rate = learning_rate, momentum = momentum, decay = decay):
        super().__init__()
        self.num_classes = num_classes

        self.data_dict = torch.zeros(self.num_classes, sample_number, 128).cuda()
        self.number_dict = {i:0 for i in range(self.num_classes)}

        self.threshold = threshold
        self.validation_accuracy = torchmetrics.Accuracy(self.threshold)
        self.test_accuracy = torchmetrics.Accuracy(self.threshold)

        self.model = WideResNet(num_layers, num_classes, widen_factor, dropRate=droprate)
        self.weight_energy = torch.nn.Linear(num_classes, 1)
        torch.nn.init.uniform_(self.weight_energy.weight)
        self.logistic_regression = torch.nn.Linear(1, 2)
        
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.decay = decay

        metrics = MetricCollection([Accuracy(), Precision(), Recall()])
        self.train_metrics = metrics.clone(prefix='train_')
        self.valid_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')

    def forward(self, x):
        return self.log_sum_exp(self.model(x), 1)

    def configure_optimizers(self):
        params = list(self.model.parameters()) + list(self.weight_energy.parameters()) + list(self.logistic_regression.parameters())
        optimizer = torch.optim.SGD(params,self.learning_rate, momentum=self.momentum,weight_decay=self.decay, nesterov=True)
        return optimizer
        # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
        #     lr_lambda = lambda step: VOSModel.cosine_annealing(step, epochs * len(), 1, 1e-6 / self.learning_rate))
        # return [optimizer], [scheduler]
    
    def log_sum_exp(self, value, dim=None, keepdim=False):
        """Numerically stable implementation of the operation value.exp().sum(dim, keepdim).log()
        """
        # TODO: torch.max(value, dim=None) threw an error at time of writing
        if dim is not None:
            m, _ = torch.max(value, dim=dim, keepdim=True)
            value0 = value - m
            if keepdim is False:
                m = m.squeeze(dim)
            return m + torch.log(torch.sum(
                F.relu(self.weight_energy.weight) * torch.exp(value0), dim=dim, keepdim=keepdim))
        else:
            m = torch.max(value)
            sum_exp = torch.sum(torch.exp(value - m))
            return m + torch.log(sum_exp)

    @staticmethod
    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

    def training_step(self, batch, batch_idx):
        data, target = batch

        # forward
        x, output = self.model.forward_virtual(data)

        # energy regularization
        sum_temp = 0
        for index in range(self.num_classes):
            sum_temp = self.number_dict[index]
        lr_reg_loss = torch.zeros(1).cuda()[0]

        if sum_temp == self.num_classes * sample_number and epochs < start_epoch:
            # maintaining an ID data queue for each class.
            target_numpy = target.cpu().data.numpy()
            for index in range(len(target)):
                dict_key = target_numpy[index]
                self.data_dict[dict_key] = torch.cat((self.data_dict[dict_key][1:], output[index].detach().view(1, -1)), 0)
        elif sum_temp == self.num_classes * sample_number and epochs >= start_epoch:
            target_numpy = target.cpu().data.numpy()
            for index in range(len(target)):
                dict_key = target_numpy[index]
                self.data_dict[dict_key] = torch.cat((self.data_dict[dict_key][1:],
                                                      output[index].detach().view(1, -1)), 0)
            # the covariance finder needs the data to be centered.
            for index in range(self.num_classes):
                if index == 0:
                    X = self.data_dict[index] - self.data_dict[index].mean(0)
                    mean_embed_id = self.data_dict[index].mean(0).view(1, -1)
                else:
                    X = torch.cat((X, self.data_dict[index] - self.data_dict[index].mean(0)), 0)
                    mean_embed_id = torch.cat((mean_embed_id,
                                               self.data_dict[index].mean(0).view(1, -1)), 0)

            ## add the variance.
            temp_precision = torch.mm(X.t(), X) / len(X)
            temp_precision += 0.0001 * torch.eye(128, device='cuda')


            for index in range(self.num_classes):
                new_dis = torch.distributions.multivariate_normal.MultivariateNormal(
                    mean_embed_id[index], covariance_matrix=temp_precision)
                negative_samples = new_dis.rsample((sample_from,))
                prob_density = new_dis.log_prob(negative_samples)
                # breakpoint()
                # index_prob = (prob_density < - self.threshold).nonzero().view(-1)
                # keep the data in the low density area.
                cur_samples, index_prob = torch.topk(- prob_density, select)
                if index == 0:
                    ood_samples = negative_samples[index_prob]
                else:
                    ood_samples = torch.cat((ood_samples, negative_samples[index_prob]), 0)
            if len(ood_samples) != 0:
                # add some gaussian noise
                # ood_samples = self.noise(ood_samples)
                # energy_score_for_fg = 1 * torch.logsumexp(predictions[0][selected_fg_samples][:, :-1] / 1, 1)
                energy_score_for_fg = self.log_sum_exp(x, 1)
                predictions_ood = self.model.fc(ood_samples)
                # energy_score_for_bg = 1 * torch.logsumexp(predictions_ood[0][:, :-1] / 1, 1)
                energy_score_for_bg = self.log_sum_exp(predictions_ood, 1)

                input_for_lr = torch.cat((energy_score_for_fg, energy_score_for_bg), -1)
                labels_for_lr = torch.cat((torch.ones(len(output)).cuda(),
                                           torch.zeros(len(ood_samples)).cuda()), -1)

                criterion = torch.nn.CrossEntropyLoss()
                output1 = self.logistic_regression(input_for_lr.view(-1, 1))
                lr_reg_loss = criterion(output1, labels_for_lr.long())

                if epochs % 5 == 0:
                    print(lr_reg_loss)
        else:
            target_numpy = target.cpu().data.numpy()
            for index in range(len(target)):
                dict_key = target_numpy[index]
                if self.number_dict[dict_key] < sample_number:
                    self.data_dict[dict_key][self.number_dict[dict_key]] = output[index].detach()
                    self.number_dict[dict_key] += 1
        
        loss = F.cross_entropy(x, target) + loss_weight * lr_reg_loss
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self.model.forward(data)
        loss = F.cross_entropy(output, target)

        # Accuracy
        metrics = self.valid_metrics(output, target)
        self.log_dict(metrics,  on_epoch=True)
        return metrics

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self.model.forward(data)
        loss = F.cross_entropy(output, target)

        # Accuracy
        metrics = self.test_metrics(output, target)
        self.log_dict(metrics, on_epoch=True)
        return metrics

In [41]:
class WideResNetModule(pl.LightningModule):
    def __init__(self, num_classes = 10, threshold = 0.95, learning_rate = learning_rate, momentum = momentum, decay = decay):
        super().__init__()
        self.num_classes = num_classes

        self.data_dict = torch.zeros(self.num_classes, sample_number, 128).cuda()
        self.number_dict = {i:0 for i in range(self.num_classes)}

        self.threshold = threshold
        self.validation_accuracy = torchmetrics.Accuracy(self.threshold)
        self.test_accuracy = torchmetrics.Accuracy(self.threshold)

        self.model = WideResNet(num_layers, num_classes, widen_factor, dropRate=droprate)
        
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.decay = decay

        metrics = MetricCollection(
            [Accuracy(), Precision(num_classes=self.num_classes), 
                Recall(num_classes=self.num_classes)]
                )
        self.train_metrics = metrics.clone(prefix='train_')
        self.valid_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(),
            self.learning_rate, momentum=self.momentum,weight_decay=self.decay, nesterov=True)
        return optimizer
        # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
        #     lr_lambda = lambda step: VOSModel.cosine_annealing(step, epochs * len(self.), 1, 1e-6 / self.learning_rate))
        # return [optimizer], [scheduler]
    

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        loss = F.cross_entropy(output, target)

        metrics = self.train_metrics(output, target)
        self.log_dict(metrics, on_epoch = True)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        loss = F.cross_entropy(output, target)

        # Accuracy
        metrics = self.valid_metrics(output, target)
        self.log_dict(metrics, on_epoch = True)
        return metrics

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        loss = F.cross_entropy(output, target)

        # Accuracy
        metrics = self.test_metrics(output, target)
        self.log_dict(metrics, on_epoch=True)
        return metrics


In [102]:
def get_ood_scores(model, test_loader, batch_size, temperature, in_dist=False):
    _score = []
    _right_score = []
    _wrong_score = []

    ood_num_examples = len(test_loader.dataset)

    to_np = lambda x: x.data.cpu().numpy()
    concat = lambda x: np.concatenate(x, axis = 0)

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            if batch_idx >= ood_num_examples // batch_size and in_dist is False:
                break

            data = data.cuda()

            output = model.model(data)
            smax = to_np(F.softmax(output, dim=1))

            _score.append(-to_np((temperature*torch.logsumexp(output / temperature, dim=1))))
            if in_dist:
                preds = np.argmax(smax, axis=1)
                targets = target.numpy().squeeze()
                right_indices = preds == targets
                wrong_indices = np.invert(right_indices)

                _right_score.append(-np.max(smax[right_indices], axis=1))
                _wrong_score.append(-np.max(smax[wrong_indices], axis=1))

    if in_dist:
        return concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy()
    else:
        return concat(_score)[:ood_num_examples].copy()

def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
    """Use high precision for cumsum and check that final value matches sum
    Parameters
    ----------
    arr : array-like
        To be cumulatively summed as flat
    rtol : float
        Relative tolerance, see ``np.allclose``
    atol : float
        Absolute tolerance, see ``np.allclose``
    """
    out = np.cumsum(arr, dtype=np.float64)
    expected = np.sum(arr, dtype=np.float64)
    if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
        raise RuntimeError('cumsum was found to be unstable: '
                           'its last element does not correspond to sum')
    return out

def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None):
    classes = np.unique(y_true)
    if (pos_label is None and
            not (np.array_equal(classes, [0, 1]) or
                     np.array_equal(classes, [-1, 1]) or
                     np.array_equal(classes, [0]) or
                     np.array_equal(classes, [-1]) or
                     np.array_equal(classes, [1]))):
        raise ValueError("Data is not binary and pos_label is not specified")
    elif pos_label is None:
        pos_label = 1.

    # make y_true a boolean vector
    y_true = (y_true == pos_label)

    # sort scores and corresponding truth values
    desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
    y_score = y_score[desc_score_indices]
    y_true = y_true[desc_score_indices]

    # y_score typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = np.where(np.diff(y_score))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]

    # accumulate the true positives with decreasing threshold
    tps = stable_cumsum(y_true)[threshold_idxs]
    fps = 1 + threshold_idxs - tps      # add one because of zero-based indexing

    thresholds = y_score[threshold_idxs]

    recall = tps / tps[-1]

    last_ind = tps.searchsorted(tps[-1])
    sl = slice(last_ind, None, -1)      # [last_ind::-1]
    recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]

    cutoff = np.argmin(np.abs(recall - recall_level))

    return fps[cutoff] / (np.sum(np.logical_not(y_true)))   # , fps[cutoff]/(fps[cutoff] + tps[cutoff])


def get_measures(_pos, _neg, recall_level=0.95):
    pos = np.array(_pos[:]).reshape((-1, 1))
    neg = np.array(_neg[:]).reshape((-1, 1))
    examples = np.squeeze(np.vstack((pos, neg)))
    labels = np.zeros(len(examples), dtype=np.int32)
    labels[:len(pos)] += 1

    auroc = sklearn.metrics.roc_auc_score(labels, examples)
    aupr = sklearn.metrics.average_precision_score(labels, examples)
    fpr = fpr_and_fdr_at_recall(labels, examples, recall_level)

    return auroc, aupr, fpr

def print_measures(auroc, aupr, fpr, method_name='Ours', recall_level=0.95):
    print('\t\t\t\t' + method_name)
    print('  FPR{:d} AUROC AUPR'.format(int(100*recall_level)))
    print('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr))

def print_measures_with_std(aurocs, auprs, fprs, method_name='Ours', recall_level=0.95):
    print('\t\t\t\t' + method_name)
    print('  FPR{:d} AUROC AUPR'.format(int(100*recall_level)))
    print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.mean(fprs), 100*np.mean(aurocs), 100*np.mean(auprs)))
    print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.std(fprs), 100*np.std(aurocs), 100*np.std(auprs)))


def get_and_print_results(model, id_loader, ood_loader, batch_size, temperature = 1, num_to_avg=1):
    auroc_list, aupr_list, fpr_list = [], [], []

    aurocs, auprs, fprs = [], [], []
    in_score, right_score, wrong_score = get_ood_scores(model, id_loader, batch_size, temperature, in_dist=True)

    for _ in range(num_to_avg):
        out_score = get_ood_scores(model, ood_loader, batch_size, temperature)
        measures = get_measures(-in_score, -out_score)
        aurocs.append(measures[0]); auprs.append(measures[1]); fprs.append(measures[2])
    # print(in_score[:3], out_score[:3])
    auroc = np.mean(aurocs); aupr = np.mean(auprs); fpr = np.mean(fprs)
    auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr)

    if num_to_avg >= 5:
        print_measures_with_std(aurocs, auprs, fprs)
    else:
        print_measures(auroc, aupr, fpr)

In [77]:
model = VOSModel()

datamodule = CIFAR10DataModule()

trainer = pl.Trainer(max_epochs = epochs, logger=pl.loggers.WandbLogger(), gpus=1)
trainer.fit(model, datamodule)

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                | Type             | Params
---------------------------------------------------------
0 | validation_accuracy | Accuracy         | 0     
1 | test_accuracy       | Accuracy         | 0     
2 | model               | WideResNet       | 1.1 M 
3 | weight_energy       | Linear           | 11    
4 | logistic_regression | Linear           | 4     
5 | train_metrics       | MetricCollection | 0     
6 | valid_metrics       | MetricCollection | 0     
7 | test_metrics        | MetricCollection | 0     
---------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.319     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


In [78]:
trainer.save_checkpoint("example.ckpt")

In [121]:
model = VOSModel.load_from_checkpoint(checkpoint_path="example.ckpt")

In [99]:
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

ood_data = datasets.SVHN('/content','test', transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)]), download = True)

ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

get_and_print_results(model, datamodule.test_dataloader(), ood_loader, batch_size)

Using downloaded and verified file: /content/test_32x32.mat
				Ours
  FPR95 AUROC AUPR
& 81.94 & 77.42 & 53.61


In [141]:
mnist = datasets.MNIST('/content', 'test', transform = transforms.Compose([
    transforms.Resize(32) ,transforms.ToTensor(), 
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)), 
    transforms.Normalize(mean, std)]),download = True
)

mnist_loader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

In [143]:
fashion_mnist = datasets.MNIST('/content', 'test', transform = transforms.Compose([
    transforms.Resize(32), transforms.ToTensor(), 
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)), 
    transforms.Normalize(mean, std)]),download = True
)

fashion_mnist_loader = torch.utils.data.DataLoader(fashion_mnist, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

In [178]:
def get_output_uncertainty(dataloader, dataloader_name = None):
    mean = 0
    maximum = 0
    minimum = 1
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            x = model.model(data)
            energy_score_for_fg = model.log_sum_exp(x, 1)
            criterion = torch.nn.CrossEntropyLoss()
            temp_output = model.logistic_regression(energy_score_for_fg.view(-1, 1))
            lr_reg_loss = criterion(temp_output, torch.ones(len(x)).long())
            mean += lr_reg_loss
            maximum = max(lr_reg_loss, maximum)
            minimum = min(lr_reg_loss, minimum)
    if dataloader_name:
        print(f'{dataloader_name} Output Mean: {mean / len(dataloader)}')
        print(f'{dataloader_name} Output Max: {maximum}')
        print(f'{dataloader_name} Output Min: {minimum}')

    return mean / len(dataloader)

def get_weights(dataloaders):
    uncertainty = [get_output_uncertainty(dataloader) for dataloader in dataloaders]
    heuristic = lambda x: 1 / (x + 1e-8) # add epsilon to prevent divide by 0
    output = list(map(heuristic, uncertainty))
    normalized_output = [o / sum(output) for o in output]
    return normalized_output

In [179]:
get_output_uncertainty(datamodule.val_dataloader(), 'CIFAR10 Validation')
get_output_uncertainty(datamodule.test_dataloader(), 'CIFAR10 Test')
get_output_uncertainty(mnist_loader, 'MNIST Test')
get_output_uncertainty(fashion_mnist_loader, 'Fashion MNIST Test')
get_output_uncertainty(ood_loader, 'SVHN Test')

CIFAR10 Validation Output Mean: 0.03586018458008766
CIFAR10 Validation Output Max: 0.06040087714791298
CIFAR10 Validation Output Min: 0.018747715279459953
CIFAR10 Test Output Mean: 0.03424934670329094
CIFAR10 Test Output Max: 0.051549073308706284
CIFAR10 Test Output Min: 0.017143435776233673
MNIST Test Output Mean: 0.10397102683782578
MNIST Test Output Max: 0.13040177524089813
MNIST Test Output Min: 0.07635048776865005
Fashion MNIST Test Output Mean: 0.24943573772907257
Fashion MNIST Test Output Max: 0.2883538007736206
Fashion MNIST Test Output Min: 0.2145966738462448
SVHN Test Output Mean: 0.10591708868741989
SVHN Test Output Max: 0.12683247029781342
SVHN Test Output Min: 0.0770074650645256


tensor(0.1059)

In [195]:
from torchmetrics import AveragePrecision, Accuracy, StatScores

labels = torch.from_numpy(np.asarray(list(zip(*datamodule.test_dataloader().dataset))[1]))

output = list()

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(datamodule.test_dataloader()):
        output.append(model.model(data).max(1)[1])

output = torch.cat(output, 0)

def compute_stat(preds, target, class_num):
    stat_scores = StatScores(reduce='macro', num_classes=class_num)
    stat = stat_scores(preds, target)
    FPR_stat = torch.div(stat[:, 1], (stat[:, 1] + stat[:, 2]))
    FNR_stat = torch.div(stat[:, 3], (stat[:, 3] + stat[:, 0]))
    acc_stat = torch.div((stat[:, 0] + stat[:, 2]), (stat[:, 3] + stat[:, 1] + stat[:, 0] + stat[:, 2]))
    return FPR_stat, FNR_stat, acc_stat

FPR, FNR, acc = compute_stat(output, labels, 10)
print(f'FPR: {FPR}, FNR: {FNR}, Accuracy: {acc}')


FPR: tensor([0.0182, 0.0082, 0.0145, 0.0434, 0.0164, 0.0213, 0.0176, 0.0115, 0.0121,
        0.0135]), FNR: tensor([0.1571, 0.1028, 0.2130, 0.2273, 0.1921, 0.2417, 0.1176, 0.1446, 0.0951,
        0.0955]), Accuracy: tensor([0.9677, 0.9825, 0.9657, 0.9375, 0.9658, 0.9574, 0.9725, 0.9749, 0.9795,
        0.9786])
