In [1]:
from torch.utils.data import Dataset, random_split
import os 
from os.path import join 
from PIL import Image
from tqdm import tqdm
import logging
import sys
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
logger.setLevel(logging.INFO)

logger.info("abc")

device = "cuda:0"
root = "/home/ki/datasets/"

from detector import label_to_name, color_to_name

class FruitDataset(Dataset):
    """
    
    """

    class_color_map = {
        "Apple Braeburn": "red",
        "Apple Granny Smith": "green",
        "Apricot": "orange",
        "Avocado": "green",
        "Banana": "yellow",
        "Blueberry": "black",
        "Cactus fruit": "green",
        "Cantaloupe": "yellow",
        "Cherry": "red",
        "Clementine": "orange",
        "Corn": "yellow",
        "Cucumber Ripe": "brown",
        "Grape Blue": "black",
        "Kiwi": "brown",
        "Lemon": "yellow",
        "Limes": "green",
        "Mango": "green",
        "Onion White": "brown",
        "Orange": "orange",
        "Papaya": "green",
        "Passion Fruit": "black",
        "Peach": "orange",
        "Pear": "green", # ??
        "Pepper Green": "green",
        "Pepper Red": "red",
        "Pineapple": "brown",
        "Plum": "red",
        "Pomegranate": "red",
        "Potato Red": "brown",
        "Raspberry": "red",
        "Strawberry": "red",
        "Tomato": "red",
        "Watermelon": "red" 
    }
    
    def __init__(self, root="train", transform=None, target_transform=None):
        root = join(root, "fruits", "train", "train")

        self.classes = os.listdir(root)
        self.files = []
        self.labels = []
        self.colors = []
        
        self.transform = transform
        self.target_transform = target_transform 
        
        for c in self.classes:
            fs = [join(root, c, f) for f in os.listdir(join(root, c))]
            self.files += fs
            self.labels += [c.lower().replace(" ", "_")] * len(fs)
            self.colors += [self.class_color_map[c]] * len(fs)

        self.class_map = {c: n for n, c in enumerate(label_to_name)}
        self.color_map = {c: n for n, c in enumerate(color_to_name)}
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img = self.files[index]
        y = self.class_map[self.labels[index]]
        color = self.color_map[self.colors[index]]
        
        img = Image.open(img)

        if self.transform is not None:
            img = self.transform(img)

        y = torch.tensor([y, color]) 
        if self.target_transform is not None:
            y = self.target_transform(y)
        
        return img, y 


abc


In [2]:
ds = FruitDataset(root=root)

In [3]:
from pytorch_ood.utils import ToRGB
from torchvision.transforms import ToTensor, Resize, Compose
import torch 
from torch.utils.data import DataLoader
import numpy as np


trans = Compose([ToRGB(), ToTensor(), Resize((32, 32), antialias=True)])

data = FruitDataset(root=root, transform=trans)
train_data, val_data, test_data = random_split(data, [14000,1000, 1854], generator=torch.Generator().manual_seed(0))

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [4]:
from torch import nn
from pytorch_ood.model import WideResNet

# def override 
def Model(num_classes=None, *args, **kwargs):
    model = WideResNet(*args, num_classes=1000, pretrained="imagenet32", **kwargs)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

In [5]:
from torch.optim import SGD


def train_model(att_index, num_classes):
    """
    train a model for the given attribute index 
    """
    model = Model(num_classes=num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)

    for epoch in range(1):
        running_loss = 0.0
        model.train()
        bar = tqdm(train_loader)
        for inputs, y in bar:
            labels = y[:, att_index]
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = 0.8 * running_loss + 0.2 * loss.item()
            bar.set_postfix({"loss": running_loss})

        correct = 0
        total = 0

        with torch.no_grad():
            model.eval()

            for inputs, y in test_loader:
                labels = y[:, att_index]
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the network on the test images: {correct / total:.2%}')

    return model

In [9]:
from pytorch_ood.dataset.img import TinyImages300k
from pytorch_ood.utils import is_known

def train_fruit_model():
    tiny = TinyImages300k(root=root, download=True, transform=trans, target_transform=ToUnknown())
    data_train_out, data_test_out, _ = random_split(tiny, [50000, 10000, 240000], generator=torch.Generator().manual_seed(123))

    data_noatt = FruitDataset(root=root, transform=trans, target_transform=lambda y: int(y[0]))
    train_data_noatt, val_data_noatt, test_data_noatt = random_split(data_noatt, [14000,1000, 1854], generator=torch.Generator().manual_seed(0))

    new_loader = DataLoader(train_data_noatt + data_train_out, batch_size=32, shuffle=True, num_workers=10)
    new_test_loader = DataLoader(test_data_noatt + data_test_out, batch_size=32, shuffle=False, num_workers=10)

    model = Model(num_classes=2).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)

    accs = []

    for epoch in range(1):
        running_loss = 0.0
        model.train()

        bar = tqdm(new_loader)
        for inputs, y in bar:
            labels = is_known(y).long()
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = 0.8 * running_loss + 0.2 * loss.item()
            bar.set_postfix({"loss": running_loss})

        correct = 0
        total = 0

        with torch.no_grad():
            model.eval()

            for inputs, y in new_test_loader:
                labels = is_known(y).long()
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the shape network on the test images: {correct / total:.2%}')
        accs.append(correct / total)

    return model

In [10]:
from pytorch_ood.dataset.img import (LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize)
from pytorch_ood.detector import EnergyBased, MaxSoftmax, MaxLogit, Entropy, Mahalanobis, ViM, ReAct
from pytorch_ood.utils import OODMetrics, ToUnknown
from detector import EnsembleDetector, PrologOOD, Prologic, PrologOODT

def evaluate(label_net, color_net, fruit_net):
    _ = label_net.eval()
    _ = color_net.eval()
    
    results = []

    detectors = {
        "ViM": ViM(label_net.features, w=label_net.fc.weight, b=label_net.fc.bias, d=64),
        "Mahalanobis": Mahalanobis(label_net.features),
        "Entropy": Entropy(label_net),
        "LogicOOD+": PrologOOD("kb.pl", label_net, color_net, fruit_net),
        "Logic": Prologic("kb.pl", label_net, color_net),
        "Logic+": Prologic("kb.pl", label_net, color_net, fruit_net),
        "LogicOOD": PrologOOD("kb.pl", label_net, color_net),
        "LogicT": PrologOODT("kb.pl", label_net, color_net),
        "Ensemble": EnsembleDetector(label_net, color_net),
        "MSP": MaxSoftmax(label_net),
        "ReAct": ReAct(label_net.features, label_net.fc),
        "Energy": EnergyBased(label_net),
        "MaxLogit": MaxLogit(label_net),
    }


    data_fit_label = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[0]))
    _ , data_fit_label, _ = random_split(data_fit_label, [14000, 1000, 1854], generator=torch.Generator().manual_seed(0))
    data_fit_color = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[1]))
    _, data_fit_color, _ = random_split(data_fit_color, [14000, 1000, 1854], generator=torch.Generator().manual_seed(0))
    data_fit_color = DataLoader(data_fit_color, batch_size=32, shuffle=False, num_workers=2)
    data_fit_label = DataLoader(data_fit_label, batch_size=32, shuffle=False, num_workers=2)

    data = FruitDataset(root=root, transform=trans, target_transform=lambda y: int(y[0]))
    data_in_train, data_in_val, data_in = random_split(data, [14000, 1000, 1854], generator=torch.Generator().manual_seed(0))
    train_in_loader = DataLoader(data_in_train, batch_size=32, shuffle=False, num_workers=2)

    detectors["ViM"].fit(train_in_loader, device=device)
    detectors["LogicT"].fit(data_fit_label, data_fit_color, device=device)
    # detectors["Mahalanobis"].fit(train_in_loader, device=device)

    datasets = {d.__name__: d for d in (LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize)}
    
    for detector_name, detector in detectors.items():
        for data_name, dataset_c in datasets.items():
            print(data_name)
            data_out = dataset_c(root=root, transform=trans, target_transform=ToUnknown(), download=True)
            loader = DataLoader(data_in+data_out, batch_size=256, shuffle=False, num_workers=12)
            
            scores = []
            ys = []
            
            with torch.no_grad():
                for x, y in loader:
                    scores.append(detector(x.to(device)))
                    ys.append(y.to(device))
                    
                scores = torch.cat(scores, dim=0).cpu()
                ys = torch.cat(ys, dim=0).cpu()
            
            metrics = OODMetrics()
            metrics.update(scores, ys)
            r = metrics.compute()
            r.update({
                "Method": detector_name,
                "Dataset": data_name
            })
            print(r)
            results.append(r)
    
    return results 

In [11]:
results = []

for trial in range(1):
    print("label")
    label_net = train_model(att_index=0, num_classes=33)
    print("color")
    color_net = train_model(att_index=1, num_classes=6)
    print("fruit")

    fruit_net = train_fruit_model()

    res = evaluate(label_net, color_net, fruit_net)
    
    for r in res:
        r.update({"Seed": trial})
    
    results += res

  0%|          | 0/438 [00:00<?, ?it/s]

label


100%|██████████| 438/438 [00:12<00:00, 35.12it/s, loss=0.22]  
  0%|          | 0/438 [00:00<?, ?it/s]

Accuracy of the network on the test images: 95.42%
color


100%|██████████| 438/438 [00:12<00:00, 34.64it/s, loss=0.0604]


Accuracy of the network on the test images: 96.17%
fruit


100%|██████████| 2000/2000 [00:57<00:00, 34.67it/s, loss=0.0263]  


Accuracy of the shape network on the test images: 99.88%
Computing principal space ...
Computing alpha ...
self.alpha=4.4100
Fitting with temperature scaling


  data_fit_label = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[0]))
  data_fit_label = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[0]))
  data_fit_color = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[1]))
  data_fit_color = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[1]))


Initial T/NLL: 1.000/0.155
Optimal temperature: 0.8572290539741516
NLL after scaling: 0.13'
Initial T/NLL: 1.000/0.110
Optimal temperature: 0.8130077123641968
NLL after scaling: 0.10'
self.scorer_label.t=Parameter containing:
tensor(0.8572, requires_grad=True)
self.scorer_color.t=Parameter containing:
tensor(0.8130, requires_grad=True)
LSUNCrop


  loss = nll_loss(log_softmax(logits / self.t), labels).item()
  loss = nll_loss(log_softmax(logits / self.t), labels)
  loss = nll_loss(log_softmax(logits / self.t), labels).item()


{'AUROC': 0.9957823753356934, 'AUPR-IN': 0.9992374777793884, 'AUPR-OUT': 0.9767327904701233, 'FPR95TPR': 0.004854368977248669, 'Method': 'ViM', 'Dataset': 'LSUNCrop'}
LSUNResize
{'AUROC': 0.9999845027923584, 'AUPR-IN': 0.9999971985816956, 'AUPR-OUT': 0.9999152421951294, 'FPR95TPR': 0.0, 'Method': 'ViM', 'Dataset': 'LSUNResize'}
Textures
Found 5640 texture files.
{'AUROC': 0.9990158677101135, 'AUPR-IN': 0.9996914267539978, 'AUPR-OUT': 0.9967571496963501, 'FPR95TPR': 0.0, 'Method': 'ViM', 'Dataset': 'Textures'}
TinyImageNetCrop
{'AUROC': 0.998676061630249, 'AUPR-IN': 0.9997619390487671, 'AUPR-OUT': 0.9923152923583984, 'FPR95TPR': 0.0, 'Method': 'ViM', 'Dataset': 'TinyImageNetCrop'}
TinyImageNetResize
{'AUROC': 0.9998490810394287, 'AUPR-IN': 0.9999725818634033, 'AUPR-OUT': 0.9991335272789001, 'FPR95TPR': 0.0, 'Method': 'ViM', 'Dataset': 'TinyImageNetResize'}
LSUNCrop
{'AUROC': 0.9773463010787964, 'AUPR-IN': 0.9958349466323853, 'AUPR-OUT': 0.9808893799781799, 'FPR95TPR': 0.0453074425458908

In [12]:
import pandas as pd
result_df = pd.DataFrame(results)
print((result_df.groupby(by="Method").agg(["mean", "sem"]) * 100)[["AUROC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].to_latex(float_format="%.2f"))

\begin{tabular}{lrrrrrrrr}
\toprule
{} & \multicolumn{2}{l}{AUROC} & \multicolumn{2}{l}{AUPR-IN} & \multicolumn{2}{l}{AUPR-OUT} & \multicolumn{2}{l}{FPR95TPR} \\
{} &  mean &  sem &    mean &  sem &     mean &  sem &     mean &  sem \\
Method    &       &      &         &      &          &      &          &      \\
\midrule
Ensemble  & 82.63 & 1.31 &   95.56 & 0.78 &    50.45 & 5.63 &    66.81 & 6.71 \\
Logic     & 69.95 & 2.22 &   94.06 & 0.72 &    61.60 & 1.48 &   100.00 & 0.00 \\
LogicOOD  & 84.54 & 2.07 &   96.12 & 0.70 &    53.21 & 6.71 &    64.77 & 7.65 \\
LogicOOD+ & 97.72 & 0.02 &   99.51 & 0.07 &    98.06 & 0.03 &     4.53 & 0.00 \\
LogicT    & 83.97 & 2.13 &   95.99 & 0.71 &    52.01 & 6.69 &    66.16 & 7.29 \\
MSP       & 80.15 & 0.66 &   94.71 & 0.58 &    44.40 & 4.31 &    74.38 & 4.64 \\
ReAct     & 82.34 & 2.42 &   94.65 & 1.01 &    62.16 & 3.15 &    50.27 & 3.14 \\
ViM       & 99.87 & 0.08 &   99.97 & 0.01 &    99.30 & 0.43 &     0.10 & 0.10 \\
\bottomrule
\end{tabular}


In [13]:
s = (result_df.groupby(by="Method").agg(["mean", "sem"]) * 100)[["AUROC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].to_latex(float_format="%.2f")

In [14]:
print(s.replace("& 0.", "& \pm 0.").replace("& 1.", "& \pm 1."))

\begin{tabular}{lrrrrrrrr}
\toprule
{} & \multicolumn{2}{l}{AUROC} & \multicolumn{2}{l}{AUPR-IN} & \multicolumn{2}{l}{AUPR-OUT} & \multicolumn{2}{l}{FPR95TPR} \\
{} &  mean &  sem &    mean &  sem &     mean &  sem &     mean &  sem \\
Method    &       &      &         &      &          &      &          &      \\
\midrule
Ensemble  & 82.63 & \pm 1.31 &   95.56 & \pm 0.78 &    50.45 & 5.63 &    66.81 & 6.71 \\
Logic     & 69.95 & 2.22 &   94.06 & \pm 0.72 &    61.60 & \pm 1.48 &   100.00 & \pm 0.00 \\
LogicOOD  & 84.54 & 2.07 &   96.12 & \pm 0.70 &    53.21 & 6.71 &    64.77 & 7.65 \\
LogicOOD+ & 97.72 & \pm 0.02 &   99.51 & \pm 0.07 &    98.06 & \pm 0.03 &     4.53 & \pm 0.00 \\
LogicT    & 83.97 & 2.13 &   95.99 & \pm 0.71 &    52.01 & 6.69 &    66.16 & 7.29 \\
MSP       & 80.15 & \pm 0.66 &   94.71 & \pm 0.58 &    44.40 & 4.31 &    74.38 & 4.64 \\
ReAct     & 82.34 & 2.42 &   94.65 & \pm 1.01 &    62.16 & 3.15 &    50.27 & 3.14 \\
ViM       & 99.87 & \pm 0.08 &   99.97 & \pm 0.01 &