In [None]:
import os
import ot
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kstest
from tqdm import tqdm
module_path = os.path.abspath(os.path.join('..','scripts'))
if module_path not in sys.path:
    sys.path.append(module_path)
from W2ChangePoints import ChangePointDetector
data_path = os.path.abspath(os.path.join('..', 'data', 'Langevin_1D.txt'))
truth_path = os.path.abspath(os.path.join('..', 'data', 'ChangePts.txt'))

In [None]:
data = np.loadtxt(data_path)
truth_data = np.loadtxt(truth_path)

In [None]:
true_cps = [t in truth_data for t in range(len(data))]

In [None]:
recall_scores = np.empty((10,10))
for i, window_size in enumerate(range(50, 550, 50)):
    dw_data = [ot.emd2_1d(data[t - window_size:t], data[t:t + window_size]) for t in range(window_size, len(data) - window_size)]
    for j, tolerance in enumerate(range(25, 275, 25)):
        
        dw_tp, dw_fp, dw_fn = [],[],[]
        dw_cutoff = np.quantile(dw_data, 0.85)
        dw_cps = [x > dw_cutoff for x in dw_data]
         
        dw_cp_times = []
        dw_cp_tp_fn = []
        for t in range(len(data) - 2 * window_size):
            left = t + window_size - tolerance # NB: we need to right shift all of our timestamps by amount window_size because thats when the first possible change_point is!
            right = t + window_size + tolerance
            window = true_cps[left:right]
            if dw_cps[t] and True in window:
                dw_tp.append(t)
                dw_cp_times.append(t)
                dw_cp_tp_fn.append("true positive")
            # if we detected a change point and there IS NOT a true change point in the tolerance window, fp++
            if dw_cps[t] and True not in window:
                dw_fp.append(t)
            # if we did not detect and change point and there IS a true change point int he tolerance window, fn++
            if not dw_cps[t] and True in window:
                dw_fn.append(t)
                dw_cp_times.append(t)
                dw_cp_tp_fn.append("false negative")
        dw_recall = len(dw_tp)/ (len(dw_tp) + len(dw_fn))
        dw_precision = len(dw_tp) / (len(dw_tp) + len(dw_fp))
        dw_f1 = 2 / ((1 / dw_recall) + (1 / dw_precision))
        recall_scores[i,j] = dw_recall


In [None]:
ax = sns.heatmap(recall_scores)
plt.title("Recall Scores for Wasserstein Differentials, True Moving Window")
plt.xlabel("Error Tolerance")
plt.ylabel("Distribution/Window Size")
ax.set_xticklabels([x for x in range(25, 275, 25)])
ax.set_yticklabels([x for x in range(50,550,50)])


In [None]:
window_size = 350
tolerance = 25
dw_data = [ot.emd2_1d(data[t - window_size:t], data[t:t + window_size]) for t in range(window_size, len(data) - window_size)]
dw_tp, dw_fp, dw_fn = [],[],[]
dw_cutoff = np.quantile(dw_data, 0.85)
dw_cps = [x > dw_cutoff for x in dw_data]
 
dw_cp_times = [] dw_cp_tp_fn = []
for t in range(len(data) - 2 * window_size):
    left = t + window_size - tolerance # NB: we need to right shift all of our timestamps by amount window_size because thats when the first possible change_point is!
    right = t + window_size + tolerance
    window = true_cps[left:right]
    if dw_cps[t] and True in window:
        dw_tp.append(t)
        dw_cp_times.append(t)
        dw_cp_tp_fn.append("true positive")
    # if we detected a change point and there IS NOT a true change point in the tolerance window, fp++
    if dw_cps[t] and True not in window:
        dw_fp.append(t)
    # if we did not detect and change point and there IS a true change point int he tolerance window, fn++
    if not dw_cps[t] and True in window:
        dw_fn.append(t)
        dw_cp_times.append(t)
        dw_cp_tp_fn.append("false negative")
dw_recall = len(dw_tp)/ (len(dw_tp) + len(dw_fn))
dw_precision = len(dw_tp) / (len(dw_tp) + len(dw_fp))
dw_f1 = 2 / ((1 / dw_recall) + (1 / dw_precision))
recall_scores[i,j] = dw_recall

In [None]:
times_to_graph = []
cats = []
for index in range(len(dw_cp_times) - 1):
    if dw_cp_times[index + 1] - dw_cp_times[index] > 1:
        times_to_graph.append(dw_cp_times[index])
        cats.append(dw_cp_tp_fn[index])
df = pd.DataFrame()
df['times'] = times_to_graph
df['val'] = -2
df['f1class'] = cats


In [None]:
plt.clf()
fig, ax = plt.subplots()
sns.set(rc={"figure.figsize":(15,6)})
sns.lineplot(data,
             color='red')
sns.scatterplot(data = df,
                x = 'times',
                y = 'val',
                style = 'f1class',
                s=50)
for index, time in enumerate(times_to_graph):

    if cats[index] == "false negative":
        sns.lineplot(x = [time,time], y = [-2,2], estimator = None, lw = 0.5, color = 'magenta')
    if cats[index] == "true positive":
        sns.lineplot(x = [time,time], y = [-2,2], estimator = None, lw = 0.5, color = 'blue')
plt.legend(loc='upper right')
plt.title(f"Change Points Detected via Wasserstein Differentials\n F1: {dw_f1:.2f}, Recall: {dw_recall:.2f}")


In [None]:
##ax.plot(np.arange(window_size, len(data) - window_size), ks_stats, c='b')
#for t in range(window_size, len(data) - window_size):
    #if ks_stats[t - window_size] > 0.9:
        #ax.vlines(t, -1.5, 1.5, colors='blue',linestyle='dashdot')
        #o
#ax.plot(np.arange(len(data)), data, c='orange')
#plt.title("Change Points Detected via Kolmogorov-Smirnov Statistics")

In [None]:
def change_points(data, method, cutoff, window):
    if method=='wasserstein':
        raise Exception("Not Implemented")
    elif method=='kolmogorov':
        ks_data = [kstest(data[(t - window):t], data[t:(t + window)]) for t in range(window, len(data) - window)]
        ks_stats = [x.statistic for x in ks_data]
        return [x > cutoff for x in ks_stats]
    else:
        raise Exception("method not valid")


def f1_components(predicted_change_points, truth, method, cutoff, window, tolerance=25):
    true_change_points = [t in truth_data for t in range(len(data))]
    tp, fp, fn = [], [], []
    for t in range(len(predicted_change_points)):
        left = t + window - tolerance # NB: we need to right shift all of our timestamps by amount window_size because thats when the first possible change_point is!
        right = t + window + tolerance
        moving_window = true_change_points[left:right]
        # if we detected a change point and there IS a true change point in the tolerance window, tp++
        if predicted_change_points[t] and True in moving_window:
            tp.append(t)
        # if we detected a change point and there IS NOT a true change point in the tolerance moving_window, fp++
        if predicted_change_points[t] and True not in moving_window:
            fp.append(t)
        # if we did not detect and change point and there IS a true change point int he tolerance moving_window, fn++
        if not predicted_change_points[t] and True in moving_window:
            fn.append(t)
        # if we detected a change point and there IS a true change point in the tolerance moving_window, tp++
    recall = len(tp)/ (len(tp) + len(fn))
    precision = len(tp) / (len(tp) + len(fp))
    f1 = 2 / ((1 / recall) + (1 / precision))

    return recall, precision, f1
    

In [None]:
cutoff_min = 10
cutoff_max = 90
cutoff_step_size = 5
cutoff_count = int((cutoff_max - cutoff_min) / cutoff_step_size)
window_min = 50
window_max = 1000
window_step_size = 50
window_count = int((window_max - window_min) / window_step_size)
f1_values = np.empty((cutoff_count, window_count))
recall_values = np.empty((cutoff_count, window_count))
print(f1_values.shape)

go = True
if go:
    for i, cutoff in tqdm(enumerate(range(cutoff_min, cutoff_max, cutoff_step_size))):
        for j, window in tqdm(enumerate(range(window_min, window_max, window_step_size))):
           predictions = change_points(data, method='kolmogorov', cutoff = 0.01 * cutoff, window=window)
           recall, _, f1 = f1_components(predictions, true_cps, method='kolmogorov', cutoff=0.01 * cutoff, window=window, tolerance=100)
           f1_values[i,j] = f1
           recall_values[i,j] = recall
sns.heatmap(f1_values)
sns.heatmap(recall_values)

In [None]:
fig = plt.figure(figsize = (10,5)) # width x height
ax1 = fig.add_subplot(1, 2, 1) # row, column, position
ax2 = fig.add_subplot(1, 2, 2)
sns.heatmap(f1_values,
            ax=ax1,
            yticklabels=[cutoff for cutoff in range(cutoff_min, cutoff_max, cutoff_step_size)],
            xticklabels=[window for  window in range(window_min, window_max, window_step_size)])
sns.heatmap(recall_values,
            ax=ax2,
            yticklabels=[cutoff for cutoff in range(cutoff_min, cutoff_max, cutoff_step_size)],
            xticklabels=[window for  window in range(window_min, window_max, window_step_size)])
ax1.set(title="F1 Scores")
ax2.set(title="Recall Scores")
plt.xlabel('Window Size')
plt.ylabel('Cutoff')

In [None]:
np.save('ks_f1.npy', f1_values)
np.save('ks_recall.npy', recall_values)

In [None]:
ks_f1 = np.load('ks_f1.npy')

In [None]:
np.max(ks_f1[13, 12])

In [None]:
r = np.load('ks_recall.npy')[13,12]

In [None]:
r