## Test models on additional small dataset

This notebook tests the FCN dense and FCN transformer models on the new smaller dataset with only positive samples.

To run this script, please place the file francois_normalized_dataset.pickle in the folder data/real.

In [6]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append('../src_jobs/')

In [2]:
import pickle
from itertools import repeat
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from artifact import Saw_centered, Saw_centered_Francois
from sliding_window_detector import SlidingWindowTransformerDetector, SlidingWindowLinearDetector, ConvolutionalSlidingWindowDetector

from data import CachedArtifactDataset, TestArtifactDataset, CenteredArtifactDataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x729b0808bd90>

In [7]:
test_width_512 = 512
data_512 = pd.read_pickle('../data/real/francois_normalized_dataset.pickle') 

In [5]:
paths_SW = "../models/SW_CNN_Trans.ckpt" # SW 1d CNN Transformer
SW_CNNTrans_detector_512 = SlidingWindowTransformerDetector.load_from_checkpoint(paths_SW).cpu()

paths_SW = "../models/SW_adaFCN_Trans.ckpt" # SW ada 1d CNN Transformer
SW_adaCNNTrans_detector_512 = SlidingWindowTransformerDetector.load_from_checkpoint(paths_SW).cpu()

paths_SW = "../models/SW_CNN_Dense.ckpt" # SW 1d CNN Transformer
SW_CNNDense_detector_512 = SlidingWindowLinearDetector.load_from_checkpoint(paths_SW).cpu()

paths_SW = "../models/SW_adaFCN_Dense.ckpt" # SW ada 1d CNN Transformer
SW_adaCNNDense_detector_512 = SlidingWindowLinearDetector.load_from_checkpoint(paths_SW).cpu(),

  rank_zero_warn(


In [10]:
def baseline_detector(input: torch.Tensor) -> int:   
    input.squeeze(0)
    prediction = 0

    center = int(input.shape[1]/2)
    # flag points with very high increment as artifact
    # Calculate increments by subtracting the tensor shifted by one from the original tensor
    increments = (input[0][1:] - input[0][:-1]).abs()
    mean_increment = torch.mean(increments)
    std_increment = torch.std(increments)

    if increments[center-1] > (mean_increment + 2*std_increment):
        prediction = 1


    return prediction

In [11]:
preds_baseline = list()
gt = list()

for sample in data_512:
    example = sample["data"]
    window  = test_width_512
    length  = len(example)

    example_data = torch.tensor(example)
    prediction_baseline = baseline_detector(example_data.unsqueeze(0))
    preds_baseline = preds_baseline + [prediction_baseline]


In [10]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, mean_squared_error, confusion_matrix, fbeta_score
import pandas as pd

metrics = pd.DataFrame(columns=['accuracy', 'precision', 'recall', 'mse'])
index = 0

preds_FCNTrans = list()
preds_adaFCNTrans = list()
preds_FCNDense = list()
preds_adaFCNDense = list()

preds_baseline = list()
gt = list()

for sample in data_512:
    example = sample["data"]
    stride  = 64
    window  = test_width_512
    length  = len(example)

    # add artifact to data
    example_data = torch.tensor(example)

    prediction_FCNTrans = SW_CNNTrans_detector_512.eval()(example_data.unsqueeze(0))   
    prediction_adaFCNTrans = SW_adaCNNTrans_detector_512.eval()(example_data.unsqueeze(0))
    prediction_FCNDense = SW_CNNDense_detector_512.eval()(example_data.unsqueeze(0))   
    prediction_adaFCNDense = SW_adaCNNDense_detector_512.eval()(example_data.unsqueeze(0))

    prediction_baseline = baseline_detector(example_data.unsqueeze(0))

    preds_FCNTrans = preds_FCNTrans + [prediction_FCNTrans.numpy()]
    preds_adaFCNTrans = preds_adaFCNTrans + [prediction_adaFCNTrans.numpy()]
    preds_FCNDense = preds_FCNDense + [prediction_FCNDense.numpy()]
    preds_adaFCNDense = preds_adaFCNDense + [prediction_adaFCNDense.numpy()]
    
    preds_baseline = preds_baseline + [prediction_baseline]

    gt = gt + [1]


  return F.conv1d(input, weight, bias, self.stride,


In [12]:
best_threshold_fbeta_FCNTrans_512 = 0.252
best_threshold_fbeta_adaFCNTrans_512 = 0.343

best_threshold_fbeta_FCNDense_512 = 0.343
best_threshold_fbeta_adaFCNDense_512 = 0.373

In [12]:
preds_binary_FCNTrans = [1 if value >= best_threshold_fbeta_FCNTrans_512 else 0 for value in preds_FCNTrans]
preds_binary_adaFCNTrans = [1 if value >= best_threshold_fbeta_adaFCNTrans_512 else 0 for value in preds_adaFCNTrans]
preds_binary_FCNDense = [1 if value >= best_threshold_fbeta_FCNDense_512 else 0 for value in preds_FCNDense]
preds_binary_adaFCNDense = [1 if value >= best_threshold_fbeta_adaFCNDense_512 else 0 for value in preds_adaFCNDense]

preds_binary_baseline = preds_baseline

tn_cnn_1, fp_cnn_1, fn_cnn_1, tp_cnn_1 = confusion_matrix(gt, preds_binary_FCNTrans, labels=[0, 1]).ravel()
tn_cnn_2, fp_cnn_2, fn_cnn_2, tp_cnn_2 = confusion_matrix(gt, preds_binary_adaFCNTrans, labels=[0, 1]).ravel()
tn_cnn_3, fp_cnn_3, fn_cnn_3, tp_cnn_3 = confusion_matrix(gt, preds_binary_FCNDense, labels=[0, 1]).ravel()
tn_cnn_4, fp_cnn_4, fn_cnn_4, tp_cnn_4 = confusion_matrix(gt, preds_binary_adaFCNDense, labels=[0, 1]).ravel()

tn_baseline, fp_baseline, fn_baseline, tp_baseline = confusion_matrix(gt, preds_binary_baseline, labels=[0, 1]).ravel()

metrics = pd.DataFrame([{
    'detector': 'FCN Trans',
    'threshold': best_threshold_fbeta_FCNTrans_512,
    'fbeta': fbeta_score(gt, preds_binary_FCNTrans, beta=0.5),
    'accuracy': accuracy_score(gt, preds_binary_FCNTrans),
    'precision': precision_score(gt, preds_binary_FCNTrans),
    'recall': recall_score(gt, preds_binary_FCNTrans),
    'mse': mean_squared_error(gt, preds_binary_FCNTrans),
    'tn': tn_cnn_1,
    'fp': fp_cnn_1, 
    'fn': fn_cnn_1, 
    'tp': tp_cnn_1,
},
{
    'detector': 'adapted FCN Trans',
    'threshold': best_threshold_fbeta_adaFCNTrans_512,
    'fbeta': fbeta_score(gt, preds_binary_adaFCNTrans, beta=0.5),
    'accuracy': accuracy_score(gt, preds_binary_adaFCNTrans),
    'precision': precision_score(gt, preds_binary_adaFCNTrans),
    'recall': recall_score(gt, preds_binary_adaFCNTrans),
    'mse': mean_squared_error(gt, preds_binary_adaFCNTrans),
    'tn': tn_cnn_2,
    'fp': fp_cnn_2, 
    'fn': fn_cnn_2, 
    'tp': tp_cnn_2
},
{
    'detector': 'FCN Dense',
    'threshold': best_threshold_fbeta_FCNDense_512,
    'fbeta': fbeta_score(gt, preds_binary_FCNDense, beta=0.5),
    'accuracy': accuracy_score(gt, preds_binary_FCNDense),
    'precision': precision_score(gt, preds_binary_FCNDense),
    'recall': recall_score(gt, preds_binary_FCNDense),
    'mse': mean_squared_error(gt, preds_binary_FCNDense),
    'tn': tn_cnn_3,
    'fp': fp_cnn_3, 
    'fn': fn_cnn_3, 
    'tp': tp_cnn_3
},
{
    'detector': 'adapted FCN Dense',
    'threshold': best_threshold_fbeta_adaFCNDense_512,
    'fbeta': fbeta_score(gt, preds_binary_adaFCNDense, beta=0.5),
    'accuracy': accuracy_score(gt, preds_binary_adaFCNDense),
    'precision': precision_score(gt, preds_binary_adaFCNDense),
    'recall': recall_score(gt, preds_binary_adaFCNDense),
    'mse': mean_squared_error(gt, preds_binary_adaFCNDense),
    'tn': tn_cnn_4,
    'fp': fp_cnn_4, 
    'fn': fn_cnn_4, 
    'tp': tp_cnn_4
},
{
    'detector': 'baseline',
    'threshold': 0.5,
    'fbeta': fbeta_score(gt, preds_binary_baseline, beta=0.5),
    'accuracy': accuracy_score(gt, preds_binary_baseline),
    'precision': precision_score(gt, preds_binary_baseline),
    'recall': recall_score(gt, preds_binary_baseline),
    'mse': mean_squared_error(gt, preds_binary_baseline),
    'tn': tn_baseline,
    'fp': fp_baseline, 
    'fn': fn_baseline, 
    'tp': tp_baseline
}])

  _warn_prf(average, modifier, msg_start, len(result))


In [13]:
metrics

Unnamed: 0,detector,threshold,fbeta,accuracy,precision,recall,mse,tn,fp,fn,tp
0,FCN Trans,0.252,0.44,0.135802,1.0,0.135802,0.864198,0,0,70,11
1,adapted FCN Trans,0.343,0.80786,0.45679,1.0,0.45679,0.54321,0,0,44,37
2,FCN Dense,0.343,0.247525,0.061728,1.0,0.061728,0.938272,0,0,76,5
3,adapted FCN Dense,0.373,0.11236,0.024691,1.0,0.024691,0.975309,0,0,79,2
4,baseline,0.5,0.0,0.0,0.0,0.0,1.0,0,0,81,0
