In [None]:
from torch.utils.data import Dataset, random_split
import os 
from os.path import join 
from PIL import Image
from tqdm import tqdm
import logging
import sys
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
logger.setLevel(logging.INFO)

logger.info("abc")

device = "cuda:0"
root = "/home/ki/datasets/"

from detector import label_to_name, color_to_name

class FruitDataset(Dataset):
    """
    
    """

    class_color_map = {
        "Apple Braeburn": "red",
        "Apple Granny Smith": "green",
        "Apricot": "orange",
        "Avocado": "green",
        "Banana": "yellow",
        "Blueberry": "black",
        "Cactus fruit": "green",
        "Cantaloupe": "yellow",
        "Cherry": "red",
        "Clementine": "orange",
        "Corn": "yellow",
        "Cucumber Ripe": "brown",
        "Grape Blue": "black",
        "Kiwi": "brown",
        "Lemon": "yellow",
        "Limes": "green",
        "Mango": "green",
        "Onion White": "brown",
        "Orange": "orange",
        "Papaya": "green",
        "Passion Fruit": "black",
        "Peach": "orange",
        "Pear": "green", # ??
        "Pepper Green": "green",
        "Pepper Red": "red",
        "Pineapple": "brown",
        "Plum": "red",
        "Pomegranate": "red",
        "Potato Red": "brown",
        "Raspberry": "red",
        "Strawberry": "red",
        "Tomato": "red",
        "Watermelon": "red" 
    }
    
    def __init__(self, root="train", transform=None, target_transform=None):
        root = join(root, "fruits", "train", "train")

        self.classes = os.listdir(root)
        self.files = []
        self.labels = []
        self.colors = []
        
        self.transform = transform
        self.target_transform = target_transform 
        
        for c in self.classes:
            fs = [join(root, c, f) for f in os.listdir(join(root, c))]
            self.files += fs
            self.labels += [c.lower().replace(" ", "_")] * len(fs)
            self.colors += [self.class_color_map[c]] * len(fs)

        self.class_map = {c: n for n, c in enumerate(label_to_name)}
        self.color_map = {c: n for n, c in enumerate(color_to_name)}
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img = self.files[index]
        y = self.class_map[self.labels[index]]
        color = self.color_map[self.colors[index]]
        
        img = Image.open(img)

        if self.transform is not None:
            img = self.transform(img)

        y = torch.tensor([y, color]) 
        if self.target_transform is not None:
            y = self.target_transform(y)
        
        return img, y 


In [None]:
ds = FruitDataset(root=root)

In [None]:
from pytorch_ood.utils import ToRGB
from torchvision.transforms import ToTensor, Resize, Compose
import torch 
from torch.utils.data import DataLoader
import numpy as np


trans = Compose([ToRGB(), ToTensor(), Resize((32, 32), antialias=True)])

data = FruitDataset(root=root, transform=trans)
train_data, val_data, test_data = random_split(data, [14000,1000, 1854], generator=torch.Generator().manual_seed(0))

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [None]:
from torch import nn
from pytorch_ood.model import WideResNet

# def override 
def Model(num_classes=None, *args, **kwargs):
    model = WideResNet(*args, num_classes=1000, pretrained="imagenet32", **kwargs)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

In [None]:
from torch.optim import SGD


def train_model(att_index, num_classes):
    """
    train a model for the given attribute index 
    """
    model = Model(num_classes=num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)

    for epoch in range(5):
        running_loss = 0.0
        model.train()
        bar = tqdm(train_loader)
        for inputs, y in bar:
            labels = y[:, att_index]
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = 0.8 * running_loss + 0.2 * loss.item()
            bar.set_postfix({"loss": running_loss})

        correct = 0
        total = 0

        with torch.no_grad():
            model.eval()

            for inputs, y in test_loader:
                labels = y[:, att_index]
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the network on the test images: {correct / total:.2%}')

    return model

In [None]:
from pytorch_ood.dataset.img import TinyImages300k
from pytorch_ood.utils import is_known

def train_fruit_model():
    tiny = TinyImages300k(root=root, download=True, transform=trans, target_transform=ToUnknown())
    data_train_out, data_test_out, _ = random_split(tiny, [50000, 10000, 240000], generator=torch.Generator().manual_seed(123))

    data_noatt = FruitDataset(root=root, transform=trans, target_transform=lambda y: int(y[0]))
    train_data_noatt, val_data_noatt, test_data_noatt = random_split(data_noatt, [14000,1000, 1854], generator=torch.Generator().manual_seed(0))

    new_loader = DataLoader(train_data_noatt + data_train_out, batch_size=32, shuffle=True, num_workers=10)
    new_test_loader = DataLoader(test_data_noatt + data_test_out, batch_size=32, shuffle=False, num_workers=10)

    model = Model(num_classes=2).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)

    accs = []

    for epoch in range(10):
        running_loss = 0.0
        model.train()

        bar = tqdm(new_loader)
        for inputs, y in bar:
            labels = is_known(y).long()
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = 0.8 * running_loss + 0.2 * loss.item()
            bar.set_postfix({"loss": running_loss})

        correct = 0
        total = 0

        with torch.no_grad():
            model.eval()

            for inputs, y in new_test_loader:
                labels = is_known(y).long()
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the shape network on the test images: {correct / total:.2%}')
        accs.append(correct / total)

    return model

In [None]:
from pytorch_ood.dataset.img import (LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize)
from pytorch_ood.detector import EnergyBased, MaxSoftmax, MaxLogit, Entropy, Mahalanobis, ViM, ReAct
from pytorch_ood.utils import OODMetrics, ToUnknown
from detector import EnsembleDetector, PrologOOD, Prologic, PrologOODT

def evaluate(label_net, color_net, fruit_net):
    _ = label_net.eval()
    _ = color_net.eval()
    
    results = []

    detectors = {
        "ViM": ViM(label_net.features, w=label_net.fc.weight, b=label_net.fc.bias, d=64),
        "Mahalanobis": Mahalanobis(label_net.features),
        "Entropy": Entropy(label_net),
        "LogicOOD+": PrologOOD("kb.pl", label_net, color_net, fruit_net),
        "Logic": Prologic("kb.pl", label_net, color_net),
        "Logic+": Prologic("kb.pl", label_net, color_net, fruit_net),
        "LogicOOD": PrologOOD("kb.pl", label_net, color_net),
        "LogicOODT": PrologOODT("kb.pl", label_net, color_net),
        "LogicOODT+": PrologOODT("kb.pl", label_net, color_net, fruit_net),
        # "LogicT+": PrologOODT("kb.pl", label_net, color_net, fruit_net), # this should be exactly the same
        "Ensemble": EnsembleDetector(label_net, color_net),
        "MSP": MaxSoftmax(label_net),
        "ReAct": ReAct(label_net.features, label_net.fc),
        "Energy": EnergyBased(label_net),
        "MaxLogit": MaxLogit(label_net),
    }


    data_fit_label = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[0]))
    _ , data_fit_label, _ = random_split(data_fit_label, [14000, 1000, 1854], generator=torch.Generator().manual_seed(0))
    data_fit_color = FruitDataset(root=root, transform=trans,  target_transform=lambda y: torch.tensor(y[1]))
    _, data_fit_color, _ = random_split(data_fit_color, [14000, 1000, 1854], generator=torch.Generator().manual_seed(0))
    data_fit_color = DataLoader(data_fit_color, batch_size=32, shuffle=False, num_workers=2)
    data_fit_label = DataLoader(data_fit_label, batch_size=32, shuffle=False, num_workers=2)

    data = FruitDataset(root=root, transform=trans, target_transform=lambda y: int(y[0]))
    data_in_train, data_in_val, data_in = random_split(data, [14000, 1000, 1854], generator=torch.Generator().manual_seed(0))
    train_in_loader = DataLoader(data_in_train, batch_size=32, shuffle=False, num_workers=2)

    detectors["ViM"].fit(train_in_loader, device=device)
    detectors["LogicOODT"].fit(data_fit_label, data_fit_color, device=device)
    detectors["LogicOODT+"].fit(data_fit_label, data_fit_color, device=device)
    detectors["Mahalanobis"].fit(train_in_loader, device=device)

    datasets = {d.__name__: d for d in (LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize)}
    
    for detector_name, detector in detectors.items():
        for data_name, dataset_c in datasets.items():
            print(data_name)
            data_out = dataset_c(root=root, transform=trans, target_transform=ToUnknown(), download=True)
            loader = DataLoader(data_in+data_out, batch_size=256, shuffle=False, num_workers=12)
            
            scores = []
            ys = []
            
            with torch.no_grad():
                for x, y in loader:
                    scores.append(detector(x.to(device)))
                    ys.append(y.to(device))
                    
                scores = torch.cat(scores, dim=0).cpu()
                ys = torch.cat(ys, dim=0).cpu()
            
            metrics = OODMetrics()
            metrics.update(scores, ys)
            r = metrics.compute()
            r.update({
                "Method": detector_name,
                "Dataset": data_name
            })
            print(r)
            results.append(r)
    
    return results 

In [None]:
results = []

for trial in range(10):
    print("label")
    label_net = train_model(att_index=0, num_classes=33)
    print("color")
    color_net = train_model(att_index=1, num_classes=6)
    print("fruit")

    fruit_net = train_fruit_model()

    res = evaluate(label_net, color_net, fruit_net)
    
    for r in res:
        r.update({"Seed": trial})
    
    results += res

In [None]:
import pandas as pd
result_df = pd.DataFrame(results)
# print((result_df.groupby(by="Method").agg(["mean", "sem"]) * 100)[["AUROC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].to_latex(float_format="%.2f"))

In [None]:
# s = (result_df.groupby(by="Method").agg(["mean", "sem"]) * 100)[["AUROC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].to_latex(float_format="%.2f")

In [None]:
order = ['MSP', 'Energy', 'MaxLogit', 'Entropy', 'ReAct', 'Mahalanobis', 'ViM', 'Ensemble', 'Logic', 'Logic+', 'LogicOOD', 'LogicOOD+', 'LogicOODT', 'LogicOODT+']


print((result_df.groupby(by=["Method", "Seed"]).mean() * 100).groupby("Method").agg(["mean", "sem"]).reindex(order).to_latex(float_format="%.2f").replace("& 0.", "& $\pm$ 0.").replace("& 2.", "& $\pm$ 2.").replace("& 3.", "& $\pm$ 3.").replace("& 1.", "& $\pm$ 1.").replace("& 4.", "& $\pm$ 4.").replace("& 5.", "& $\pm$ 5."))


# print(s.replace("& 0.", "& \pm 0.").replace("& 1.", "& \pm 1.").replace("& 2.", "& \pm 2.").replace("& 4.", "& \pm 4."))