In [None]:
import os
import random
import pickle
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import (classification_report, confusion_matrix, 
precision_score, recall_score, f1_score, 
roc_curve, roc_auc_score, accuracy_score)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

if torch.cuda.is_available():
    cudnn.benchmark = True
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from main import setup_seed, DrebinLoader, FGSM, Net, logits_acc

In [None]:
with open(os.path.join("./drebin", "features.pkl"), "rb") as f:
    features = pickle.load(f)
print(features["activity::DCMetroQ"])

In [None]:
test_loader = DrebinLoader("./drebin", 64, 0.045, False)
test_loader = iter(test_loader)

In [None]:
setup_seed(0)
net = Net(test_loader.num_features)
net.load_state_dict(torch.load("AT.pth"))
net = net.to(device)

In [None]:
@torch.no_grad()
def test(model, dataloader, eps):
    model.eval()
    Acc = 0
    Labels, Preds, Values = [], [], []
    with tqdm(enumerate(test_loader), total=606) as t:
        for i, (x, y) in t:
            x, y = x.to(device), y.to(device)
            if eps>0:
                x = FGSM(model, x, y, eps)
            logits = F.softmax(model(x), dim=-1)
            preds = logits.argmax(dim=-1)
            values = logits[:, 1]
            Labels.append(y.cpu().numpy())
            Preds.append(preds.cpu().numpy())
            Values.append(values.cpu().numpy())
            Acc += logits_acc(logits, y)
            t.set_postfix(acc=f"{Acc/(i+1):6.2%}")
    Labels = np.hstack(Labels)
    Preds = np.hstack(Preds)
    Values = np.hstack(Values)
    return Labels, Preds, Values

In [None]:
def metric(labels, preds, values):
    acc = accuracy_score(y_true=labels, y_pred=preds)
    p = precision_score(y_true=labels, y_pred=preds, pos_label=1)
    r = recall_score(y_true=labels, y_pred=preds, pos_label=1)
    f1 = f1_score(y_true=labels, y_pred=preds, pos_label=1)
    auc = roc_auc_score(y_true=labels, y_score=values)
    return acc, p, r, f1, auc

In [None]:
mertrics = []
for eps in range(10):
    acc, p, r, f1, auc = metric(*test(net, test_loader, eps))
    mertrics.append([acc, p, r, f1, auc])
mertrics = np.vstack(mertrics)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=150)
ax.plot(mertrics[:, 0], color="C0")
ax.set_title("Accuracy")
ax.set_xlabel(r"$\epsilon$")
ax.set_xticks(np.arange(10))
ax.set_ylim(0.7, 1.0)
plt.show()

fig, axs = plt.subplots(1, 4, figsize=(12, 3), dpi=150)

axs[0].plot(mertrics[:, 1], color="C1")
axs[0].set_title("Precision")
axs[0].set_ylim(0, 1)

axs[1].plot(mertrics[:, 2], color="C2")
axs[1].set_title("Recall")
axs[1].set_ylim(0.9, 1)

axs[2].plot(mertrics[:, 3], color="C3")
axs[2].set_title("F1")
axs[2].set_ylim(0, 1)

axs[3].plot(mertrics[:, 4], color="C4")
axs[3].set_title("AUC")
axs[3].set_ylim(0.9, 1)

for ax in axs:
    ax.set_xlabel(r"$\epsilon$")
    ax.set_xticks(np.arange(5)*2)

plt.show()

In [None]:
Labels, Preds, Values = test(net, test_loader, 8)
metric(Labels, Preds, Values)

In [None]:
fpr, tpr, thresholds = roc_curve(y_true=Labels, y_score=Values, pos_label=1)
cm = confusion_matrix(Labels, Preds)
cmn = cm/cm.sum(axis=1, keepdims=True)

font = {"color": "darkred", "size": 13, "family" : "serif"}

fig, axs = plt.subplots(1, 2, figsize=(7, 3), dpi=160)
axs[0].set_title("ROC curve \n", fontdict=font)
axs[0].set_xlabel("FP rate", fontdict=font)
axs[0].set_ylabel("TP rate", fontdict=font)
axs[0].plot(fpr,tpr, color="C1", lw=1)
axs[0].plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')

axs[1].set_title("Confusion Matrix", fontdict=font)
axs[1].set_xlabel("Pred", fontdict=font)
axs[1].set_ylabel("True", fontdict=font)
# axs[1].set_yticklabels(["Neg", "Pos"])

for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        axs[1].text(j, i, f"{cmn[i, j]:6.2%}\n({cm[i, j]})", verticalalignment="center", horizontalalignment="center", color='black' if cmn[i, j] > 0.5 else 'white')
axs[1].matshow(cmn)
plt.xticks([0, 1], ["Neg", "Pos"])
plt.yticks([0, 1], ["Neg", "Pos"], rotation=90)
plt.savefig("ST.png", format="PNG", dpi=120)
plt.show()
