In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import sys
import os
import pickle
import argparse
import math
import time
from bisect import bisect_left
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from torch.autograd import Variable as V
import torchtext

import pandas as pd
from torchtext import data
from torchtext import datasets
import spacy
import re

import csv
csv.field_size_limit(sys.maxsize)

import tqdm
from tqdm import tqdm_notebook
from utils.display_results import get_performance

In [None]:
np.random.seed(1)

args = argparse.Namespace(
    batch_size = 64,
    in_dist_dataset = 'sst',
    method='OECC',
    save = 'results',
    load = 'results', 
    oe_dataset = 'wikitext2'
    )

torch.set_grad_enabled(False)
cudnn.benchmark = True  # fire on all cylinders

In [None]:
# ============================ SST ============================ #
# set up fields
TEXT_sst = data.Field(pad_first=True)
LABEL_sst = data.Field(sequential=False)

# make splits for data
train_sst, val_sst, test_sst = datasets.SST.splits(
    TEXT_sst, LABEL_sst, fine_grained=False, train_subtrees=False,
    filter_pred=lambda ex: ex.label != 'neutral')

# build vocab
TEXT_sst.build_vocab(train_sst, max_size=10000)
LABEL_sst.build_vocab(train_sst, max_size=10000)
print('vocab length for SST(including special tokens):', len(TEXT_sst.vocab))
num_classes = len(LABEL_sst.vocab)
print('num labels:', len(LABEL_sst.vocab))
# create our own iterator, avoiding the calls to build_vocab in SST.iters
train_iter_sst, val_iter_sst, test_iter_sst = data.BucketIterator.splits(
    (train_sst, val_sst, test_sst), batch_size=args.batch_size, repeat=False)


ood_num_examples = len(test_iter_sst.dataset) // 5
expected_ap = ood_num_examples / (ood_num_examples + len(test_iter_sst.dataset))
recall_level = 0.9

In [4]:
class ClfGRU(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(len(TEXT_sst.vocab), 50, padding_idx=1)
        self.gru = nn.GRU(input_size=50, hidden_size=128, num_layers=2, bias=True, batch_first=True, bidirectional=False)
        self.linear = nn.Linear(128, num_classes)
        self.num_classes = num_classes

    def forward(self, x):
        embeds = self.embedding(x)
        hidden = self.gru(embeds)[1][1]  # select h_n, and select the 2nd layer
        logits = self.linear(hidden)
        return logits



model = ClfGRU(num_classes-1)
model.load_state_dict(torch.load(f'./{args.load}/{args.in_dist_dataset}/{args.method}/{args.oe_dataset}/model_finetune.dict'))  # change location as per our method
print('\nLoaded model.\n')


Loaded model.



### Use 20 Newsgroup and TREC as validation OOD data 

In [None]:
# ============================ 20 Newsgroups ============================ #
TEXT_20ng = data.Field(pad_first=True, lower=True, fix_length=100)
LABEL_20ng = data.Field(sequential=False)

train_20ng = data.TabularDataset(path='20_newsgroup_train.csv',
                                 format='csv',
                                 fields=[('text', TEXT_20ng), ('label', LABEL_20ng)])

test_20ng = data.TabularDataset(path='20_newsgroup_test.csv',
                                 format='csv',
                                 fields=[('text', TEXT_20ng), ('label', LABEL_20ng)])

TEXT_20ng.build_vocab(train_20ng, max_size=10000)
LABEL_20ng.build_vocab(train_20ng, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_20ng.vocab))
#num_classes = len(LABEL_20ng.vocab)
print('num labels:', len(LABEL_20ng.vocab))
train_iter_20ng = data.BucketIterator(train_20ng, batch_size=args.batch_size, repeat=False)
test_iter_20ng = data.BucketIterator(test_20ng, batch_size=args.batch_size, repeat=False)


# ============================ TREC ============================ #
# set up fields
TEXT_trec = data.Field(pad_first=True, lower=True)
LABEL_trec = data.Field(sequential=False)

# make splits for data
train_trec, test_trec = datasets.TREC.splits(TEXT_trec, LABEL_trec, fine_grained=True)


# build vocab
TEXT_trec.build_vocab(train_trec, max_size=10000)
LABEL_trec.build_vocab(train_trec, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_trec.vocab))
#num_classes = len(LABEL_trec.vocab)
print('num labels:', len(LABEL_trec.vocab))

# make iterators
train_iter_trec, test_iter_trec = data.BucketIterator.splits(
    (train_trec, test_trec), batch_size=args.batch_size, repeat=False)

In [None]:
def get_scores(dataset_iterator, ood=False, translation_dataset = False, snli=False):
    model.eval()
    model.cpu()
    
    outlier_scores = []

    for batch_idx, batch in enumerate(iter(dataset_iterator)):
        if ood and (batch_idx * args.batch_size > ood_num_examples):
            break

        if snli:
            inputs = batch.hypothesis.t()
        else:
            if translation_dataset:
                inputs = batch.src.t()
            else:        
                inputs = batch.text.t()

        logits = model(inputs)
        
        smax = F.softmax(logits - torch.max(logits, dim=1, keepdim=True)[0], dim=1)
        msp = -1 * torch.max(smax, dim=1)[0]
      
    #      ce_to_unif = F.log_softmax(logits - torch.max(logits, dim=1, keepdim=True)[0], dim=1).mean(1)  # negative cross entropy
        # test = (F.softmax(logits - torch.max(logits, dim=1, keepdim=True)[0], dim=1) * (1 / torch.FloatTensor([logits.size(1)]).cuda().mean()).log()).sum(1)
#         test = -1 * (F.log_softmax(logits - torch.max(logits, dim=1, keepdim=True)[0], dim=1) * smax).sum(1)

        outlier_scores.extend(list(msp.data.numpy()))

    return outlier_scores



# ============================ OECC ============================ #

test_scores = get_scores(test_iter_sst)

titles = ['20 Newsgroup', 'TREC']

iterators = [train_iter_20ng, train_iter_trec]


mean_fprs = []
mean_aurocs = []
mean_auprs = []

f = open(f'./{args.save}/{args.in_dist_dataset}/{args.method}/{args.oe_dataset}/OECC_eval_results.txt', 'w')

for i in range(len(titles)):
    title = titles[i]
    iterator = iterators[i]
    
    if '30K' in title or '16' in title:
        translation_dataset=True
    else:
        translation_dataset=False
        
    print(f'\n{title}')
    f.write(f'\n{title}')
    fprs, aurocs, auprs = [], [], []
    for i in range(10):
        ood_scores = get_scores(iterator, ood=True, translation_dataset = translation_dataset, snli=True) if 'SNLI' in title else get_scores(iterator, ood=True, translation_dataset=translation_dataset)
        fpr, auroc, aupr = get_performance(ood_scores, test_scores, expected_ap, recall_level=recall_level)
        fprs.append(fpr)
        aurocs.append(auroc)
        auprs.append(aupr)

    print(f'FPR{int(100 * recall_level):d}:\t\t\t{np.mean(fprs):.4f} ({np.std(fprs):.4f})')
    f.write(f'\nFPR{int(100 * recall_level):d}:\t\t\t{np.mean(fprs):.4f} ({np.std(fprs):.4f})')
    print(f'AUROC:\t\t\t{np.mean(aurocs):.4f} ({np.std(aurocs):.4f})')
    f.write(f'\nAUROC:\t\t\t{np.mean(aurocs):.4f} ({np.std(aurocs):.4f})')
    print(f'AUPR:\t\t\t{np.mean(auprs):.4f} ({np.std(auprs):.4f})')
    f.write(f'\nAUPR:\t\t\t{np.mean(auprs):.4f} ({np.std(auprs):.4f})\n')

    mean_fprs.append(np.mean(fprs))
    mean_aurocs.append(np.mean(aurocs))
    mean_auprs.append(np.mean(auprs))

print()
print(f'OOD dataset mean FPR: {np.mean(mean_fprs):.4f}')
f.write(f'\nOOD dataset mean FPR: {np.mean(mean_fprs):.4f}')
print(f'OOD dataset mean AUROC: {np.mean(mean_aurocs):.4f}')
f.write(f'\nOOD dataset mean AUROC: {np.mean(mean_aurocs):.4f}')
print(f'OOD dataset mean AUPR: {np.mean(mean_auprs):.4f}')
f.write(f'\nOOD dataset mean AUPR: {np.mean(mean_auprs):.4f}')

f.close()