# Street sign with Pre-Trained WideResNet

With additional shield net 

In [1]:
from torch.optim import SGD
import seaborn as sb 
from gtsrb import GTSRB
from detectors import EnsembleDetector, LogicOOD, PrologOOD
import torch
from pytorch_ood.utils import fix_random_seed

import logging
import sys
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
logger.setLevel(logging.INFO)


fix_random_seed(123)

def seed_worker(worker_id):
    fix_random_seed(worker_id)

g = torch.Generator()
g.manual_seed(0)

sb.set()

device="cuda:0"
root = "../data/"




In [2]:
from pytorch_ood.utils import ToRGB
from torchvision.transforms import ToTensor, Resize, Compose
import torch 
from torch.utils.data import DataLoader, random_split
import numpy as np


trans = Compose([ToRGB(), ToTensor(), Resize((32, 32), antialias=True)])

data = GTSRB(root=root, train=True, transforms=trans)
print(len(data))
train_data , val_data  = random_split(data, [35000, 4209], generator=torch.Generator().manual_seed(123))

test_data = GTSRB(root=root, train=False, transforms=trans)

39209


In [3]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2, worker_init_fn=seed_worker)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2, worker_init_fn=seed_worker)

In [4]:
from torch import nn
from torchvision.models.resnet import resnet18
from pytorch_ood.model import WideResNet

# def override 
def Model(num_classes=None, *args, **kwargs):
    model = WideResNet(*args, num_classes=1000, **kwargs, pretrained="imagenet32")
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


In [5]:
def train_model(att_index, num_classes):
    """
    train a model for the given attribute index 
    """
    data = GTSRB(root=root, train=True, transforms=trans)
    train_data , val_data  = random_split(data, [35000, 4209], generator=torch.Generator().manual_seed(123))
    test_data = GTSRB(root=root, train=False, transforms=trans)

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2, worker_init_fn=seed_worker)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2, worker_init_fn=seed_worker)
    
    model = Model(num_classes=num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)

    for epoch in range(20):
        running_loss = 0.0
        model.train()
        bar = tqdm(train_loader)
        for inputs, y in bar:
            labels = y[:, att_index]
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = 0.8 * running_loss + 0.2 * loss.item()
            bar.set_postfix({"loss": running_loss})

        correct = 0
        total = 0

        with torch.no_grad():
            model.eval()

            for inputs, y in test_loader:
                labels = y[:, att_index]
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the network on the test images: {correct / total:.2%}')

    return model 

# Sign Network 

In [6]:
from torch.utils.data import DataLoader
from pytorch_ood.utils import is_known
from tqdm.notebook import tqdm 
from pytorch_ood.dataset.img import TinyImages300k
from pytorch_ood.utils import ToUnknown
from torch.utils.data import random_split
from torch.utils.data import TensorDataset


# class GANData:
#     def __init__(self):
#         self.data =  torch.tensor(np.load("../data/gtsrb-samples-z-var-100.npz")["x"], dtype=torch.float32)
#         resize = Resize(size=(32, 32), antialias=True)
#         self.data = torch.stack([resize(a) for a in self.data])
#
#     def __len__(self):
#         return len(self.data)
#
#     def __getitem__(self, item):
#         return self.data[item], torch.tensor(-1)


# %%
def train_sign_model():
    # dataset = GANData()
    # print(len(dataset))
    dataset = TinyImages300k(root=root, download=True, transform=trans, target_transform=lambda x: torch.tensor(-1))
    data_train_out , data_test_out, _  = random_split(dataset, [5000, 10000, 285000], generator=torch.Generator().manual_seed(123))

    train_data_noatt = GTSRB(root=root, train=True, transforms=trans, target_transform=lambda y: torch.tensor(y[0]))
    test_data_noatt = GTSRB(root=root, train=False, transforms=trans, target_transform=lambda y: torch.tensor(y[0]))

    new_loader = DataLoader(train_data_noatt + data_train_out, batch_size=128, shuffle=True, num_workers=10, worker_init_fn=seed_worker)
    new_test_loader = DataLoader(test_data_noatt, batch_size=128, shuffle=False, num_workers=10, worker_init_fn=seed_worker)

    model = Model(num_classes=2).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)

    accs = []

    for epoch in range(1):
        running_loss = 0.0
        model.train()
        
        bar = tqdm(new_loader)
        for inputs, y in bar:
            labels = is_known(y).long()
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = 0.8 * running_loss + 0.2 * loss.item()
            bar.set_postfix({"loss": running_loss})

        correct = 0
        total = 0

        with torch.no_grad():
            model.eval()

            for inputs, y in new_test_loader:
                labels = is_known(y).long()
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the sign network on the test images: {correct / total:.2%}')
        accs.append(correct / total)


    return model


# train_sign_model()


# OOD Evaluation 

In [7]:
from importlib import reload

import detectors
reload(detectors)

from pytorch_ood.dataset.img import (LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize)
from pytorch_ood.detector import EnergyBased, MaxSoftmax, ReAct, MaxLogit, Entropy, Mahalanobis, ViM
from pytorch_ood.utils import ToRGB, OODMetrics

from detectors import PrologOODT

def evaluate(label_net, shape_net, color_net, shield_net):
    _ = label_net.eval()
    _ = shape_net.eval()
    _ = color_net.eval()
    _ = shield_net.eval()
    
    results = []

    data_in = GTSRB(root=root, train=False, transforms=trans, target_transform=lambda y: y[0])
    # dataset_out_test = Textures(root=root, transform=trans, target_transform=ToUnknown(), download=True)

    detectors = {
        "PrologOOD": PrologOOD(
            "kb.pl",
            label_net=label_net,
            shape_net=shape_net,
            color_net=color_net,
            label_file="../data/GTSRB/labels.txt",
        ),
        "PrologOOD+": PrologOOD(
            "kb.pl",
            label_net,
            shape_net,
            color_net,
            sign_net=shield_net,
            label_file="../data/GTSRB/labels.txt",
        ),
        "PrologOODT": PrologOODT(
            "kb.pl",
            label_net=label_net,
            shape_net=shape_net,
            color_net=color_net,
            label_file="../data/GTSRB/labels.txt",
        ),
        "PrologOODT+": PrologOODT(
            "kb.pl",
            label_net,
            shape_net,
            color_net,
            sign_net=shield_net,
            label_file="../data/GTSRB/labels.txt",
        ),
        "Logic": LogicOOD(
            label_net,
            shape_net,
            color_net,
            data_in.class_to_shape,
            data_in.class_to_color,
        ).consistent,
        "Ensemble": EnsembleDetector(label_net, shape_net, color_net),
        "MSP": MaxSoftmax(label_net),
        "Energy": EnergyBased(label_net),
        "ReAct": ReAct(label_net.features, label_net.fc, threshold=10.0),
        "ViM": ViM(label_net.features, w=label_net.fc.weight, b=label_net.fc.bias, d=64),
        "Mahalanobis": Mahalanobis(label_net.features),
        "Entropy": Entropy(label_net),
        "MaxLogit": MaxLogit(label_net),
    }
    data = GTSRB(root=root, train=True, transforms=trans, target_transform=lambda y: torch.tensor(y[0]))
    _ , val_data  = random_split(data, [35000, 4209], generator=torch.Generator().manual_seed(123))
    label_loader = DataLoader(val_data, batch_size=128, shuffle=False, worker_init_fn=seed_worker, num_workers=10)

    data = GTSRB(root=root, train=True, transforms=trans, target_transform=lambda y: torch.tensor(y[1]))
    _ , val_data  = random_split(data, [35000, 4209], generator=torch.Generator().manual_seed(123))
    color_loader = DataLoader(val_data, batch_size=128, shuffle=False, worker_init_fn=seed_worker, num_workers=10)

    data = GTSRB(root=root, train=True, transforms=trans, target_transform=lambda y: torch.tensor(y[2]))
    _ , val_data  = random_split(data, [35000, 4209], generator=torch.Generator().manual_seed(123))
    shape_loader = DataLoader(val_data, batch_size=128, shuffle=False, worker_init_fn=seed_worker, num_workers=10)

    detectors["ViM"].fit(label_loader, device=device)
    detectors["Mahalanobis"].fit(label_loader, device=device)
    detectors["PrologOODT"].fit(label_loader, color_loader, shape_loader, device=device)
    detectors["PrologOODT+"].fit(label_loader, color_loader, shape_loader, device=device)

    datasets = {d.__name__: d for d in (LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize)}
    
    for detector_name, detector in detectors.items():
        for data_name, dataset_c in datasets.items():
            data_out = dataset_c(root=root, transform=trans, target_transform=ToUnknown(), download=True)
            loader = DataLoader(data_in+data_out, batch_size=128, shuffle=False, worker_init_fn=seed_worker, num_workers=10)
            
            scores = []
            ys = []
            
            with torch.no_grad():
                for x, y in loader:
                    scores.append(detector(x.to(device)))
                    ys.append(y.to(device))
                    
                scores = torch.cat(scores, dim=0).cpu()
                ys = torch.cat(ys, dim=0).cpu()
            
            metrics = OODMetrics()
            metrics.update(scores, ys)
            r = metrics.compute()
            r.update({
                "Method": detector_name,
                "Dataset": data_name
            })
            print(r)
            results.append(r)
    
    return results 

In [8]:
def evaluate_acc(net, att_idx=0, oe=False):
    _ = net.eval()
    
    if oe:
        target_trans = lambda y: torch.tensor(1)
    else:
         target_trans = lambda y: y[att_idx]

    trans = Compose([Resize(size=(32, 32), antialias=True), ToRGB(), ToTensor()])
    data_in = GTSRB(root=root, train=False, transforms=trans, target_transform=target_trans)
    loader = DataLoader(data_in, batch_size=1024, shuffle=False, worker_init_fn=seed_worker)
            
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            predicted = outputs.max(dim=1).indices
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct / total  

def evaluate_accs(label_net, shape_net, color_net, shield_net):
    r = {}
    names = ("Label", "Color", "Shape",)
    
    for n, net in enumerate((label_net, color_net, shape_net)): 
        acc = evaluate_acc(net, n)
        r[names[n]] = acc
    
    acc = evaluate_acc(shield_net, oe=True)
    r["Sign"] = acc
    
    return [r] 

In [None]:
results = []
results_acc = []

for trial in range(10):
    shield_net = train_sign_model()
    shape_net = train_model(att_index=2, num_classes=5)
    color_net = train_model(att_index=1, num_classes=4)
    label_net = train_model(att_index=0, num_classes=43)
    
    res = evaluate(label_net, shape_net, color_net, shield_net)
    res_acc = evaluate_accs(label_net, shape_net, color_net, shield_net)
    
    for r in res:
        r.update({"Seed": trial})
        
    for r in res_acc:
        r.update({"Seed": trial})
    
    results += res
    results_acc += res_acc

  0%|          | 0/346 [00:00<?, ?it/s]

Accuracy of the sign network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.97%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.97%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.96%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.91%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.97%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.85%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.90%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.92%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.93%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.96%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.94%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.99%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.99%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.99%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.98%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 96.82%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.38%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.56%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.41%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.55%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.95%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.03%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.88%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.80%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.73%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.83%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.07%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.10%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.18%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.18%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.16%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 99.18%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.88%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.90%


  0%|          | 0/1094 [00:00<?, ?it/s]

Accuracy of the network on the test images: 98.94%
Computing principal space ...
Computing alpha ...
self.alpha=8.8574
Fitting with temperature scaling
label: y1.unique()=tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42]) logits_label.shape=torch.Size([4209, 43])
Initial T/NLL: 1.000/0.002
Optimal temperature: 0.999998152256012
NLL after scaling: 0.00'
color: y2.unique()=tensor([0, 1, 2, 3]) logits_color.shape=torch.Size([4209, 4])
Initial T/NLL: 1.000/0.000
Optimal temperature: 0.9999988079071045
NLL after scaling: 0.00'
shape: y3.unique()=tensor([0, 1, 2, 3, 4]) logits_shape.shape=torch.Size([4209, 5])
Initial T/NLL: 1.000/0.000
Optimal temperature: 0.9999987483024597
NLL after scaling: 0.00'
self.scorer_label.t=Parameter containing:
tensor(1.0000, requires_grad=True)
self.scorer_color.t=Parameter containing:
tensor(1.0000, requires_grad=Tr

  loss = nll_loss(log_softmax(logits / self.t), labels).item()
  loss = nll_loss(log_softmax(logits / self.t), labels)
  loss = nll_loss(log_softmax(logits / self.t), labels).item()


label: y1.unique()=tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42]) logits_label.shape=torch.Size([4209, 43])
Initial T/NLL: 1.000/0.002
Optimal temperature: 0.999998152256012
NLL after scaling: 0.00'
color: y2.unique()=tensor([0, 1, 2, 3]) logits_color.shape=torch.Size([4209, 4])
Initial T/NLL: 1.000/0.000
Optimal temperature: 0.9999988079071045
NLL after scaling: 0.00'
shape: y3.unique()=tensor([0, 1, 2, 3, 4]) logits_shape.shape=torch.Size([4209, 5])
Initial T/NLL: 1.000/0.000
Optimal temperature: 0.9999987483024597
NLL after scaling: 0.00'
self.scorer_label.t=Parameter containing:
tensor(1.0000, requires_grad=True)
self.scorer_color.t=Parameter containing:
tensor(1.0000, requires_grad=True)
self.scorer_shape.t=Parameter containing:
tensor(1.0000, requires_grad=True)
{'AUROC': 0.9987390041351318, 'AUPR-IN': 0.9982517957687378, 'AUPR-OUT'

In [None]:
import pandas as pd 
result_df = pd.DataFrame(results)
print((result_df.groupby(by="Method").agg(["mean", "sem"]) * 100)[["AUROC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].to_latex(float_format="%.2f"))

In [None]:
from scipy.stats import ttest_ind

sem_auroc = result_df[result_df["Method"] == "PrologOOD"].groupby(by=["Method", "Seed"]).mean()["AUROC"]
sem_ensemble =  result_df[result_df["Method"] == "Ensemble"].groupby(by=["Method", "Seed"]).mean()["AUROC"]

print(ttest_ind(sem_auroc, sem_ensemble, equal_var=False))

In [None]:
print((pd.DataFrame(results_acc) * 100).agg(["mean", "sem"]).to_latex(float_format="%.2f"))