In [None]:
import os
os.chdir('/home/mmr/DUQ')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR

import matplotlib.pyplot as plt
import numpy as np
import ipdb

from src.model.backbone import Backbone
from src.model.duqmodel import DUQ
from src.utils.utils import grad_penalty,train_duq,ood_detection_eval

In [None]:
batch_size  = 100

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.286,), (0.353,))
        ])

M_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

data_root = '~/DUQ/data/fashion_mnist/'
M_data_root = '~/DUQ/data/mnist/'

FM_train_dataset = torchvision.datasets.FashionMNIST(data_root,transform=transform)
FM_train_loader  = torch.utils.data.DataLoader(FM_train_dataset,batch_size = batch_size)

FM_test_dataset  = torchvision.datasets.FashionMNIST(data_root,transform=transform,train=False)
M_test_dataset   = torchvision.datasets.MNIST(M_data_root,transform=M_transform,train=False)

ood_dataset = torch.utils.data.ConcatDataset([FM_test_dataset,M_test_dataset])
ood_targets = torch.cat([torch.zeros(len(FM_test_dataset)),torch.ones(len(M_test_dataset))])

ood_loader = torch.utils.data.DataLoader(ood_dataset,batch_size=batch_size,shuffle=False)

In [None]:
sigma = 0.1
gamma = .99
num_classes = 10
emb_size = 256
lambda_ = 0.05 

duq_model = DUQ(sigma,gamma,num_classes,emb_size).cuda()
opt       = optim.SGD(duq_model.parameters(),lr = 0.05,momentum=0.9,weight_decay = 1e-4)
sched     = StepLR(opt,step_size=10,gamma = 0.2)

In [None]:
for epoch in range(30):
    train_duq(duq_model,epoch,'cuda',FM_train_loader,opt,lambda_)
    sched.step()

In [None]:
eval_scores = ood_detection_eval(duq_model,'cuda',ood_loader)

In [None]:
from sklearn.metrics import roc_auc_score
print(roc_auc_score(ood_targets.numpy(),eval_scores))