In [1]:
import os 
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-GPU-bb1ccb6e-2bc9-c7a1-b25d-3eef9033e192/6/0"

In [2]:
"""
This requires torchtext
"""
import torch
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from pytorch_ood.dataset.txt import NewsGroup20, Reuters52, WMT16Sentences, Multi30k
from pytorch_ood.model.gruclf import GRUClassifier
from pytorch_ood.utils import ToUnknown, OODMetrics
from pytorch_ood.detector import MaxSoftmax, EnergyBased

torch.manual_seed(123)

n_epochs = 10
lr = 0.001

In [3]:
# download datasets
train_dataset = NewsGroup20("data/", train=True, download=True)
test_dataset = NewsGroup20("data/", train=False, download=True)

In [4]:
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for text, _ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_dataset)) # , max_tokens=10000# , specials=["<unk>"]
# vocab.set_default_index(0)

11293lines [00:00, 15999.73lines/s]


In [5]:
def prep(x):
    return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)

train_dataset = NewsGroup20("data/", train=True, transform=prep)
test_dataset = NewsGroup20("data/", train=False, transform=prep)

In [6]:
# add padding, etc.
def collate_batch(batch):
    texts = [i[0] for i in batch]
    labels = torch.tensor([i[1] for i in batch],  dtype=torch.int64)
    t_lengths = torch.tensor([len(t) for t in texts])
    max_t_length = torch.max(t_lengths)

    padded = []
    for text in texts:
        t = torch.cat([torch.zeros(max_t_length-len(text), dtype=torch.long), text])
        padded.append(t)
    return torch.stack(padded,dim=0), labels

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True, collate_fn=collate_batch)


In [7]:
model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

In [8]:
import torch.nn.functional as F

model.train()
model.cuda()

for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    loss_ema = 0
    correct = 0
    total = 0

    model.train()
    for n, batch in enumerate(train_loader):
        inputs, labels = batch

        inputs = inputs.cuda()
        labels = labels.cuda()
        logits = model(inputs)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_ema = loss_ema * 0.9 + loss.data.cpu().numpy() * 0.1

        pred = logits.max(dim=1).indices
        correct += pred.eq(labels).sum().data.cpu().numpy()
        total += pred.shape[0]

        if n % 10 == 0:
            print(f"Loss: {loss_ema.item():.2f} Accuracy {correct/total:.2%}")

    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for n, batch in enumerate(test_loader):
            inputs, labels = batch

            inputs = inputs.cuda()
            labels = labels.cuda()
            logits = model(inputs)
            pred = logits.max(dim=1).indices
            correct += pred.eq(labels).sum().data.cpu().numpy()
            total += pred.shape[0]

        print(f"Test Accuracy: {correct/total:.2%}")

Epoch 0
Loss: 0.30 Accuracy 12.50%
Loss: 2.05 Accuracy 5.40%
Loss: 2.67 Accuracy 5.80%
Loss: 2.86 Accuracy 6.65%
Loss: 2.93 Accuracy 7.39%
Loss: 2.96 Accuracy 8.21%
Loss: 2.96 Accuracy 8.76%
Loss: 2.94 Accuracy 9.24%
Loss: 2.95 Accuracy 9.34%
Loss: 2.93 Accuracy 9.65%
Loss: 2.89 Accuracy 9.93%
Loss: 2.88 Accuracy 10.28%
Loss: 2.90 Accuracy 10.15%
Loss: 2.88 Accuracy 10.47%
Loss: 2.88 Accuracy 10.73%
Loss: 2.88 Accuracy 10.78%
Loss: 2.88 Accuracy 10.85%
Loss: 2.86 Accuracy 10.89%
Loss: 2.86 Accuracy 11.08%
Loss: 2.83 Accuracy 11.65%
Loss: 2.81 Accuracy 11.94%
Loss: 2.77 Accuracy 12.20%
Loss: 2.78 Accuracy 12.34%
Loss: 2.73 Accuracy 12.70%
Loss: 2.73 Accuracy 12.97%
Loss: 2.67 Accuracy 13.22%
Loss: 2.67 Accuracy 13.47%
Loss: 2.65 Accuracy 13.63%
Loss: 2.69 Accuracy 13.69%
Loss: 2.59 Accuracy 14.01%
Loss: 2.56 Accuracy 14.34%
Loss: 2.62 Accuracy 14.52%
Loss: 2.56 Accuracy 14.80%
Loss: 2.58 Accuracy 14.96%
Loss: 2.57 Accuracy 15.19%
Loss: 2.55 Accuracy 15.40%
Test Accuracy: 19.63%
Epoch 1


In [21]:
def test(model, dataset, dataset_name):
    test_loader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=collate_batch)
    metrics = OODMetrics()
    metrics_energy = OODMetrics()
    softmax = MaxSoftmax(model)
    energy = EnergyBased(model)
    model.eval()

    with torch.no_grad():
        for n, batch in enumerate(test_loader):
            inputs, labels = batch

            inputs = inputs.cuda()
            labels = labels.cuda()
            metrics.update(softmax(inputs), labels)
            metrics_energy.update(energy(inputs), labels)

    d1 = metrics.compute()
    d1.update({"Method": "Softmax", "Dataset": dataset_name})
    
    d2 = metrics_energy.compute()
    d2.update({"Method": "Energy", "Dataset": dataset_name})
    return [d1, d2]

res = []

ood_dataset = Reuters52("data/", train=False, download=True, transform=prep, target_transform=ToUnknown())
res+= test(model, test_dataset + ood_dataset, dataset_name="Reuters52")

ood_dataset = Multi30k("data/", train=False, download=True, transform=prep, target_transform=ToUnknown())
res+= test(model, test_dataset + ood_dataset, dataset_name="Multi30k")

ood_dataset = WMT16Sentences("data/", download=True, transform=prep, target_transform=ToUnknown())
res+= test(model, test_dataset + ood_dataset, dataset_name="WMT16Sentences")

In [26]:
import pandas as pd 
df = pd.DataFrame(res)
print((df.groupby("Method").mean() * 100).to_latex(float_format="%.2f"))

\begin{tabular}{lrrrrr}
\toprule
{} &  AUROC &  AUPR-IN &  AUPR-OUT &  ACC95TPR &  FPR95TPR \\
Method  &        &          &           &           &           \\
\midrule
Energy  &  89.35 &    69.95 &     94.56 &     75.59 &     31.93 \\
Softmax &  82.06 &    62.08 &     88.74 &     63.77 &     52.02 \\
\bottomrule
\end{tabular}

