In [1]:
import numpy as np

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import seaborn as sn
import pandas as pd


from tkinter import *
from tkinter import _setit
matplotlib.use("TkAgg")

plt.rcParams.update({'font.size': 11})
plt.rc('axes', titlesize=14)
plt.rc('axes', labelsize=16)
"""
plt.rcParams.update({'font.size': 20})
plt.rcParams["figure.figsize"] = (8,6)
plt.rc('axes', titlesize=26)
plt.rc('axes', labelsize=26)
"""

'\nplt.rcParams.update({\'font.size\': 20})\nplt.rcParams["figure.figsize"] = (8,6)\nplt.rc(\'axes\', titlesize=26)\nplt.rc(\'axes\', labelsize=26)\n'

In [2]:
class Plotwindow():
    def __init__(self, masterframe, size): 
        
        (w,h)=size    
        self.inchsize=(w/25.4, h/25.4)
        self.figure = Figure(self.inchsize)
        self.axes = self.figure.add_subplot(111)
        
        # create canvas as matplotlib drawing area
        self.canvas = FigureCanvasTkAgg(self.figure, master=masterframe)
        self.canvas.get_tk_widget().pack()
        
        self.models = ['attm', 'ctrw', 'fbm', 'lw', 'sbm']
        
    def plotxy(self, x, y):
        self.axes.plot(x,y)
        self.canvas.draw()
        
    def clearplot(self):
        self.figure.clear()
        self.axes = self.figure.add_subplot(111)
        self.canvas.draw()
        
    
    def plot_errorbar(self,data,popout=False):
        true_values = data[0]
        pred_values = data[1]
        pred_std = data[2]
        
        choices = []
        while len(choices) < 20:
            randint = np.random.randint(0,len(true_values))
            if randint not in choices:
                include = True
                if len(filtered_models) != 0:
                    model = self.models[int(data[3,randint])]
                    if not model in filtered_models:
                        include = False
                if len(filtered_snr) != 0:
                    noise = "snr "+str(int(1/data[4,randint]+0.5))
                    if not noise in filtered_snr:
                        include = False
                if include == True:
                    choices.append(randint)
                
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)       
        
        
        plotaxes.errorbar(true_values[choices],pred_values[choices],pred_std[choices],
                           fmt = "o", capsize = 4, markersize = 10)
        plotaxes.plot(np.arange(0,2.01,0.1),np.arange(0,2.01,0.1),"grey")
        plotaxes.set_xlabel("True exponent",fontsize=26)
        plotaxes.set_ylabel("Predicted exponent",fontsize=26)
        
        
        if popout == False:
            self.canvas.draw()
    
    def plot_conf_acc(self,data,popout=False):
        target_values = data[0]
        pred_values = data[1]
        pred_std = data[2]
        
        #observed error over predicted error
        predicted_errors = np.arange(0.,1,0.02)
        observed_errors = np.zeros(len(predicted_errors))
        n_interval = np.zeros(len(predicted_errors))
        mean_conf = 0
        mean_mse = 0
        mae = 0
        n_total = 0
        
        for l in range(len(target_values)):
            include = True
            if len(filtered_models) != 0:
                model = self.models[int(data[3,l])]
                if not model in filtered_models:
                    include = False
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/data[4,l]+0.5))
                if not noise in filtered_snr:
                    include = False
            if include == True:
                index = np.where(predicted_errors <= pred_std[l].item())[-1][-1]
                #print(index,pred_std[l]**2)
                n_interval[index] += 1
                observed_errors[index] += np.square(target_values[l] - pred_values[l])
                mean_conf += pred_std[l]**2
                mean_mse += np.square(target_values[l] - pred_values[l])
                mae += np.abs(target_values[l] - pred_values[l])
                n_total += 1
                
        for l in range(len(predicted_errors)):
            if n_interval[l] > 50:
                observed_errors[l] = np.sqrt(observed_errors[l]/n_interval[l])
            else:
                observed_errors[l] = np.nan
        
        mean_conf /= n_total
        mean_mse /= n_total
        mae /= n_total
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
            
        plotaxes.plot(predicted_errors+0.01,observed_errors,"o-",linewidth=1,ms=10)
        plotaxes.plot(predicted_errors,predicted_errors,"grey",lw=0.8)
        
        plotaxes.set_xlabel("predicted std",fontsize=26)
        plotaxes.set_ylabel("observed std",fontsize=26)
        plotaxes.set_title(f"mean predicted squared error = {mean_conf:.4f}, observed mse = {mean_mse:.4f},\n observed mae={mae:.2f}",fontsize=24)
        plotaxes.set_xlim(xmin=0,xmax=0.6)
        plotaxes.set_ylim(ymin=0,ymax=0.6)
        
        
        if popout == False:
            self.canvas.draw()
    
    def plot_exponent_per_gt(self,data, popout = False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        
        alpha_gts = np.arange(0.05,2.05,0.05)
        predval_per_gt_and_model = np.zeros((5,len(alpha_gts)))
        predvar_per_gt_and_model = np.zeros((5,len(alpha_gts)))
        ncount_per_gt_and_model = np.zeros((5,len(alpha_gts)))
        for i in range(len(pred_values)):
            if len(filtered_snr) != 0:
                noise = "snr " + str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            condition = target_values[i].item() - 0.025 <= alpha_gts
            index = np.where(condition)[-1]
            if len(index) > 1:
                index = index[0]
            predval_per_gt_and_model[int(all_models[i]),index] += pred_values[i]
            predvar_per_gt_and_model[int(all_models[i]),index] += pred_std[i]**2
            ncount_per_gt_and_model[int(all_models[i]),index] += 1


        predval_per_gt = predval_per_gt_and_model.sum(axis=0)/ncount_per_gt_and_model.sum(axis=0)
        predvar_per_gt = predvar_per_gt_and_model.sum(axis=0)/ncount_per_gt_and_model.sum(axis=0)
        #print(predval_per_gt)
        predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
        predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
        
        plotaxes.plot(alpha_gts ,predval_per_gt_and_model[0] ,"o-" ,markersize=6 ,linewidth=0.7 ,label="ATTM")
        plotaxes.plot(alpha_gts ,predval_per_gt_and_model[1] ,"v-" ,markersize=6 ,linewidth=0.7 ,label="CTRW")
        plotaxes.plot(alpha_gts ,predval_per_gt_and_model[2] ,"<-" ,markersize=6 ,linewidth=0.7 ,label="FBM")
        plotaxes.plot(alpha_gts ,predval_per_gt_and_model[3] ,"s-" ,markersize=6 ,linewidth=0.7 ,label="LW")
        plotaxes.plot(alpha_gts ,predval_per_gt_and_model[4] ,"d-" ,markersize=6 ,linewidth=0.7 ,label="SBM")
        plotaxes.plot(alpha_gts ,alpha_gts ,"-" ,color="grey" ,markersize=7 ,linewidth=0.7 ,label="Truth")
        plotaxes.legend()
        plotaxes.set_xlabel("True Exponent "+r"$\alpha$")
        plotaxes.set_ylabel("Predicted Exponent "+r"$\alpha$")
        
        if popout == False:
            self.canvas.draw()
    
    def plot_error_histogram(self,data,option,popout=False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        
        #observed error over predicted error
        predicted_errors = np.arange(0.,1,0.02)
        observed_errors = np.zeros(len(predicted_errors))
        n_interval = np.zeros(len(predicted_errors))
        n_interval_per_model = np.zeros((5,len(predicted_errors)))

        for l in range(len(target_values)):
            if len(filtered_models) != 0:
                model = self.models[int(data[3,l])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/data[4,l]+0.5))
                if not noise in filtered_snr:
                    continue
            
            
            index = np.where(predicted_errors <= pred_std[l].item())[-1][-1]
            #print(index,pred_std[l]**2)
            n_interval[index] += 1
            observed_errors[index] += np.square(target_values[l] - pred_values[l])
            n_interval_per_model[int(all_models[l]),index] += 1

        for l in range(len(predicted_errors)):
            if n_interval[l] > 50:
                observed_errors[l] = np.sqrt(observed_errors[l]/n_interval[l])
            else:
                observed_errors[l] = np.nan
                
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
        
        
        models = ["attm","ctrw","fbm","lw","sbm"]
        if option == 0:
            plotaxes.bar(predicted_errors+0.01,n_interval,width=0.02)
            plotaxes.set_xlabel("predicted error (std)")
            plotaxes.set_ylabel("count")
        if option == 1:
            y_offset = np.zeros((len(n_interval_per_model[0])))
            for model in range(5):
                plotaxes.bar(predicted_errors+0.01,n_interval_per_model[model],width=0.02,bottom=y_offset,label=models[model])
                y_offset = y_offset + np.asarray(n_interval_per_model[model])
            plotaxes.legend()
            plotaxes.set_xlabel("predicted error (std)")
            plotaxes.set_ylabel("count")
        if option == 2:
            for model in range(5):
                plotaxes.plot(predicted_errors+0.01,n_interval_per_model[model],label=models[model],lw=3)
            plotaxes.legend(fontsize=26)
            plotaxes.set_xlabel("predicted error (std)",fontsize=26)
            plotaxes.set_ylabel("count",fontsize=26)
            plotaxes.set_ylim(ymin = 0)
            plotaxes.set_xlim(xmin = 0, xmax=0.8)
        
        if popout == False:
            self.canvas.draw()
               
    
    def plot_error_histogram_exponent_split(self,data,option,
                                            exponent_split = np.asarray([0,0.4,0.8,1.2,1.6]), popout=False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        
        #label for the exponent split
        exp_labels = []
        for k in range(len(exponent_split)-1):
            labeltxt = f"{exponent_split[k]}"+r"$<\alpha\leq$"+f"{exponent_split[k+1]}"
            exp_labels.append(labeltxt)
        labeltxt = f"{exponent_split[-1]}"+r"$<\alpha\leq$2"
        exp_labels.append(labeltxt)
        
        #observed error over predicted error
        predicted_errors = np.arange(0.,1,0.02)
        observed_errors = np.zeros((len(predicted_errors),len(exponent_split)))
        n_interval = np.zeros((len(predicted_errors),len(exponent_split)))
        n_interval_per_model = np.zeros((5,len(exponent_split),len(predicted_errors)))

        for l in range(len(target_values)):
            if len(filtered_models) != 0:
                model = self.models[int(data[3,l])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/data[4,l]+0.5))
                if not noise in filtered_snr:
                    continue
            
            
            index = np.where(predicted_errors <= pred_std[l].item())[-1][-1]
            index_exp = np.where(exponent_split < target_values[l].item())[-1][-1]
            #print(index,pred_std[l]**2)
            n_interval[index,index_exp] += 1
            observed_errors[index,index_exp] += np.square(target_values[l] - pred_values[l])
            n_interval_per_model[int(all_models[l]),index_exp,index] += 1

        for l in range(len(predicted_errors)):
            for k in range(len(exponent_split)):
                if n_interval[l,k] > 10:
                    observed_errors[l,k] = np.sqrt(observed_errors[l,k]/n_interval[l,k])
                else:
                    observed_errors[l,k] = np.nan
                
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
        
        
        models = ["attm","ctrw","fbm","lw","sbm"]
        if option == 0:
            for k in range(len(exponent_split)):
                plotaxes.plot(predicted_errors+0.01,n_interval[:,k],label=exp_labels[k],lw=3)
            plotaxes.set_xlabel("predicted error (std)")
            plotaxes.set_ylabel("count")
            plotaxes.legend(fontsize=16)
            plotaxes.set_ylim(ymin = 0)
            plotaxes.set_xlim(xmin = 0, xmax=0.8)
        if option == 2:
            for model in range(5):
                for k in range(len(exponent_split)):
                    plotaxes.plot(predicted_errors+0.01,n_interval_per_model[model,k],
                                  label=models[model]+","+exp_labels[k],lw=3)
            plotaxes.legend(fontsize=16)
            plotaxes.set_xlabel("predicted error (std)",fontsize=26)
            plotaxes.set_ylabel("count",fontsize=26)
            plotaxes.set_ylim(ymin = 0)
            plotaxes.set_xlim(xmin = 0, xmax=0.8)
        
        if popout == False:
            self.canvas.draw()
               
    def plot_error_histogram_noise_split(self,data,option, popout=False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        snrs = ["1","2","10"]
        
        
        #observed error over predicted error
        predicted_errors = np.arange(0.,1,0.02)
        observed_errors = np.zeros((len(predicted_errors),len(snrs)))
        n_interval = np.zeros((len(predicted_errors),len(snrs)))
        n_interval_per_model = np.zeros((5,len(snrs),len(predicted_errors)))

        for l in range(len(target_values)):
            if len(filtered_models) != 0:
                model = self.models[int(data[3,l])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/data[4,l]+0.5))
                if not noise in filtered_snr:
                    continue
            
            
            index = np.where(predicted_errors <= pred_std[l].item())[-1][-1]
            snr = int(1/all_noises[l]+0.5)
            index_snr = 0
            if snr == 2:
                index_snr = 1
            elif snr > 2:
                index_snr = 2
            #print(index,pred_std[l]**2)
            n_interval[index,index_snr] += 1
            observed_errors[index,index_snr] += np.square(target_values[l] - pred_values[l])
            n_interval_per_model[int(all_models[l]),index_snr,index] += 1

        for l in range(len(predicted_errors)):
            for k in range(len(snrs)):
                if n_interval[l,k] > 10:
                    observed_errors[l,k] = np.sqrt(observed_errors[l,k]/n_interval[l,k])
                else:
                    observed_errors[l,k] = np.nan
                
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
        
        
        models = ["attm","ctrw","fbm","lw","sbm"]
        if option == 0:
            for k in range(len(snrs)):
                plotaxes.plot(predicted_errors+0.01,n_interval[:,k],label=snrs[k],lw=3)
            plotaxes.set_xlabel("predicted error (std)")
            plotaxes.set_ylabel("count")
            plotaxes.legend(fontsize=16)
            plotaxes.set_ylim(ymin = 0)
            plotaxes.set_xlim(xmin = 0, xmax=0.8)
        if option == 2:
            for model in range(5):
                for k in range(len(snrs)):
                    plotaxes.plot(predicted_errors+0.01,n_interval_per_model[model,k],
                                  label=models[model]+","+snrs[k],lw=3)
            plotaxes.legend(fontsize=16)
            plotaxes.set_xlabel("predicted error (std)")
            plotaxes.set_ylabel("count")
            plotaxes.set_ylim(ymin = 0)
            plotaxes.set_xlim(xmin = 0, xmax=0.8)
        
        if popout == False:
            self.canvas.draw()
                
    def plot_error_per_snr(self,data,popout=False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        
        
        per_snr_mse = np.zeros(3)
        per_snr_mae = np.zeros(3)
        per_snr_predvar = np.zeros(3)
        snr_count = np.zeros(3)

        for i in range(len(target_values)):
            if len(filtered_models) != 0:
                model = self.models[int(data[3,i])]
                if not model in filtered_models:
                    continue
            #noise values are [0.1,0.5,1] for snr 10,2,1
            snr = int(1/all_noises[i]+0.5)
            if snr == 10:
                snr_index = 0
            elif snr == 2:
                snr_index = 1
            elif snr == 1:
                snr_index = 2
                
            per_snr_mae[snr_index] += abs(target_values[i]-pred_values[i])
            per_snr_mse[snr_index] += np.square(target_values[i]-pred_values[i])
            per_snr_predvar[snr_index] += pred_std[i]**2
            
            snr_count[snr_index] += 1

        per_snr_mae = per_snr_mae/snr_count
        per_snr_mse = per_snr_mse/snr_count
        per_snr_predvar = per_snr_predvar/snr_count
        
        x_val = [0,1,2]
        classes = ['snr 10', 'snr 2', 'snr 1']
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)    
        
        plotaxes.plot(x_val,per_snr_mse,"o",ms = 11,label="observed error")
        plotaxes.plot(x_val,per_snr_predvar,"o",ms = 11, label="predicted error")
        plotaxes.set_xticks(ticks=x_val)
        plotaxes.set_xticklabels(labels=classes, size = 16)
        plotaxes.set_ylabel("MSE")
        plotaxes.legend()
        
        if popout == False:
            self.canvas.draw()
            
    def plot_error_per_model(self,data,popout=False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        
        
        per_model_mse = np.zeros(5)
        per_model_mae = np.zeros(5)
        per_model_predvar = np.zeros(5)
        model_count = np.zeros(5)

        for i in range(len(target_values)):
            if len(filtered_snr) != 0:
                noise = "snr " + str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            
            per_model_mae[int(all_models[i])] += abs(target_values[i]-pred_values[i])
            per_model_mse[int(all_models[i])] += np.square(target_values[i]-pred_values[i])
            per_model_predvar[int(all_models[i])] += pred_std[i]**2
            model_count[int(all_models[i])] += 1

        per_model_mae = per_model_mae/model_count
        per_model_mse = per_model_mse/model_count
        per_model_predvar = per_model_predvar/model_count
        
        x_val = [0,1,2,3,4]
        classes = ['attm', 'ctrw', 'fbm', 'lw', 'sbm']
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)    
        
        plotaxes.plot(x_val,per_model_mse,"o",ms = 11,label="observed error")
        plotaxes.plot(x_val,per_model_predvar,"o",ms = 11, label="predicted error")
        plotaxes.set_xticks(ticks=x_val)
        plotaxes.set_xticklabels(labels=classes, size = 16)
        plotaxes.set_ylabel("MSE")
        plotaxes.legend()
        
        if popout == False:
            self.canvas.draw()
    
    def plot_error_per_gt_exp(self,data,popout=False):
        pred_values = data[1]
        target_values = data[0]
        pred_std = data[2]
        all_models = data[3]
        all_noises = data[4]
        
        #observed and predicted error over ground truth anomalous exponent
        alpha_gt_values = np.arange(0.05,2.05,0.05)
        acc_errors = np.zeros(len(alpha_gt_values))
        acc_pred_error = np.zeros(len(alpha_gt_values))
        n_error = np.zeros(len(alpha_gt_values))

        for l in range(len(target_values)):
            if len(filtered_models) != 0:
                model = self.models[int(data[3,l])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/data[4,l]+0.5))
                if not noise in filtered_snr:
                    continue
            
            
            condition = target_values[l].item() - 0.025 <= alpha_gt_values 
            index = np.where(condition)[-1]
            if len(index) > 1:
                index = index[0]

            acc_errors[index] += np.square(target_values[l] - pred_values[l])
            acc_pred_error[index] += np.square(pred_std[l])
            n_error[index] += 1

        errors = acc_errors/n_error
        pred_errors = acc_pred_error/n_error
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)  
        
        plotaxes.plot(alpha_gt_values,errors,'bo',label="observed error", markersize = 8)
        plotaxes.plot(alpha_gt_values,pred_errors,"ro",label="predicted error", markersize = 8)
        plotaxes.set_xlabel("True exponent")
        plotaxes.set_ylabel("MSE")
        plotaxes.legend()
        
        if popout == False:
            self.canvas.draw()
    
    def plot_confusion_matrix(self,data,popout=False):
        all_gt_models = data[0]
        all_confidences = data[1:6]
        all_exponents = data[6]
        all_noises = data[7]
        #print(all_confidences[:,0])
        
        predicted_models = np.argmax(all_confidences, 0)
        #print(all_confidences[predicted_models[0],0],predicted_models[0])
        
        accuracy = (predicted_models == all_gt_models).sum()/len(predicted_models)
        #print(accuracy)
        
        #init conf matrix
        conf_matrix = np.zeros((5,5))
        #loop over all test samples
        for i in range(len(all_gt_models)):
            if len(filtered_models) != 0:
                model = self.models[int(all_gt_models[i])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            
            prediction = -1
            confidence = 0
            conf_norm = 0
            for j in range(5):
                if len(filtered_models) != 0:
                    model = self.models[j]
                    if not model in filtered_models:
                        continue
                conf_norm += all_confidences[j,i]
                if all_confidences[j,i] > confidence:
                    confidence = all_confidences[j,i]
                    prediction = j
            confidence /= conf_norm
            
            
            conf_matrix[prediction,int(all_gt_models[i])] += 1
            
        conf_matrix = conf_matrix/conf_matrix.sum(axis=0)
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)  
            
        df_cm = pd.DataFrame(conf_matrix, index = [i for i in self.models],
                          columns = [i for i in self.models])
        # plt.figure(figsize=(7,7))
        #sn.set(font_scale=1.4) # for label size
        sn.heatmap(df_cm, annot=True, ax=plotaxes, cbar = False, cmap = "Blues")
        plotaxes.set_xlabel("True model")
        plotaxes.set_ylabel("Predicted model")
        
        if popout == False:
            self.canvas.draw()
            
    def plot_conf_acc_classi(self,data,popout=False):
        all_gt_models = data[0]
        all_confidences = data[1:6]
        all_exponents = data[6]
        all_noises = data[7]
        
        
        confidence_intervals = np.arange(0,1,0.05) #confidence intervals
        n_interval = np.zeros(len(confidence_intervals))
        accuracy_interval = np.zeros(len(confidence_intervals))
        total_accuracy  = 0
        mean_confidence = 0
        n_samples = 0
        
        for i in range(len(all_gt_models)):
            label = int(all_gt_models[i])
            if len(filtered_models) != 0:
                model = self.models[int(all_gt_models[i])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            
            prediction = -1
            confidence = 0
            conf_norm = 0
            for j in range(5):
                if len(filtered_models) != 0:
                    model = self.models[j]
                    if not model in filtered_models:
                        continue
                conf_norm += all_confidences[j,i]
                if all_confidences[j,i] > confidence:
                    confidence = all_confidences[j,i]
                    prediction = j
            confidence /= conf_norm
            
            index = np.where(confidence_intervals <= confidence)[-1][-1]
            #print(index)
            #print(outputs_prob[i][pred])
            n_interval[index] += 1
            if (label == prediction):
                accuracy_interval[index] += 1
                total_accuracy += 1
            mean_confidence += confidence
            n_samples += 1
        
        total_accuracy /= n_samples
        mean_confidence /= n_samples
        accuracy_interval /= n_interval
        for i in np.where(n_interval<50):
            accuracy_interval[i] = np.nan
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
    
        #plot accuracy over confidence
        plotaxes.plot(confidence_intervals+0.025,accuracy_interval,"o-",ms=10,lw=1)
        plotaxes.plot(np.arange(0,1.06,0.05),np.arange(0,1.06,0.05),"grey")
        plotaxes.set_xlim(xmin=0.2,xmax=1.025)
        plotaxes.set_ylim(ymin=0.2,ymax=1.025)
        plotaxes.set_xlabel("(Predicted) Confidence")
        plotaxes.set_ylabel("(Observed) Accuracy")
        plotaxes.set_title(f"total accuracy = {total_accuracy*100:.2f}%, mean confidence = {mean_confidence*100:.2f}%")
        
        if popout == False:
            self.canvas.draw()
    
    def plot_accu_per_exponent(self,data,popout=False):
        all_gt_models = data[0]
        all_confidences = data[1:6]
        all_exponents = data[6]
        all_noises = data[7]
        
        alpha_interval = np.arange(0.05,2,0.1)
        acc_per_alpha = np.zeros(len(alpha_interval))
        pred_acc_per_alpha = np.zeros(len(alpha_interval))
        n_per_alpha = np.zeros(len(alpha_interval))
        
        for i in range(len(all_gt_models)):
            label = int(all_gt_models[i])
            if len(filtered_models) != 0:
                model = self.models[int(all_gt_models[i])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            
            prediction = -1
            confidence = 0
            conf_norm = 0
            for j in range(5):
                if len(filtered_models) != 0:
                    model = self.models[j]
                    if not model in filtered_models:
                        continue
                conf_norm += all_confidences[j,i]
                if all_confidences[j,i] > confidence:
                    confidence = all_confidences[j,i]
                    prediction = j
            confidence /= conf_norm
            
            
            #accuracy in alpha interval
            index_alpha = np.where(alpha_interval <= all_exponents[i])[-1][-1]
            n_per_alpha[index_alpha] += 1
            if label == prediction:
                acc_per_alpha[index_alpha] += 1
            pred_acc_per_alpha[index_alpha] += confidence
        
        acc_per_alpha = (acc_per_alpha/n_per_alpha)
        pred_acc_per_alpha = (pred_acc_per_alpha/n_per_alpha)
    
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
    
        plotaxes.plot(alpha_interval+0.05,acc_per_alpha*100,"o-",linewidth=0.7,label="observed")
        plotaxes.plot(alpha_interval+0.05,pred_acc_per_alpha*100,"o-",linewidth=0.7,label="predicted")


        plotaxes.set_xlabel("Ground Truth Anomalous Exponent")
        plotaxes.set_ylabel("Accuracy %")
        plotaxes.legend()

        
        if popout == False:
            self.canvas.draw()
    
    def plot_accu_per_snr(self,data,popout=False):
        all_gt_models = data[0]
        all_confidences = data[1:6]
        all_exponents = data[6]
        all_noises = data[7]
        
        total_accu_per_snr = np.zeros(3)
        total_conf_per_snr = np.zeros(3)
        n_total_per_snr = np.zeros(3)
        
        for i in range(len(all_gt_models)):
            label = int(all_gt_models[i])
            if len(filtered_models) != 0:
                model = self.models[int(all_gt_models[i])]
                if not model in filtered_models:
                    continue
            
            prediction = -1
            confidence = 0
            conf_norm = 0
            for j in range(5):
                if len(filtered_models) != 0:
                    model = self.models[j]
                    if not model in filtered_models:
                        continue
                conf_norm += all_confidences[j,i]
                if all_confidences[j,i] > confidence:
                    confidence = all_confidences[j,i]
                    prediction = j
            confidence /= conf_norm
            
            snr = int(1/all_noises[i]+0.5)
            if snr == 10:
                snr_index = 0
            elif snr == 2:
                snr_index = 1
            elif snr == 1:
                snr_index = 2
            else:
                raise("Wrong snr - not included in [10,2,1], this should not be possible")
            
            if label == prediction:
                total_accu_per_snr[snr_index] += 1
            total_conf_per_snr[snr_index] += confidence
            n_total_per_snr[snr_index] += 1
            
        total_accu_per_snr /= n_total_per_snr
        total_conf_per_snr /= n_total_per_snr
            
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
        
        plotaxes.plot([0,1,2],total_accu_per_snr*100,"o",ms=20,label="Observed")
        plotaxes.plot([0,1,2],total_conf_per_snr*100,"o",ms=20,label="Predicted")
        plotaxes.set_xticks([0,1,2])
        plotaxes.set_xticklabels(["SNR 10", "SNR 2", "SNR 1"])
        plotaxes.set_ylabel("Accuracy %")
        plotaxes.legend()

        
        if popout == False:
            self.canvas.draw()
        
    def plot_fn_fp_rates(self,data,popout=False):
        all_gt_models = data[0]
        all_confidences = data[1:6]
        all_exponents = data[6]
        all_noises = data[7]
        
        alpha_interval = np.arange(0.05,2,0.1)
        truepos_per_alpha_n_model = np.zeros((len(alpha_interval),5))
        falsepos_per_alpha_n_model = np.zeros((len(alpha_interval),5))
        falseneg_per_alpha_n_model = np.zeros((len(alpha_interval),5))
        count_per_alpha_n_model_pos = np.zeros((len(alpha_interval),5))
        count_per_alpha_n_model_neg = np.zeros((len(alpha_interval),5))
        
        total_accuracy  = 0
        mean_confidence = 0
        n_samples = 0
        
        for i in range(len(all_gt_models)):
            label = int(all_gt_models[i])
            if len(filtered_models) != 0:
                model = self.models[int(all_gt_models[i])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            
            prediction = -1
            confidence = 0
            conf_norm = 0
            for j in range(5):
                if len(filtered_models) != 0:
                    model = self.models[j]
                    if not model in filtered_models:
                        continue
                conf_norm += all_confidences[j,i]
                if all_confidences[j,i] > confidence:
                    confidence = all_confidences[j,i]
                    prediction = j
            confidence /= conf_norm
            
            index_alpha = np.where(alpha_interval <= all_exponents[i])[-1][-1]
            
            if (label == prediction):
                truepos_per_alpha_n_model[index_alpha,label] += 1
            else:
                falsepos_per_alpha_n_model[index_alpha,prediction] += 1
                falseneg_per_alpha_n_model[index_alpha,label] += 1
            count_per_alpha_n_model_pos[index_alpha,label] += 1
            for neg in range(5):
                if neg != label:
                    count_per_alpha_n_model_neg[index_alpha,neg] += 1
            
            #continue working from here next time
        truepos_per_alpha_n_model /= count_per_alpha_n_model_pos
        falseneg_per_alpha_n_model /= count_per_alpha_n_model_pos
        falsepos_per_alpha_n_model /= count_per_alpha_n_model_neg
        
        
        #plots
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
        
        colors = ["red","blue","green","black","magenta"]
        for i in range(5):
            model = self.models[i]
            if len(filtered_models) != 0:
                if not model in filtered_models:
                    continue
            plotaxes.plot(alpha_interval+0.05, falsepos_per_alpha_n_model[:,i],color=colors[i],ls="-",label=model)
            plotaxes.plot(alpha_interval+0.05, falseneg_per_alpha_n_model[:,i],color=colors[i],ls="--")
        plotaxes.legend()
        plotaxes.set_xlabel("gt-exponent")
        plotaxes.set_ylabel("- false positive rate\n -- false negative rate")
        
        if popout == False:
            self.canvas.draw()
        
    def plot_error_histogram_classi(self,data,option,popout=False):
        all_gt_models = data[0]
        all_confidences = data[1:6]
        all_exponents = data[6]
        all_noises = data[7]
        
        confidence_intervals = np.arange(0,1.01,0.05) #confidence intervals
        n_interval = np.zeros(len(confidence_intervals))
        n_interval_per_model = np.zeros((5,len(confidence_intervals)))
        accuracy_interval = np.zeros(len(confidence_intervals))
        total_accuracy  = 0
        mean_confidence = 0
        n_samples = 0
        
        
        for i in range(len(all_gt_models)):
            label = int(all_gt_models[i])
            if len(filtered_models) != 0:
                model = self.models[int(all_gt_models[i])]
                if not model in filtered_models:
                    continue
            if len(filtered_snr) != 0:
                noise = "snr "+str(int(1/all_noises[i]+0.5))
                if not noise in filtered_snr:
                    continue
            
            prediction = -1
            confidence = 0
            conf_norm = 0
            for j in range(5):
                if len(filtered_models) != 0:
                    model = self.models[j]
                    if not model in filtered_models:
                        continue
                conf_norm += all_confidences[j,i]
                if all_confidences[j,i] > confidence:
                    confidence = all_confidences[j,i]
                    prediction = j
            confidence /= conf_norm
        
            index = np.where(confidence_intervals <= confidence)[-1][-1]
            #print(index)
            #print(outputs_prob[i][pred])
            n_interval[index] += 1
            n_interval_per_model[label,index] += 1
            if (label == prediction):
                accuracy_interval[index] += 1
                total_accuracy += 1
            mean_confidence += confidence
            n_samples += 1
        
        total_accuracy /= n_samples
        mean_confidence /= n_samples
        accuracy_interval /= n_interval
        for i in np.where(n_interval<50):
            accuracy_interval[i] = np.nan
        
        if popout == False:
            plotaxes = self.axes
        else:
            newfigure = plt.figure(figsize=self.inchsize)
            plotaxes = newfigure.add_subplot(111)
            
            
        models = ["attm","ctrw","fbm","lw","sbm"]
        if option == 0:
            plotaxes.bar((confidence_intervals+0.025)*100,n_interval,width=5)
            plotaxes.set_xlabel("predicted confidence %")
            plotaxes.set_ylabel("count")
        if option == 1:
            y_offset = np.zeros((len(n_interval_per_model[0])))
            for model in range(5):
                plotaxes.bar((confidence_intervals+0.025)*100,n_interval_per_model[model],width=5,bottom=y_offset,label=models[model])
                y_offset = y_offset + np.asarray(n_interval_per_model[model])
            plotaxes.legend()
            plotaxes.set_xlabel("predicted confidence %")
            plotaxes.set_ylabel("count")
        if option == 2:
            for model in range(5):
                plotaxes.plot((confidence_intervals+0.025)*100,n_interval_per_model[model],label=models[model])
            plotaxes.legend()
            plotaxes.set_xlabel("predicted confidence %")
            plotaxes.set_ylabel("count")
        
        if popout == False:
            self.canvas.draw()
        
class Optionwindow():
    def __init__(self,masterframe,options):
        self.options = options
        self.choice = StringVar(masterframe)
        self.choice.set(options[0])
        
        self.menu = OptionMenu(masterframe,self.choice,*options)
        self.menu.pack(side="left")
    def change_options(self,new_options,chosen_option = False):
        if new_options != self.options:
            self.options = new_options
            #clear old options
            self.choice.set("")
            self.menu['menu'].delete(0,'end')
            #create new options
            for choice in new_options:
                self.menu['menu'].add_command(label=choice, command=_setit(self.choice,choice))
            self.choice.set(new_options[0])
            if not chosen_option == False:
                self.choice.set(chosen_option)
    
    
def submitmenues():
    #function for submitting for which length and task to plot
    #and loading the corresponding data
    global plotdata
    #menu1
    if menu1.choice.get() == "Classification":
        #classification data loading
        T = int(menu2.choice.get())
        
        #change to appropriate optionmenu
        classiplotoptions = ["Confusion Matrix", "Confidence vs accuracy","Per gt-exponent", "Per snr", 
                             "False negatives and positives", 
                             "Error histogram1","Error histogram2","Error histogram3"]
        plotmenu.change_options(classiplotoptions)
        
        lengthoptions = ["10","25","50","100","250","500","999"]
        menu2.change_options(lengthoptions,chosen_option=str(T))
        
        
        #load data
        plotdata = np.loadtxt("plotdata/"+f"classification_length{T}")
    elif menu1.choice.get() == "Regression":
        #regression data loading
        T = int(menu2.choice.get())
        
        #change to appropriate optionmenu
        regressionplotoptions = ["Errorbar","Confidence vs accuracy","Mean prediction per ground truth",
                                 "Error histogram1","Error histogram2","Error histogram3", 
                                 "Error histogram Expsplit", "Error histogram Noisesplit",
                                 "Error per model",
                                 "Error per gt exponent", "Error per snr"]
        plotmenu.change_options(regressionplotoptions)
        
        lengthoptions = ["10","25","50","100","250","500","999"]
        menu2.change_options(lengthoptions,chosen_option=str(T))
        
        #load data
        plotdata = np.loadtxt("plotdata/"+f"regression_length{T}")
    elif menu1.choice.get() == "Single Model Regression":
        #loading single model regression data
        T = int(menu2.choice.get())
        if T not in [10,100,500]:
            T = 10
        
       
        #change to appropriate optionmenu
        singlemodelplotoptions = ["Errorbar","Confidence vs accuracy","Mean prediction per ground truth",
                                 "Error histogram1","Error histogram2","Error histogram3", 
                                 "Error histogram Expsplit", "Error histogram Noisesplit",
                                 "Error per model",
                                 "Error per gt exponent", "Error per snr"]
        plotmenu.change_options(singlemodelplotoptions)
        
        
        lengthoptions = ["10","100","500"]
        menu2.change_options(lengthoptions,chosen_option=str(T))
        
        #load data
        if len(filtered_models) == 0: #for this we always need at least one filtered model, if there is none add fbm
            modelname = "fbm"
            filtered_models.add("fbm")
            filterlabel.configure(text="You have filtered for models: "
                                  +str(filtered_models)+" and "+str(filtered_snr))
        count = 0
        for modelname in filtered_models: #load in all the plotdata for the different models
            modeldata = np.loadtxt("plotdata/"+modelname+f"regression_length{T}")
            if count > 0: #add data for this model to complete pool of data
                print(plotdata)
                print(modeldata)
                plotdata = np.concatenate((plotdata,modeldata),axis=1)
                print(plotdata)
            else: #plotdata not yet initialized, so do so
                plotdata = modeldata
            count += 1
            
    submitplot()

def submitplot(popout=False):
    maindraw.clearplot()
    if menu1.choice.get() == "Regression" or menu1.choice.get() == "Single Model Regression":
        #plotmenu for regression
        if plotmenu.choice.get() == "Errorbar":
            maindraw.plot_errorbar(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Confidence vs accuracy":
            maindraw.plot_conf_acc(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Mean prediction per ground truth":
            maindraw.plot_exponent_per_gt(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Error histogram1":
            maindraw.plot_error_histogram(plotdata,0,popout=popout)
        elif plotmenu.choice.get() == "Error histogram2":
            maindraw.plot_error_histogram(plotdata,1,popout=popout)
        elif plotmenu.choice.get() == "Error histogram3":
            maindraw.plot_error_histogram(plotdata,2,popout=popout)
        elif plotmenu.choice.get() == "Error histogram Expsplit":
            maindraw.plot_error_histogram_exponent_split(plotdata,0,popout=popout)
        elif plotmenu.choice.get() == "Error histogram Noisesplit":
            maindraw.plot_error_histogram_noise_split(plotdata,0,popout=popout)
        elif plotmenu.choice.get() == "Error per model":
            maindraw.plot_error_per_model(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Error per gt exponent":
            maindraw.plot_error_per_gt_exp(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Error per snr":
            maindraw.plot_error_per_snr(plotdata,popout=popout)
    if menu1.choice.get() == "Classification":
        #plotmenu for classification
        if plotmenu.choice.get() == "Confusion Matrix":
            maindraw.plot_confusion_matrix(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Confidence vs accuracy":
            maindraw.plot_conf_acc_classi(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Per gt-exponent":
            maindraw.plot_accu_per_exponent(plotdata,popout=popout)        
        elif plotmenu.choice.get() == "Per snr":
            maindraw.plot_accu_per_snr(plotdata,popout=popout)
        elif plotmenu.choice.get() == "False negatives and positives":
            maindraw.plot_fn_fp_rates(plotdata,popout=popout)
        elif plotmenu.choice.get() == "Error histogram1":
            maindraw.plot_error_histogram_classi(plotdata,0,popout=popout)
        elif plotmenu.choice.get() == "Error histogram2":
            maindraw.plot_error_histogram_classi(plotdata,1,popout=popout)
        elif plotmenu.choice.get() == "Error histogram3":
            maindraw.plot_error_histogram_classi(plotdata,2,popout=popout)
        
        
def submitfilter(action):
    if action == "add":
        if filtermenu1.choice.get() in ['attm', 'ctrw', 'fbm', 'lw', 'sbm']:
            filtered_models.add(filtermenu1.choice.get())
        else:
            filtered_snr.add(filtermenu1.choice.get())
    elif action == "remove":
        if filtermenu1.choice.get() in ['attm', 'ctrw', 'fbm', 'lw', 'sbm']:
            try: filtered_models.remove(filtermenu1.choice.get())
            except: pass
        else:
            try: filtered_snr.remove(filtermenu1.choice.get())
            except: pass
    filterlabel.configure(text="You have filtered for models: "+str(filtered_models)+" and "+str(filtered_snr))

In [3]:
#define TK
tk = Tk()
tk.title("BDL AnDi-Evaluate")
tk.resizable(0,0)

top = Frame(tk)
top2 = Frame(tk)
top3 = Frame(tk)
middle = Frame(tk)
bottom = Frame(tk)

top.pack(side="top",fill="x", expand=True)
top2.pack(side="top",fill="x", expand=True)
top3.pack(side="top",fill="x", expand=True)
middle.pack(side="top",fill="both",expand=True)
bottom.pack(side="bottom",fill="x",expand=True)



#choice menues
options1 = ["Regression","Classification","Single Model Regression"]
menu1 = Optionwindow(top,options1)
options2 = ["10","25","50","100","250","500","999"]
menu2 = Optionwindow(top,options2)

submitbutton = Button(top,text="submit", command=submitmenues)
submitbutton.pack()
#choice1 = StringVar(tk)
#choice1.set(options1[0])
#menu1 = OptionMenu(tk,choice1,*options1)

#menu1.config(width=90,font=("Helvetica",12))
#menu1.pack()
#menu1.grid(row=1,column=1)
#more choice menues (what to plot)
plotoptions = ["Errorbar","Confidence vs accuracy","Mean prediction per ground truth",
                                 "Error histogram1","Error histogram2","Error histogram3", "Error per model",
                                 "Error per gt exponent", "Error per snr"]
plotmenu = Optionwindow(top2,plotoptions)
plotbutton = Button(top2,text="plot", command=submitplot)
plotbutton.pack(side="left")
plotbutton2 = Button(top2,text="extra plot", command=lambda:submitplot(popout=True))
plotbutton2.pack(side="right")

#filtering for specific noise/model
filtered_models = set() #list of filtered in models / snr-ratios
filtered_snr = set()

filters1 = ["attm","ctrw","fbm","lw","sbm", "snr 10", "snr 2", "snr 1"]
filtermenu1 = Optionwindow(top3,filters1)
addbutton = Button(top3,text="add",command=lambda:submitfilter("add"))
addbutton.pack(side="left")
removebutton = Button(top3,text="remove",command=lambda:submitfilter("remove"))
removebutton.pack(side="left")

filterlabel = Label(top3,text="")
filterlabel.configure(text="All models and noise allowed")
filterlabel.pack(side="left")


#define central figure
maindraw = Plotwindow(middle,(250,200))
maindraw.plotxy(np.arange(0,1,0.1),np.arange(0,1,0.1))
"""
figure = Figure(figsize=(6,8))
axes = Axes(figure,[0.,0.,1.,1.])
figure.add_axes(axes)

centerimage = FigureCanvasTkAgg(figure,master=tk)
centerimage.get_tk_widget().pack()

axes.plot(np.arange(0,1,0.1),np.arange(0,1,0.1))
"""
#some buttons
testbutton = Button(bottom,text="quit",command=tk.destroy)
testbutton.pack()#grid(row=1,column=1,columnspan=2)

tk.update()
tk.mainloop()

  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
  predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
  predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
  predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
  predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
  predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per_gt_and_model
  predvar_per_gt_and_model = predvar_per_gt_and_model/ncount_per_gt_and_model
  predval_per_gt_and_model = predval_per_gt_and_model/ncount_per