In [None]:
import matplotlib.pyplot as plt
from tn4ml.eval import compare_AUC, compare_FPR_per_TPR, compare_TPR_per_FPR

In [None]:
plt.rcParams['yaxis.labellocation'] = 'center'
plt.rcParams['xaxis.labellocation'] = 'center'
plt.rcParams['lines.markersize'] = 10
plt.rcParams['lines.markeredgewidth'] = 2.0
plt.rcParams['xtick.minor.top'] = False    # draw x axis top minor ticks
plt.rcParams['xtick.minor.bottom'] = False    # draw x axis bottom minor ticks
plt.rcParams['ytick.minor.left'] = True    # draw x axis top minor ticks
plt.rcParams['ytick.minor.right'] = True    # draw x axis bottom minor ticks
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16
plt.rcParams['legend.fontsize'] = 16
plt.rcParams['font.size'] = 16

In [None]:
initializers_strings = ["glor", "he", "ortho", "gram", "randn"]
initializers = [
                "glorot_n",
                "he_n",
                "orthogonal",
                "gramschmidt_n_1e-1", 
                "randn_1e-1"
                ]
embedding_string = 'trigonometric'

In [None]:
BONDS = [5, 10, 30, 50]
SPACINGS = [4, 8, 16, 32, 64]
NORMAL_CLASSES = [0, 3, 4]

In [None]:
LABELS = {'5': (r'bond = 5', 'o', '#016c59'),
          '10': (r'bond = 10','X', '#7a5195'),
          '30': (r'bond = 30', 'v', '#67a9cf'),
          '50': (r'bond = 50', 'd', '#ffa600')}

In [None]:
save_dir = 'results'

In [None]:
for normal_class in NORMAL_CLASSES:
    compare_AUC(save_dir=save_dir+f'/normal_class_{normal_class}',
                bond_dims=BONDS,
                spacings=SPACINGS,
                initializers=initializers,
                embedding=embedding_string,
                nruns=1,
                labels=LABELS,
                anomaly_det=True)
    compare_TPR_per_FPR(save_dir=save_dir+f'/normal_class_{normal_class}',
                        FPR_fixed=0.1,
                        bond_dims=BONDS,
                        spacings=SPACINGS,
                        initializers=initializers,
                        embedding=embedding_string,
                        nruns=1,
                        labels=LABELS,
                        anomaly_det=True)
    compare_FPR_per_TPR(save_dir=save_dir+f'/normal_class_{normal_class}',
                        TPR_fixed=0.95,
                        bond_dims=BONDS,
                        spacings=SPACINGS,
                        initializers=initializers,
                        embedding=embedding_string,
                        nruns=1,
                        labels=LABELS,
                        anomaly_det=True)