In [None]:
import numpy as np
import pandas as pd

from flagger import BayesFlaggerBeta, SVMFlagger
from flagger_tools import BayesFlaggerBetaTest, SVMFlaggerTest, ModelError
from flagger_tools import get_sampler_model
from matplotlib import pyplot as plt
from tmvbeta import TMVBeta

### Synthetic Data

Load model parameters

In [None]:
params = np.load("Data/parameters.npz")
a = params['a']
b = params['b']
cov = params['cov']

Create sampler from model

In [None]:
r = 0.2
P0 = TMVBeta(a[0], b[0], cov[0])
P1 = TMVBeta(a[1], b[1], cov[1])
sampler = get_sampler_model(r, P0, P1)

Initialize Bayes and SVM flagger

In [None]:
K = 20
M = 10
bayes_flagger_pol_1 = BayesFlaggerBeta(K, M, rule="detection")
bayes_flagger_pol_3 = BayesFlaggerBeta(K, M, rule="mixed")
svm_flagger = SVMFlagger(K, M)

Initialize testing classes

In [None]:
N = 100
bayes_test_pol_1 = BayesFlaggerBetaTest(bayes_flagger_pol_1, sampler, N)
bayes_test_pol_3 = BayesFlaggerBetaTest(bayes_flagger_pol_3, sampler, N)
svm_test = SVMFlaggerTest(svm_flagger, sampler, N)

Simulate flagging under given model (In the paper, results are averaged over 20 runs. Here, we simulate a single run for the sake of time.)

In [None]:
T = 100
total_pol_1, detected_pol_1, model_pol_1, _ = bayes_test_pol_1.run(T)
total_pol_3, detected_pol_3, model_pol_3, phi = bayes_test_pol_3.run(T)
total_svm, detected_svm = svm_test.run(T)

Plot detection rates

In [None]:
plt.plot(range(1, T+1), detected_pol_1[1:]/total_pol_1[1:], label="Detection-Greedy")
plt.plot(range(1, T+1), detected_pol_3[1:]/total_pol_3[1:], label="Mixed")
plt.plot(range(1, T+1), detected_svm[1:]/total_svm[1:], label="SVM")
plt.title("Detection Rate")
plt.ylabel("Detection Rate")
plt.xlabel("Administration")
plt.legend()
plt.grid()

Plot model errors

In [None]:
model_err = ModelError(r, a, b, cov)

In [None]:
# MSE of R in dB
mse_r_pol_1 = 10 * np.log10(np.array([model_err.mse_r(model) for model in model_pol_1]))
mse_r_pol_3 = 10 * np.log10(np.array([model_err.mse_r(model) for model in model_pol_3]))

In [None]:
plt.plot(range(T+1), mse_r_pol_1, label="Detection-Greedy")
plt.plot(range(T+1), mse_r_pol_3, label="Mixed")
plt.title("MSE($R$)")
plt.ylabel("MSE (dB)")
plt.xlabel("Administration")
plt.legend()
plt.grid()

In [None]:
# MSE of covariance matrix for critical group in dB
mse_cov1_pol_1 = 10 * np.log10(np.array([model_err.mse_cov(1, model) for model in model_pol_1]))
mse_cov1_pol_3 = 10 * np.log10(np.array([model_err.mse_cov(1, model) for model in model_pol_3]))

In [None]:
plt.plot(range(T+1), mse_cov1_pol_1, label="Detection-Greedy")
plt.plot(range(T+1), mse_cov1_pol_3, label="Mixed")
plt.title(r"MSE($\Sigma_1$)")
plt.ylabel("MSE (dB)")
plt.xlabel("Administration")
plt.legend()
plt.grid()

In [None]:
plt.plot(range(1, T+1), phi[1:], label=r"$\phi$")
plt.title("Mixing parameter of Policy 3")
plt.ylabel(r"$\phi$")
plt.xlabel("Administration")
plt.legend()
plt.grid()