In [1]:
import torch
import torchvision.transforms as tvt
from torch.utils.data import DataLoader
from tqdm import tqdm
from pytorch_ood.dataset.img import Textures
from pytorch_ood.detector import OpenMax
from pytorch_ood.utils import OODMetrics, ToUnknown

import pandas as pd
import os
torch.manual_seed(1234)
device = "cuda:2"

  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# load dev set
devdf = pd.read_csv("xinwang_vocoders/data/voc.v4/protocol.txt", sep=" ", header=None)
devdf.columns = ["path", "subset","label"]

In [8]:
# load out-of-distribution set
melgan_dir = 'xinwang_vocoders/data/voc.v4/MelGAN'
file_list = os.listdir(melgan_dir)

# make dataframe for out-of-distribution set
outdf = pd.DataFrame()
outdf['path'] = file_list
outdf['subset'] = 'test'
outdf['label'] = 'unknown'
outdf['path'] = outdf['path'].apply(lambda x: os.path.join('MelGAN', x))

In [11]:
outdf.head()

Unnamed: 0,path,subset,label
0,MelGan/LA_D_3727888.wav,test,unknown
1,MelGan/LA_T_4565832.wav,test,unknown
2,MelGan/LA_D_6446182.wav,test,unknown
3,MelGan/LA_T_1407047.wav,test,unknown
4,MelGan/LA_D_5542285.wav,test,unknown


In [16]:
testdf = pd.concat([devdf[devdf['subset'] == 'dev'], outdf[:2028]])

In [17]:
testdf.label.value_counts()

bonafide            2074
hn-sinc-nsf         2072
waveglow            2062
unknown             2028
hifigan             2026
hn-sinc-nsf-hifi    2023
Name: label, dtype: int64

In [22]:
# write to protocol_test.txt
testdf['subset'] = 'dev'
testdf.to_csv('xinwang_vocoders/data/voc.v4/protocol_test.txt', sep=" ", header=None, index=False)

## Load model and data to fit the OpenMax

In [2]:
from datautils.vocv4 import genList, Dataset_for, Dataset_for_eval
from model.wav2vec2_resnet import Model
import yaml

In [3]:
# load config
config = yaml.load(open("configs/wav2vec2_resnet.yaml", "r"), Loader=yaml.FullLoader)
# load model
model = Model(args=config['model'], device=device, emb=False)
# load state dict
model.load_state_dict(torch.load('out/model_weighted_CCE_100_8_1e-06_wav2vec2_resnet/epoch_31.pth',map_location=device))
model = model.to(device)

In [4]:
# load train data
d_label_trn,file_train = genList( dir_meta =  os.path.join('xinwang_vocoders/data/voc.v4/','protocol.txt'),is_train=True,is_eval=False,is_dev=False)
print('no. of training trials',len(file_train))
train_set=Dataset_for(config,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join('xinwang_vocoders/data/voc.v4/'),algo=1)
train_loader = DataLoader(train_set, batch_size=16,num_workers=8, shuffle=True,drop_last = True)
del train_set,d_label_trn

no. of training trials 15383


In [6]:
# Load test data
d_label_dev, file_dev = genList(dir_meta = os.path.join('xinwang_vocoders/data/voc.v4/','protocol_test.txt'),is_train=False,is_eval=False, is_dev=True)
    
print('no. of validation trials',len(file_dev))
    
dev_set = Dataset_for(config,list_IDs = file_dev,
	labels = d_label_dev,
	base_dir = os.path.join('xinwang_vocoders/data/voc.v4/'),algo=1)
test_loader = DataLoader(dev_set, batch_size=16,num_workers=8, shuffle=False)
# del dev_set,d_label_dev


no. of validation trials 12285


In [8]:
detector = OpenMax(model, alpha=5)
detector.fit(train_loader, device=device)

<pytorch_ood.detector.openmax.torch.OpenMax at 0x7f6c7dbfcc10>

In [21]:
for i in range(10):
    x, y = dev_set[len(dev_set)-i-1]
    print(detector.predict(x.unsqueeze(0).to(device)))

tensor([0.1759], dtype=torch.float64)
tensor([0.1727], dtype=torch.float64)
tensor([0.1786], dtype=torch.float64)
tensor([0.1750], dtype=torch.float64)
tensor([0.1662], dtype=torch.float64)
tensor([0.1767], dtype=torch.float64)
tensor([0.1696], dtype=torch.float64)
tensor([0.1623], dtype=torch.float64)
tensor([0.1827], dtype=torch.float64)
tensor([0.1751], dtype=torch.float64)


In [9]:
metrics = OODMetrics()
for x, y in test_loader:
    score = detector(x.to(device))
    # print(score, y)
    metrics.update(score, y)


In [10]:
scores = metrics.buffer.get('scores')
labels = metrics.buffer.get('labels')

In [11]:
scores, scores_idx = torch.sort(scores, stable=True)
labels = labels[scores_idx]

In [14]:
precision, recall, threshold = binary_precision_recall_curve(scores, labels)
print(len(threshold))


12286


### MaxSoftmax


In [12]:
from torchmetrics.functional.classification import (
    binary_auroc,
    binary_precision_recall_curve,
    binary_roc,
)

In [7]:
from pytorch_ood.detector import MaxSoftmax

detector = MaxSoftmax(model)
metrics = OODMetrics()
for x, y in test_loader:
    metrics.update(detector(x.to(device)), y)

In [8]:
scores = metrics.buffer.get('scores')
labels = metrics.buffer.get('labels')
scores, scores_idx = torch.sort(scores, stable=True)
labels = labels[scores_idx]

In [11]:
auroc = binary_auroc(scores, labels)
print(auroc)
print(metrics.compute())

tensor(0.9968, dtype=torch.float64)
{'AUROC': 0.996752917766571, 'AUPR-IN': 0.9879894256591797, 'AUPR-OUT': 0.9992644190788269, 'FPR95TPR': 0.007019596174359322}


In [36]:
def compute_det_curve(target_scores, nontarget_scores):

    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))

    # Sort labels based on scores
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]

    # Compute false rejection and false acceptance rates
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)

    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size))  # false rejection rates
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size))  # false acceptance rates
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))  # Thresholds are the sorted scores

    return frr, far, thresholds

def calculate_confusion_matrix(target_scores, nontarget_scores, threshold):
    """
    Calculate the confusion matrix for a given threshold.
    return: tp, tn, fp, fn
    """
    tp = np.sum(target_scores > threshold)
    tn = np.sum(nontarget_scores <= threshold)
    fn = np.sum(target_scores <= threshold)
    fp = np.sum(nontarget_scores > threshold)
    return tp, tn, fp, fn

def compute_eer(target_scores, nontarget_scores):
    """ Returns equal error rate (EER) and the corresponding threshold. """
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]

In [37]:
known_scores = scores[labels == 0]
unknown_scores = scores[labels == 1]

In [45]:
eer, th = compute_eer(unknown_scores.cpu().numpy(), known_scores.cpu().numpy())
tp, tn, fp, fn = calculate_confusion_matrix(unknown_scores.cpu().numpy(), known_scores.cpu().numpy(), th)

print("EER: {:.2f}%, threshold: {:.4f}".format(eer * 100, th))
print("TP: {}, TN: {}, FP: {}, FN: {}".format(tp, tn, fp, fn))

EER: 1.92%, threshold: -0.9961
TP: 1989, TN: 10060, FP: 197, FN: 39


In [46]:
from torcheval.metrics import BinaryAccuracy
metric = BinaryAccuracy(threshold=-0.9961)
metric.update(scores, labels)
print(metric.compute())



tensor(0.9809)
