## Confusion Matrix

This notebook generates a confusion matrix for QUIC and TCP traces
in the setting where the classifier is trained on TCP samples but 
evaluated on a mixture of QUIC and TCP traces.

---

In [1]:
import itertools
import json
import pathlib

import h5py
import yaml
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

import lab.metrics

In [2]:
try:
    FILES = {
        "dataset": str(snakemake.input["dataset"]),
        "splits": str(snakemake.input["splits"]),
        "predictions": str(snakemake.input["predictions"]),
        "table": str(snakemake.output)
    }
except NameError:
    FILES = {
        "dataset": "../results/open-world-dataset.hdf",
        "splits": "../results/splits-quic.json",
        "predictions": "../results/predictions-varcnn-quic-all.csv",
        "table": "../results/plots/confusion-matrix.tex"
    }

In [3]:
def load_test_protocols():
    with open(FILES["splits"], mode="r") as infile:
        splits = [json.loads(line) for line in infile]
    
    with h5py.File(FILES["dataset"], mode="r") as h5in:
        labels = pd.DataFrame.from_records(np.asarray(h5in["labels"]))
        
    result = pd.DataFrame([labels.iloc[split["test"]]["protocol"].values for split in splits]).stack().reset_index()
    result.columns = ["run", "sample", "protocol"]
    result["protocol"] = result["protocol"].str.decode("ascii")
    
    return result.set_index(["run", "sample"])
    
PROTOCOLS = load_test_protocols()
PROTOCOLS

Unnamed: 0_level_0,Unnamed: 1_level_0,protocol
run,sample,Unnamed: 2_level_1
0,0,quic
0,1,tcp
0,2,quic
0,3,tcp
0,4,quic
...,...,...
19,5852,tcp
19,5853,tcp
19,5854,tcp
19,5855,tcp


In [4]:
def load_data():
    data = pd.read_csv(FILES["predictions"], index_col="run")
    data["sample"] = data.groupby("run").cumcount()
    return data.set_index("sample", append=True)

data = load_data()
data

Unnamed: 0_level_0,Unnamed: 1_level_0,y_true,-1,0,1,2,3,4,5,6,7,...,90,91,92,93,94,95,96,97,98,99
run,sample,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
0.0,0,-1.0,0.999720,1.152128e-09,1.314624e-12,8.793840e-16,1.697146e-09,2.295330e-10,1.724003e-06,1.360647e-06,9.091011e-14,...,6.537304e-10,9.363732e-11,2.124916e-17,6.351560e-11,3.305244e-14,4.440015e-13,6.139023e-15,2.099051e-05,1.510373e-10,1.546469e-13
0.0,1,-1.0,0.992349,2.041564e-05,1.877511e-10,1.248981e-13,2.375705e-05,9.011805e-06,6.414177e-07,7.170300e-07,6.842717e-13,...,2.174591e-08,2.621962e-09,8.258945e-14,4.143330e-11,3.222561e-14,1.118325e-11,7.000126e-13,1.150927e-04,4.240752e-08,5.698959e-09
0.0,2,92.0,0.999907,3.219322e-20,3.458482e-20,1.322707e-13,9.390239e-19,3.294562e-19,2.446975e-12,5.700321e-17,8.045181e-10,...,5.474686e-12,2.824317e-18,3.608101e-07,4.027124e-14,2.930350e-15,9.397168e-09,6.216293e-12,1.171210e-17,8.343560e-22,5.803802e-08
0.0,3,-1.0,0.998586,7.973108e-06,1.163289e-06,2.161589e-11,4.863753e-06,1.024014e-10,3.002585e-08,8.146264e-10,4.860521e-12,...,1.304890e-08,6.245508e-09,1.902460e-16,4.776269e-11,4.610695e-09,3.232716e-11,4.513575e-11,3.367262e-07,4.279430e-10,2.634683e-08
0.0,4,-1.0,0.999630,4.454251e-25,5.424241e-15,1.000717e-19,2.890122e-18,3.122108e-25,2.263338e-12,1.715319e-22,4.584334e-10,...,6.049110e-12,7.757189e-23,8.955893e-16,3.615912e-17,9.012982e-18,1.471384e-12,4.095025e-15,6.619180e-14,7.232364e-25,3.489004e-08
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9.0,5852,-1.0,0.986896,3.528975e-23,9.546258e-20,1.952789e-08,7.716641e-23,1.333886e-22,2.924288e-10,8.358027e-20,1.519002e-07,...,1.560439e-17,4.705288e-20,3.914173e-07,8.760079e-11,1.888168e-10,3.466265e-04,2.692394e-09,5.267773e-22,3.701092e-28,2.888826e-12
9.0,5853,-1.0,0.991453,3.429138e-19,1.414835e-08,1.661993e-12,1.344338e-12,7.003602e-14,5.866649e-06,3.004987e-10,6.657966e-08,...,6.076485e-14,7.275848e-17,3.125576e-09,1.388958e-07,1.731548e-09,5.643792e-07,8.430226e-12,7.598989e-10,5.921206e-13,1.871475e-05
9.0,5854,40.0,0.493850,2.590780e-08,2.089568e-11,2.318362e-17,1.161703e-07,2.548839e-08,3.147360e-16,5.757980e-12,6.605399e-19,...,3.134577e-14,3.182961e-11,6.201848e-16,1.439186e-08,6.755401e-16,1.374892e-13,7.255454e-16,5.766947e-09,5.072624e-09,9.463481e-17
9.0,5855,56.0,0.996009,4.900925e-14,7.256077e-19,6.767293e-07,4.461758e-17,1.126489e-13,5.993189e-08,4.598536e-13,3.690154e-13,...,2.265150e-11,4.780270e-10,1.241584e-11,2.374251e-08,3.846271e-19,2.750442e-12,6.890900e-15,3.981186e-11,2.080682e-15,8.218090e-15


In [5]:
PROTOCOL_DATA = data.merge(PROTOCOLS, on=["run", "sample"], how="left")
PROTOCOL_DATA

Unnamed: 0_level_0,Unnamed: 1_level_0,y_true,-1,0,1,2,3,4,5,6,7,...,91,92,93,94,95,96,97,98,99,protocol
run,sample,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
0.0,0,-1.0,0.999720,1.152128e-09,1.314624e-12,8.793840e-16,1.697146e-09,2.295330e-10,1.724003e-06,1.360647e-06,9.091011e-14,...,9.363732e-11,2.124916e-17,6.351560e-11,3.305244e-14,4.440015e-13,6.139023e-15,2.099051e-05,1.510373e-10,1.546469e-13,quic
0.0,1,-1.0,0.992349,2.041564e-05,1.877511e-10,1.248981e-13,2.375705e-05,9.011805e-06,6.414177e-07,7.170300e-07,6.842717e-13,...,2.621962e-09,8.258945e-14,4.143330e-11,3.222561e-14,1.118325e-11,7.000126e-13,1.150927e-04,4.240752e-08,5.698959e-09,tcp
0.0,2,92.0,0.999907,3.219322e-20,3.458482e-20,1.322707e-13,9.390239e-19,3.294562e-19,2.446975e-12,5.700321e-17,8.045181e-10,...,2.824317e-18,3.608101e-07,4.027124e-14,2.930350e-15,9.397168e-09,6.216293e-12,1.171210e-17,8.343560e-22,5.803802e-08,quic
0.0,3,-1.0,0.998586,7.973108e-06,1.163289e-06,2.161589e-11,4.863753e-06,1.024014e-10,3.002585e-08,8.146264e-10,4.860521e-12,...,6.245508e-09,1.902460e-16,4.776269e-11,4.610695e-09,3.232716e-11,4.513575e-11,3.367262e-07,4.279430e-10,2.634683e-08,tcp
0.0,4,-1.0,0.999630,4.454251e-25,5.424241e-15,1.000717e-19,2.890122e-18,3.122108e-25,2.263338e-12,1.715319e-22,4.584334e-10,...,7.757189e-23,8.955893e-16,3.615912e-17,9.012982e-18,1.471384e-12,4.095025e-15,6.619180e-14,7.232364e-25,3.489004e-08,quic
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9.0,5852,-1.0,0.986896,3.528975e-23,9.546258e-20,1.952789e-08,7.716641e-23,1.333886e-22,2.924288e-10,8.358027e-20,1.519002e-07,...,4.705288e-20,3.914173e-07,8.760079e-11,1.888168e-10,3.466265e-04,2.692394e-09,5.267773e-22,3.701092e-28,2.888826e-12,quic
9.0,5853,-1.0,0.991453,3.429138e-19,1.414835e-08,1.661993e-12,1.344338e-12,7.003602e-14,5.866649e-06,3.004987e-10,6.657966e-08,...,7.275848e-17,3.125576e-09,1.388958e-07,1.731548e-09,5.643792e-07,8.430226e-12,7.598989e-10,5.921206e-13,1.871475e-05,quic
9.0,5854,40.0,0.493850,2.590780e-08,2.089568e-11,2.318362e-17,1.161703e-07,2.548839e-08,3.147360e-16,5.757980e-12,6.605399e-19,...,3.182961e-11,6.201848e-16,1.439186e-08,6.755401e-16,1.374892e-13,7.255454e-16,5.766947e-09,5.072624e-09,9.463481e-17,tcp
9.0,5855,56.0,0.996009,4.900925e-14,7.256077e-19,6.767293e-07,4.461758e-17,1.126489e-13,5.993189e-08,4.598536e-13,3.690154e-13,...,4.780270e-10,1.241584e-11,2.374251e-08,3.846271e-19,2.750442e-12,6.890900e-15,3.981186e-11,2.080682e-15,8.218090e-15,quic


In [6]:
def make_predictions(frame):
    classes = np.arange(-1, 100)
    class_cols = [str(class_) for class_ in classes]
    probabilities = frame.loc[:, class_cols].values
    
    frame["y_pred"] = classes[np.argmax(probabilities, axis=1)]
    return frame[["y_true", "y_pred", "protocol"]]

MATRIX_DATA = make_predictions(PROTOCOL_DATA).sort_index()
MATRIX_DATA

Unnamed: 0_level_0,Unnamed: 1_level_0,y_true,y_pred,protocol
run,sample,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0.0,0,-1.0,-1,quic
0.0,1,-1.0,-1,tcp
0.0,2,92.0,-1,quic
0.0,3,-1.0,-1,tcp
0.0,4,-1.0,-1,quic
...,...,...,...,...
9.0,5852,-1.0,-1,quic
9.0,5853,-1.0,-1,quic
9.0,5854,40.0,40,tcp
9.0,5855,56.0,-1,quic


In [7]:
TABLE_DATA = MATRIX_DATA.copy()

TABLE_DATA[["y_true", "y_pred"]] = TABLE_DATA[["y_true", "y_pred"]].where(TABLE_DATA[["y_true", "y_pred"]] == -1, 1)
TABLE_DATA = (TABLE_DATA.groupby(["y_true", "y_pred", "protocol"])
           .size()
           .rename({-1: "Unmonitored", 1: "Monitored"})
           .rename(str.upper, level="protocol")
           .unstack("y_pred"))

TABLE_DATA

Unnamed: 0_level_0,y_pred,Monitored,Unmonitored
y_true,protocol,Unnamed: 2_level_1,Unnamed: 3_level_1
Monitored,QUIC,199,4801
Monitored,TCP,4776,224
Unmonitored,QUIC,7,24293
Unmonitored,TCP,40,24230


In [8]:
TABLE_DATA["misclassify"] = TABLE_DATA["Unmonitored"] / (TABLE_DATA["Monitored"] + TABLE_DATA["Unmonitored"])
TABLE_DATA.loc[[False, False, True, True], "misclassify"] = TABLE_DATA["Monitored"] / (TABLE_DATA["Monitored"] + TABLE_DATA["Unmonitored"])
TABLE_DATA["misclassify"] *= 100
TABLE_DATA

Unnamed: 0_level_0,y_pred,Monitored,Unmonitored,misclassify
y_true,protocol,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Monitored,QUIC,199,4801,96.02
Monitored,TCP,4776,224,4.48
Unmonitored,QUIC,7,24293,0.028807
Unmonitored,TCP,40,24230,0.164813


In [9]:
table = r"""\begin{{tabular}}{{lrrr}}
    \toprule
    Actual & \multicolumn{{2}}{{c}}{{Predicted}} & Misclassify (\%)\\
    \cmidrule{{2-3}}
      &  Monitored &  Unmonitored \\
    \midrule
    Monitored \\
    \quad QUIC &  {table[0][0]:,.0f} & {table[0][1]:,.0f} & {table[0][2]:.2f} \\
    \quad TCP &   {table[1][0]:,.0f} & {table[1][1]:,.0f} & {table[1][2]:.2f} \\
    Unmonitored \\
    \quad QUIC &  {table[2][0]:,.0f} & {table[2][1]:,.0f} & {table[2][2]:.2f} \\
    \quad TCP &   {table[3][0]:,.0f} & {table[3][1]:,.0f} & {table[3][2]:.2f} \\
    \bottomrule
\end{{tabular}}
""".format(table=TABLE_DATA.to_numpy())

print(table)

pathlib.Path(FILES["table"]).write_text(table)

\begin{tabular}{lrrr}
    \toprule
    Actual & \multicolumn{2}{c}{Predicted} & Misclassify (\%)\\
    \cmidrule{2-3}
      &  Monitored &  Unmonitored \\
    \midrule
    Monitored \\
    \quad QUIC &  199 & 4,801 & 96.02 \\
    \quad TCP &   4,776 & 224 & 4.48 \\
    Unmonitored \\
    \quad QUIC &  7 & 24,293 & 0.03 \\
    \quad TCP &   40 & 24,230 & 0.16 \\
    \bottomrule
\end{tabular}

