In [None]:
# !pip install pyro-ppl==1.3.1

In [9]:
import sys
sys.path.append('../')
sys.path.append('../../')
from src.posterior_networks.PosteriorNetwork import PosteriorNetwork
import json
import numpy as np
import pandas as pd
import wandb
from src.dataset_manager.ClassificationDataset import MapillaryDataset
from src.results_manager.metrics_prior import confidence, brier_score, anomaly_detection
import torch
import torchvision.transforms as transforms
from sklearn.metrics import balanced_accuracy_score
import matplotlib.pyplot as plt
from src.posterior_networks.config import config
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [16]:
train_df = pd.read_csv('/lab/project-1/train_label.csv')
val_df = pd.read_csv('/lab/project-1/val_label.csv')
test_df = pd.read_csv('/lab/project-1/test_label.csv')
regions = set(['g1', 'g2', 'g3', 'g4', 'g5', 'g6'])
classes = set(['regu', 'warn', 'comp', 'info'])
full_path = '/lab/project-1/final_models/resnet18_oodg4_lat8_reg1e4_dens6_batch64_lr1e4__'
# full_path = '../../src/posterior_networks/models/resnet18_oodg3g4_RandAug_ops3_mag3_bins31'

config = json.load(open(f'{full_path}/config.json'))


# if 'class_encoding' in config: 
#     class_encoding = config['class_encoding']
# else:
#     class_encoding = {c: i for i, c in enumerate(sorted(train_df.label.unique()))}

if "N" in config:
    N = config['N']
else:
    N = train_df.label_encoded.value_counts().sort_index().values
N = torch.tensor(N)

print(N)

model = PosteriorNetwork(N=N,
                         n_classes=config['num_classes'],
                         hidden_dims=config['hidden_dims'],
                         kernel_dim=None,
                         latent_dim=config['latent_dim'],
                         architecture=config['architecture'],
                         k_lipschitz=config['k_lipschitz'],
                         no_density=config['no_density'],
                         density_type=config['density_type'],
                         n_density=config['n_density'],
                         budget_function=config['budget_function'],
                         batch_size=config['batch_size'],
                         lr=config['lr'],
                         loss=config['loss'],
                         dropout=config['dropout'],
                         regr=config['regr'],
                         seed=config['seed'])

# model.load_state_dict(torch.load(f'{full_path}/best_model.pth')['model_state_dict'])
model.cuda()
transform_val_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_dataset = MapillaryDataset(val_df, transform = transform_val_test)

# use a dict to map ground truth vector of ID and OOD
ood_regions_classes = set(config['ood_regions'].split(','))
ood_regions = ood_regions_classes.intersection(regions)
ood_classes = ood_regions_classes.intersection(classes)

grd_truth = torch.tensor(val_df.region.isin(ood_regions).astype(int).values) # 0 as ID, 1 as OOD

tensor([ 7458,  3970, 21210,  9894])


Using cache found in /nfs/homedirs/zhz/.cache/torch/hub/pytorch_vision_main


In [17]:
config['density_type']

'batched_radial_flow'

In [3]:
# if error in model loading:
dict = torch.load(f'{full_path}/best_model.pth')['model_state_dict']
older_val = dict['sequential.11.weight']
dict['sequential.12.weight'] = dict.pop('sequential.11.weight')
older_val = dict['sequential.11.bias']
dict['sequential.12.bias'] = dict.pop('sequential.11.bias')
torch.save(full_path,'./model_changed.pth')
model.load_state_dict(dict)

<All keys matched successfully>

In [25]:
from tqdm import tqdm
@torch.no_grad()
def compute_X_Y_alpha(model, loader, alpha_only=False):
    for batch_index, (X, Y) in tqdm(enumerate(loader)):
        X, Y = X.to('cuda'), Y.to('cuda')
        # print(X.shape)
        alpha_pred = model(X, None, return_output='alpha', compute_loss=False)
        # print('batch:', batch_index)
        if batch_index == 0:
            X_duplicate_all = X.to("cpu")
            orig_Y_all = Y.to("cpu")
            alpha_pred_all = alpha_pred.to("cpu")
        else:
            X_duplicate_all = torch.cat([X_duplicate_all, X.to('cpu')], dim=0)
            orig_Y_all = torch.cat([orig_Y_all, Y.to('cpu')], dim=0)
            alpha_pred_all = torch.cat([alpha_pred_all, alpha_pred.to('cpu')], dim=0)
    if alpha_only:
        return alpha_pred_all
    else:
        return orig_Y_all, X_duplicate_all, alpha_pred_all

In [26]:
# find a threshold for g4 in val, and then test it on testset
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, num_workers=6, pin_memory=True)
model.eval()
val_orig_Y_all, val_X_duplicate_all, val_alpha_pred_all = compute_X_Y_alpha(model, val_loader)

0it [00:00, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper___slow_conv2d_forward)

In [21]:
print((val_alpha_pred_all.max()))

tensor(14.3329)


In [8]:
def true_false_table(grd_truth, pred):
    table = [[[] for i in range(2)] for i in range(2)] # 
    for i in range(len(grd_truth)):
        if grd_truth[i] == 1:
            if pred[i] == 1: # True positive
                table[0][0].append(i)
            elif pred[i] == 0: # False negative
                table[0][1].append(i)
            else:
                KeyError('wrong number')
        elif grd_truth[i] == 0:
            if pred[i] == 1: # False positive
                table[1][0].append(i)
            elif pred[i] == 0: # True negative
                table[1][1].append(i)
            else:
                KeyError('wrong number')
        else:
            KeyError('wrong number')
    return table

In [22]:
# use maximum of alpha as threshold
# torch.concat(torch.max(alpha_pred_all, dim=1).values, grd_truth)
from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score
max_alpha = torch.max(val_alpha_pred_all, dim=1).values
output = [[],[],[],[]]
for threshold in np.linspace(1,14,100):
    pred = (torch.max(val_alpha_pred_all, dim=1).values < threshold).int()
    acc = balanced_accuracy_score(grd_truth, pred)
    acc2 = accuracy_score(grd_truth, pred)
    auc = roc_auc_score(grd_truth, pred)
    output[0].append(threshold)
    output[1].append(acc)
    output[2].append(acc2)
    output[3].append(auc)

index = output[1].index(max(output[1]))
print("###")
print(f'threshold:{output[0][index]}, balanced accuracy:{output[1][index]}, accuracy:{output[2][index]}, area under curve:{output[3][index]}')

pred = (torch.max(val_alpha_pred_all, dim=1).values < output[0][index]).int()
table = true_false_table(grd_truth, pred)
print('TP:',len(table[0][0]))
print('FN:',len(table[0][1]))
print('FP:',len(table[1][0]))
print('TN',len(table[1][1]))

###
threshold:1.1313131313131313, balanced accuracy:0.5757527677076391, accuracy:0.5025345937799699, area under curve:0.575752767707639
TP: 296
FN: 153
FP: 3478
TN 3372


In [18]:
# see the performance on the test set
test_dataset = MapillaryDataset(test_df, transform = transform_val_test)
test_grd_truth = torch.tensor(test_df.region.isin(ood_regions).astype(int).values) # 0 as ID, 1 as OOD
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, num_workers=6, pin_memory=True)
model.eval()
test_orig_Y_all, test_X_duplicate_all, test_alpha_pred_all = compute_X_Y_alpha(model, test_loader)

625it [01:06,  9.45it/s]


In [26]:
# using max alpha as threshold, use data from ood_g4
from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score
threshold = 980
max_alpha = torch.max(test_alpha_pred_all, dim=1).values
pred = (torch.max(test_alpha_pred_all, dim=1).values < threshold).int()
balanced_accuracy = balanced_accuracy_score(test_grd_truth, pred)
accuracy = accuracy_score(test_grd_truth, pred)
auc = roc_auc_score(test_grd_truth, pred)
print(f'threshold:{threshold}, balanced accuracy:{balanced_accuracy}, accuracy:{accuracy}, area under curve:{auc}')

table = true_false_table(test_grd_truth, pred)
print('TP:',len(table[0][0]))
print('FN:',len(table[0][1]))
print('FP:',len(table[1][0]))
print('TN',len(table[1][1]))

threshold:980, balanced accuracy:0.585929837980331, accuracy:0.7656, area under curve:0.585929837980331
TP: 117
FN: 190
FP: 982
TN 3711


In [27]:
other_df = pd.read_csv('/lab/project-1/train_other_signs.csv').iloc[:5000]
test_df_less = pd.read_csv('/lab/project-1/test_label.csv').iloc[:2500]
full_df = pd.concat([test_df_less,other_df])
mix_sign_dataset = MapillaryDataset(full_df, transform = transform_val_test)
grd_truth1 = torch.tensor(test_df_less.region.isin(ood_regions).astype(int).values) # 0 as ID, 1 as OOD
grd_truth2 = torch.ones(5000)
mix_grd_truth = torch.cat([grd_truth1, grd_truth2])
mix_sign_loader = torch.utils.data.DataLoader(mix_sign_dataset,
                                                      batch_size=32,
                                                      num_workers=6, pin_memory=True)
mix_orig_Y_all, mix_X_duplicate_all, mix_alpha_pred_all = compute_X_Y_alpha(model, mix_sign_loader)

235it [00:45,  5.14it/s]


In [28]:
# using max alpha as threshold, use data from ood_g4
from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score
threshold = 980
max_alpha = torch.max(mix_alpha_pred_all, dim=1).values
pred = (torch.max(mix_alpha_pred_all, dim=1).values < threshold).int()
balanced_accuracy = balanced_accuracy_score(mix_grd_truth, pred)
accuracy = accuracy_score(mix_grd_truth, pred)
auc = roc_auc_score(mix_grd_truth, pred)
print(f'threshold:{threshold}, balanced accuracy:{balanced_accuracy}, accuracy:{accuracy}, area under curve:{auc}')

table = true_false_table(mix_grd_truth, pred)
print('TP:',len(table[0][0]))
print('FN:',len(table[0][1]))
print('FP:',len(table[1][0]))
print('TN',len(table[1][1]))

threshold:980, balanced accuracy:0.7577569071755118, accuracy:0.7465333333333334, area under curve:0.7577569071755118
TP: 3756
FN: 1404
FP: 497
TN 1843


In [16]:
mix_alpha_pred_all

tensor([[1.0000e+00, 1.1089e+00, 1.2509e+04, 1.0000e+00],
        [1.0000e+00, 2.0189e+00, 1.0000e+00, 4.9233e+03],
        [1.0000e+00, 1.0012e+00, 1.0374e+04, 1.0000e+00],
        ...,
        [1.2976e+03, 1.4969e+00, 1.7326e+01, 2.2678e+00],
        [4.2470e+01, 1.9358e+00, 2.0373e+03, 1.0024e+00],
        [1.0000e+00, 1.7508e+00, 1.1575e+04, 1.0000e+00]])

In [30]:
other_alpha = mix_alpha_pred_all[-5000:]
threshold = 980
max_alpha = torch.max(other_alpha, dim=1).values
pred = (torch.max(other_alpha, dim=1).values < threshold).int()
balanced_accuracy = balanced_accuracy_score(grd_truth2, pred)
accuracy = accuracy_score(grd_truth2, pred)
# auc = roc_auc_score(grd_truth2, pred)
print(f'threshold:{threshold}, balanced accuracy:{balanced_accuracy}, accuracy:{accuracy}, area under curve:{auc}')

table = true_false_table(grd_truth2, pred)
print('TP:',len(table[0][0]))
print('FN:',len(table[0][1]))
print('FP:',len(table[1][0]))
print('TN',len(table[1][1]))

threshold:980, balanced accuracy:0.7396, accuracy:0.7396, area under curve:0.7577569071755118
TP: 3698
FN: 1302
FP: 0
TN 0




In [31]:
other_alpha

tensor([[1.0000e+00, 1.3928e+02, 4.9065e+01, 1.0000e+00],
        [1.0000e+00, 5.0412e+02, 1.7244e+00, 1.0001e+00],
        [4.1003e+02, 5.8184e+00, 1.6184e+01, 1.0002e+00],
        ...,
        [5.1279e+02, 2.9726e+00, 9.7363e+00, 1.0000e+00],
        [1.0001e+00, 1.0428e+00, 1.1643e+03, 1.0000e+00],
        [1.0058e+00, 1.4489e+00, 8.1337e+02, 1.0000e+00]])