In [1]:
import os 
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-GPU-8ab9a0c8-909c-3f13-97e6-7376d6d4a029/0/0"

In [2]:
"""
This requires torchtext
"""
import torch
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from pytorch_ood.dataset.txt import NewsGroup20, Reuters52, WMT16Sentences, Multi30k, WikiText2
from pytorch_ood.model.gruclf import GRUClassifier
from pytorch_ood.utils import ToUnknown, OODMetrics
from pytorch_ood.detector import MaxSoftmax, EnergyBased
import torch.nn.functional as F
from pytorch_ood.loss import OutlierExposureLoss

torch.manual_seed(123)

n_epochs = 10
lr = 0.001

In [3]:
# download datasets
train_dataset = NewsGroup20("data/", train=True, download=True)
test_dataset = NewsGroup20("data/", train=False, download=True)

In [4]:
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for text, _ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_dataset)) # , max_tokens=10000# , specials=["<unk>"]
# vocab.set_default_index(0)

11293lines [00:00, 15757.50lines/s]


In [5]:
# add padding, etc.
def collate_batch(batch):
    texts = [i[0] for i in batch]
    labels = torch.tensor([i[1] for i in batch],  dtype=torch.int64)
    t_lengths = torch.tensor([len(t) for t in texts])
    max_t_length = torch.max(t_lengths)

    padded = []
    for text in texts:
        t = torch.cat([torch.zeros(max_t_length-len(text), dtype=torch.long), text])
        padded.append(t)
    return torch.stack(padded,dim=0), labels

In [6]:
def test(model, dataset, dataset_name):
    test_loader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=collate_batch)
    metrics = OODMetrics()
    metrics_energy = OODMetrics()
    softmax = MaxSoftmax(model)
    energy = EnergyBased(model)
    model.eval()

    with torch.no_grad():
        for n, batch in enumerate(test_loader):
            inputs, labels = batch

            inputs = inputs.cuda()
            labels = labels.cuda()
            metrics.update(softmax(inputs), labels)
            metrics_energy.update(energy(inputs), labels)

    d1 = metrics.compute()
    d1.update({"Method": "Softmax", "Dataset": dataset_name})
    
    d2 = metrics_energy.compute()
    d2.update({"Method": "Energy", "Dataset": dataset_name})
    return [d1, d2]

In [7]:
import pandas as pd 


model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

def prep(x):
    return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)

train_in_dataset = NewsGroup20("data/", train=True, transform=prep)
test_dataset = NewsGroup20("data/", train=False, transform=prep)

train_out_dataset = WikiText2("data/", split="train", transform=prep, target_transform=ToUnknown()) 

train_loader = DataLoader(train_in_dataset + train_out_dataset, batch_size=32, shuffle=True, 
                          collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True, collate_fn=collate_batch)


model.train()
model.cuda()

criterion = OutlierExposureLoss(alpha=0.5)


for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    loss_ema = 0
    correct = 0
    total = 0

    model.train()
    for n, batch in enumerate(train_loader):
        inputs, labels = batch

        inputs = inputs.cuda()
        labels = labels.cuda()
        logits = model(inputs)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_ema = loss_ema * 0.9 + loss.data.cpu().numpy() * 0.1

        pred = logits.max(dim=1).indices
        correct += pred[labels >= 0].eq(labels[labels >= 0]).sum().data.cpu().numpy()
        total += pred[labels >= 0].shape[0]

        if n % 10 == 0:
            print(f"Loss: {loss_ema.item():.2f} Accuracy {correct/total:.2%}")

    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for n, batch in enumerate(test_loader):
            inputs, labels = batch

            inputs = inputs.cuda()
            labels = labels.cuda()
            logits = model(inputs)
            pred = logits.max(dim=1).indices
            correct += pred.eq(labels).sum().data.cpu().numpy()
            total += pred.shape[0]

        print(f"Loss {loss_ema:.2f} Test Accuracy: {correct/total:.2%}")


Epoch 0
Loss: 0.19 Accuracy 0.00%
Loss: 1.26 Accuracy 1.14%
Loss: 1.64 Accuracy 1.79%
Loss: 1.75 Accuracy 1.92%
Loss: 1.82 Accuracy 1.91%
Loss: 1.79 Accuracy 2.02%
Loss: 1.84 Accuracy 1.90%
Loss: 1.82 Accuracy 1.94%
Loss: 1.84 Accuracy 2.01%
Loss: 1.85 Accuracy 1.96%
Loss: 1.83 Accuracy 1.98%
Loss: 1.89 Accuracy 2.08%
Loss: 1.83 Accuracy 2.04%
Loss: 1.84 Accuracy 1.98%
Loss: 1.84 Accuracy 2.02%
Loss: 1.83 Accuracy 1.95%
Loss: 1.85 Accuracy 1.98%
Loss: 1.83 Accuracy 1.90%
Loss: 1.83 Accuracy 1.95%
Loss: 1.88 Accuracy 1.91%
Loss: 1.84 Accuracy 1.97%
Loss: 1.81 Accuracy 1.95%
Loss: 1.82 Accuracy 1.98%
Loss: 1.80 Accuracy 1.95%
Loss: 1.81 Accuracy 1.95%
Loss: 1.84 Accuracy 1.98%
Loss: 1.82 Accuracy 2.11%
Loss: 1.82 Accuracy 2.13%
Loss: 1.81 Accuracy 2.16%
Loss: 1.86 Accuracy 2.18%
Loss: 1.86 Accuracy 2.22%
Loss: 1.89 Accuracy 2.27%
Loss: 1.84 Accuracy 2.28%
Loss: 1.85 Accuracy 2.31%
Loss: 1.86 Accuracy 2.36%
Loss: 1.81 Accuracy 2.39%
Loss: 1.80 Accuracy 2.38%
Loss: 1.82 Accuracy 2.40%
Loss

In [9]:
res = []

ood_dataset = Reuters52("data/", train=False, download=True, transform=prep, target_transform=ToUnknown())
res+= test(model, test_dataset + ood_dataset, dataset_name="Reuters52")

ood_dataset = Multi30k("data/", train=False, download=True, transform=prep, target_transform=ToUnknown())
res+= test(model, test_dataset + ood_dataset, dataset_name="Multi30k")

ood_dataset = WMT16Sentences("data/", download=True, transform=prep, target_transform=ToUnknown())
res+= test(model, test_dataset + ood_dataset, dataset_name="WMT16Sentences")

df = pd.DataFrame(res)
# df["Temperature"] = temp
# dfs.append(df)

NameError: name 'dfs' is not defined

In [10]:
import pandas as pd 
df = pd.DataFrame(res)
print((df.groupby("Method").mean() * 100).to_latex(float_format="%.2f"))

\begin{tabular}{lrrrrr}
\toprule
{} &  AUROC &  AUPR-IN &  AUPR-OUT &  ACC95TPR &  FPR95TPR \\
Method  &        &          &           &           &           \\
\midrule
Energy  &  94.41 &    86.23 &     97.87 &     85.59 &     16.85 \\
Softmax &  93.84 &    86.34 &     97.56 &     83.59 &     19.55 \\
\bottomrule
\end{tabular}



In [None]:
#!ls /data_slow/kirchheim/gan_oe/text-generation/work_language_model/