# FIRST, LET'S TRAIN DFSTRANS WITH LESS SENSORS (JUST FOR CLARITY OF INTERPRETATIONS

In [7]:
# import torch
import os
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import *
import h5py
import numpy as np
import torch.optim as optim
from torch.backends import cudnn
from torch.autograd import Variable
import torch.multiprocessing as mp
import time
import math
import random
from sklearn.model_selection import train_test_split
from torch.nn import Module


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
gpu = str(2)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
print(os.environ['CUDA_VISIBLE_DEVICES'])



# custom class to generate batch data on the fly
class CustomDataGenerator(Dataset):
    def __init__(self, path, data_index, sequence_length, time_steps, window_length,
                 window_step, n_channels):
        """ Method called at the initialization of the class (Constructor).

            Args:
                path (string): path where the dataset is located
                data_index (array): index list of the dataset
                batch_size (int): the size of the batch
                sequence_length (int): the length of the time series
                time_steps (int): number of windows in which the time series are divided
                window_length (int): the length (data point) of the window
                window_step (int): indicates how much to slice de window over the time series
                n_channels (int): number of sensors
                isTrainig (bool): whether the generator is for training or test

        """
        self.path = path
        self.data_index = data_index
        self.sequence_length = sequence_length
        self.time_steps = time_steps
        self.window_length = window_length
        self.window_step = window_step
        self.n_channels = n_channels
        self.len = data_index.shape[0]
        
        # divide data for each sensor channel
        output_params = {'Alpha': 0, 'Ax': 1, 'Ay': 2, 'Az': 3, 'Fc': 4, 'Fcw': 5, 'FrictionCW': 6, 'FrictionCabin': 7,
                 'Fsupport': 8, 'Id': 9, 'Iq': 10, 'Omega': 11, 'Phi': 12, 'Vc': 13, 'Vcw': 14, 'Vd': 15, 'Vq': 16,
                 'Zc': 17, 'Zcw': 18, 'pulleyAz': 19}
        
        params_to_show = ['Ax', 'Ay', 'Az', 'Id', 'Iq','Omega', 'Phi', 'Vc', 'Vcw', 'Vd', 'Vq', 'Zc', 'Zcw']
        
        self.params_id = [output_params[param] for param in params_to_show]

    def __len__(self):
        """ Method called at the time of requiring the number of batches per epoch

                Returns:
                    int: number of batches per epoch

            """
        return self.len

    def __getitem__(self, idx):
        
        index = self.data_index[idx]

        col_size = self.sequence_length * self.n_channels
        with h5py.File(self.path, 'r') as hf:
            x = hf['dataset'][index, :col_size]
            y = hf['dataset'][index, -1:]

        # reshape data to 4D (time_steps, window_length, channels)
        x = reshape_data(x, self.sequence_length, self.time_steps, self.window_length, self.window_step,
                          self.n_channels)
        
        input_data_x_list = []
        for channel in self.params_id:
            input_data_x_list.append(x[:, :, channel:channel + 1])
        
        return input_data_x_list, y


def reshape_data(data, sequence_length, time_steps, window_length, window_step, n_channels):

    x_reshaped = np.zeros((time_steps, window_length, n_channels))

    current_simulation_values = np.zeros((time_steps, window_length, n_channels))
    for channel in range(n_channels):
        channel_start_point = channel * sequence_length
        channel_end_point = channel_start_point + sequence_length
        param_values = data[channel_start_point:channel_end_point]

        for step in range(time_steps):
            time_step_start_point = step * window_step
            time_step_stop_point = time_step_start_point + window_length
            time_step_values = param_values[time_step_start_point:time_step_stop_point]
            current_simulation_values[step, :, channel] = time_step_values

        x_reshaped[:, :, :] = current_simulation_values

    return x_reshaped
      
def scale_data_between_a_b(batch_x,a,b):
    '''Scales data between values (a,b)
    
    Args:
    
        batch_x (1D mumpy array): features data of the current batch
        a (int): min value of transformed data
        b (int): max value of transformed data
    
    Returns:
        
        1D numpy array: Transformed data
    '''
    return (b-a)*((batch_x-torch.amin(batch_x))/(torch.amax(batch_x)-torch.amin(batch_x)))+a



class DFTEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(DFTEncoding, self).__init__()
        torch.pi = torch.acos(torch.zeros(1)).item() * 2
        w_s = 2 * torch.pi / d_model

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.arange(0, d_model, 2).float() // 2
        pe[:, 1::2] = torch.sin(position * w_s * div_term)
        pe[:, 0::2] = torch.cos(position * w_s * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.d_model = d_model
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe_multidim = self.pe[1:x.size()[0] + 1, :].unsqueeze(-1)
        pe_multidim[0] = math.sqrt(1 / self.d_model) * pe_multidim[0]
        pe_multidim[-1] = math.sqrt(1 / self.d_model) * pe_multidim[-1]
        pe_multidim[1:-1] = torch.mul(pe_multidim[1:-1], math.sqrt(2 / self.d_model))

        pe_multidim.repeat(1, x.size()[1], 1, x.size()[3]).size()

        return x + pe_multidim


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class Flatten(nn.Module):

    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=True):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        x_reshape = x.contiguous().view(-1, *(x.size()[2:]))
        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.contiguous().view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y
    
class MultiHead1DCNN(nn.Module):

    def __init__(self,conv_filters = 20, time_steps = 80):
        super(MultiHead1DCNN, self).__init__()

        self.conv_filters = conv_filters
        self.time_steps = time_steps

        self.conv1d1 = nn.Conv1d(in_channels=1, out_channels=self.conv_filters, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(self.conv_filters,track_running_stats=False)
        self.conv1d2 = nn.Conv1d(in_channels=self.conv_filters, out_channels=self.conv_filters, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(self.conv_filters,track_running_stats=False)
        self.conv1d3 = nn.Conv1d(in_channels=self.conv_filters, out_channels=self.conv_filters, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm1d(self.conv_filters,track_running_stats=False)
        self.maxpool = nn.MaxPool1d(2, 2)

    def forward(self, x):
        x = self.conv1d1(x)
        x = F.relu(x)
        x = self.maxpool(x)

        x = self.bn1(x)
        X = x.view(-1, x.size()[0] // self.time_steps, self.time_steps, self.conv_filters)
        X = x.view(-1, *(x.size()[2:]))
        x = self.conv1d2(x)

        x = F.relu(x)
        x = self.maxpool(x)

        x = self.bn2(x)
        X = x.view(-1, x.size()[0] // self.time_steps, self.time_steps, self.conv_filters)
        X = x.view(-1, *(x.size()[2:]))
        x = self.conv1d3(x)

        x = F.relu(x)
        x = self.maxpool(x)

        x = self.bn3(x)
        x = x.view(x.size()[0], x.size()[1], -1)
        return x
        
class TemporalEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TemporalEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TemporalEncoderLayer, self).__setstate__(state)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        

        src2, weights = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)
        
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        return src, weights
    
class TransTS(nn.Module):
    def __init__(self,feature_size=240,num_layers=1,dropout=0.1,temp_gap = False):
        super(TransTS, self).__init__()
        self.model_type = 'Transformer'
        
        self.src_mask = None
        self.encoder_layer = TemporalEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.temp_gap = temp_gap
        

    def forward(self,src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        output,weights = self.transformer_encoder(src,self.src_mask)#, self.src_mask)

        output = output.permute(1,0,2) 

        return output,weights

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float(0.0)).masked_fill(mask == 1, float(0.0))
        return mask
    

    
class SpatialEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(SpatialEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(SpatialEncoderLayer, self).__setstate__(state)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        
        src2, weights = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        return src, weights
    
class TransSensor(nn.Module):
    def __init__(self, feature_size=240, num_layers=1, dropout=0.1):
        super(TransSensor, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.encoder_layer = SpatialEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

    def forward(self, src):

        bs = src.size()[1]

        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        output, weights = self.transformer_encoder(src, self.src_mask)

        output = output.permute(1, 0, 2)

        return output, weights

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float(0.0)).masked_fill(mask == 1, float(0.0))
        return mask

class DFSTrans_model(nn.Module):

    def __init__(self, activation="relu", d_model=240, dim_feedforward=2048,
                 dropout=0.1,n_channels = 13,n_time_steps = 80,output_dim = 3120, n_units_l1 = 512, conv_filters = 20):
        super(DFSTrans_model, self).__init__()
        self.conv_cell = nn.ModuleList([MultiHead1DCNN(conv_filters = conv_filters, time_steps = n_time_steps) for i in range(n_channels)])
        self.TimeDistributed_flatten = nn.ModuleList([TimeDistributed(Flatten) for i in range(n_channels)])
        self.trace = []
        self.TransformerTS_list = nn.ModuleList([TransTS() for i in range(n_channels)])
        self.TransformerS_list = nn.ModuleList([TransSensor() for i in range(n_time_steps)])
        self.d_model = d_model

        self.pos_encoder = DFTEncoding(self.d_model)

        self.sigmoid = nn.Sigmoid()

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.dense1 = nn.Linear(output_dim, n_units_l1)
        self.dropout_out1 = nn.Dropout(dropout)
        self.dense2 = nn.Linear(n_units_l1, 1)

        self.n_time_steps = n_time_steps
        self.n_channels = n_channels

    def forward(self, input_x):

        trace = []
        for sensor_n in range(self.n_channels):
            input_layer = input_x[sensor_n]
            input_layer_reshape = input_layer.view(-1, *(input_layer.size()[2:]))
            x = self.conv_cell[sensor_n](input_layer_reshape)
            x = x.view(x.size()[0] // self.n_time_steps, self.n_time_steps, -1)
            trace.append(x)

        x = torch.stack(trace)
        x = x.permute(2, 1, 3, 0)
        x = self.pos_encoder(x)

        input_ts = torch.clone(x)
        input_sensor = torch.clone(x)
        input_ts = input_ts.permute(3, 0, 1, 2)
        input_sensor = input_sensor.permute(0, 3, 1, 2)

        output_ts_list = []
        ts_weights_list = []

        for i in range(self.n_channels):
            channel = input_ts[i, :, :, :]
            trans_ts = self.TransformerTS_list[i]

            output_ts, weights_ts = trans_ts(channel)
            output_ts_list.append(output_ts)
            ts_weights_list.append(weights_ts)

        output_s_list = []
        s_weights_list = []

        for i in range(self.n_time_steps):
            ts = input_sensor[i, :, :, :]
            trans_s = self.TransformerS_list[i]

            output_s, weights_s = trans_s(ts)
            output_s_list.append(output_s)
            s_weights_list.append(weights_s)

        output_ts = torch.stack(output_ts_list)
        output_ts = output_ts.permute(1, 2, 3, 0)
        output_sensor = torch.stack(output_s_list)
        output_sensor = output_sensor.permute(1, 0, 3, 2)

        output_ts_sensor = output_ts + output_sensor

        bs = output_ts_sensor.size()[0]

        output_ts_sensor = output_ts_sensor.view(-1, output_ts_sensor.size()[2], output_ts_sensor.size()[3])
        output_ts_sensor = output_ts_sensor.permute(0, 2, 1)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(output_ts_sensor))))

        output_ts_sensor = output_ts_sensor + self.dropout2(src2)
        src = self.norm2(output_ts_sensor)
        src = src.permute(0, 2, 1)
        src = src.reshape(bs, self.n_time_steps, src.size()[1] * src.size()[2])
        src = src.mean(1)

        x = self.dense1(src)
        x = F.relu(x)
        x = self.dropout_out1(x)
        output = self.dense2(x)

        return output


def custom_MINMAX(batch_x, min_val, max_val):
    '''Scales data between values (a,b)

    Args:

        batch_x (1D mumpy array): features data of the current batch
        a (int): min value of transformed data
        b (int): max value of transformed data

    Returns:

        1D numpy array: Transformed data
    '''
    return (batch_x - min_val) / (max_val - min_val)

from sklearn.model_selection import StratifiedShuffleSplit
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1,AveragePrecision

# data_path = '/data/mcanizo/shap/point_anomaly_dataset.h5'
data_path = '/data/mcanizo/data/simulations_dataset.h5'

sequence_length = 8000       # the length of the time series
window_length = 100          # the length (data points) of the window
window_step = window_length  # indicates how much to slice de window over the time series
time_steps = int(((sequence_length - window_length) / window_step) + 1)  # number of windows in which the time series are divided
n_channels = 20              # number of sensors
batch_size = 16             # the size of the batch
n_epochs = 200                 # number of epochs in training
n_conv_layers = 3            # number of convolutional layers
conv_filters = 20            # number of filters on each convolutional layer
n_rnn_layers = 2             # number of recurrent layers
n_rnn_units = 128            # number of units on each recurrent layer
rnn_layer_type = ['LSTM']   # the type of the recurrent layer (LSTM, Bi-LSTM, GRU, Bi-GRU, SimpleRNN)
learning_rate = 0.001      # the value of the learning rate

# set a random seed to reproduce the results over different executions
random.seed(7)


with h5py.File(data_path, 'r') as hf:
    labels = hf['dataset'][:, -1]

# get an ordered an ascending index of labels to be able to match each instance with its corresponding label
index = np.array(range(len(labels)))
data_index_list = np.array(range(len(labels)))
normal_data_index = [index for index in data_index_list if labels[index] == 0]
anomalous_data_index = [index for index in data_index_list if labels[index] == 1]

np.random.shuffle(anomalous_data_index)
n_anomalies_to_get = int(3 * len(labels) / 100)
reduced_anomalous_data_index = anomalous_data_index[:n_anomalies_to_get]
reduced_data_index = np.hstack((normal_data_index, reduced_anomalous_data_index))
reduced_labels = labels[reduced_data_index]


sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, train_size=0.7, random_state=42)
for train_index, test_index in sss.split(reduced_data_index, reduced_labels):
    # train_index and test_index are not the real indexes but new indexes made from reduced_data_index
    real_train_index, real_test_index = reduced_data_index[train_index], reduced_data_index[test_index]
    
    
    trainset = CustomDataGenerator(data_path, real_train_index, sequence_length, time_steps, window_length, window_step, n_channels)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=12)
    testset = CustomDataGenerator(data_path, real_test_index, sequence_length, time_steps, window_length, window_step, n_channels)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=12)


    # custom class to generate batch data on the fly
    class CustomDataGenerator_minmax(Dataset):
        def __init__(self, path, data_index, sequence_length, time_steps, window_length,
                     window_step, n_channels):
            """ Method called at the initialization of the class (Constructor).

                Args:
                    path (string): path where the dataset is located
                    data_index (array): index list of the dataset
                    batch_size (int): the size of the batch
                    sequence_length (int): the length of the time series
                    time_steps (int): number of windows in which the time series are divided
                    window_length (int): the length (data point) of the window
                    window_step (int): indicates how much to slice de window over the time series
                    n_channels (int): number of sensors
                    isTrainig (bool): whether the generator is for training or test

            """
            self.path = path
            self.data_index = data_index
            self.sequence_length = sequence_length
            self.time_steps = time_steps
            self.window_length = window_length
            self.window_step = window_step
            self.n_channels = n_channels
            self.len = data_index.shape[0]

        def __len__(self):
            """ Method called at the time of requiring the number of batches per epoch

                    Returns:
                        int: number of batches per epoch

                """
            return self.len

        def __getitem__(self, idx):
            index = self.data_index[idx]

            col_size = self.sequence_length * self.n_channels
            with h5py.File(self.path, 'r') as hf:
                x = hf['dataset'][index, :col_size]
                y = hf['dataset'][index, -1:]

            # reshape data to 3D (time_steps* window_length, channels)
            x = reshape_data_minmax(x, self.sequence_length, self.time_steps, self.window_length, self.window_step,
                             self.n_channels)    
            return x

    def reshape_data_minmax(data, sequence_length, time_steps, window_length, window_step, n_channels):
        x_reshaped = np.zeros((time_steps* window_length, n_channels))
        for channel in range(n_channels):
            channel_start_point = channel * sequence_length
            channel_end_point = channel_start_point + sequence_length
            param_values = data[channel_start_point:channel_end_point]
            x_reshaped[:,channel] = param_values

        return x_reshaped

    def MinMax_total(trainloader,n_channels):
        global_min = 1000000000000
        global_max = -1000000000000
        minmax_dict = {'channel_{}'.format(i): [global_min,global_max] for i in range(n_channels)}
        for i, data in enumerate(trainloader, 0):
            batch_x = data
#                 print(batch_x.shape)
            for channel in range(n_channels):
                channel_values = batch_x[:,:,channel]
                min_batch = torch.amin(channel_values)
                max_batch = torch.amax(channel_values)
                if min_batch<minmax_dict['channel_{}'.format(channel)][0]:
                    minmax_dict['channel_{}'.format(channel)][0] = min_batch
                if  max_batch>minmax_dict['channel_{}'.format(channel)][1]:
                    minmax_dict['channel_{}'.format(channel)][1] = max_batch
        return minmax_dict

    trainset_minmax = CustomDataGenerator_minmax(data_path, np.array(real_train_index), sequence_length, time_steps, window_length,
                                   window_step, n_channels)
    trainloader_minmax = torch.utils.data.DataLoader(trainset_minmax, batch_size=batch_size,
                                              shuffle=True, num_workers=6)
    minmax_dict =  MinMax_total(trainloader_minmax,n_channels)
    
    net = DFSTrans_model()
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.00001)
    cudnn.benchmark = True        
    net = net.cuda()
    f1_best = 0
    recall_best = 0
    precision_best = 0
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        
        running_loss = 0.0
        steps_per_epoch = 0
        start_time = time.time()
        n_samples = 0
        
        tp_all = 0
        tn_all = 0
        fp_all = 0
        fn_all = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            inputs = torch.cat(inputs).view(13,inputs[0].size()[0],80,1,100)
            for sensor in range(13):
                channel_values = inputs[sensor].view(inputs[sensor].shape[0], -1)
                min_val,max_val = minmax_dict['channel_{}'.format(sensor)]
                scaled_data = custom_MINMAX(channel_values,min_val,max_val)
                inputs[sensor] = torch.Tensor(scaled_data.float()).view(inputs[0].size()[0], 80, 1, 100)
            
            inputs, labels = Variable(inputs.cuda().type(torch.cuda.FloatTensor)), Variable(labels.cuda())
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            output = net(inputs)

            loss = criterion(output.float(), labels.float())
            output = (torch.sigmoid(output)>0.5).int()
            tp = (labels * output).sum().to(torch.float32)
            tn = ((1 - labels) * (1 - output)).sum().to(torch.float32)
            fp = ((1 - labels) * output).sum().to(torch.float32)
            fn = (labels * (1 - output)).sum().to(torch.float32)
            tp_all+=tp
            tn_all+=tn
            fp_all+=fp
            fn_all+=fn
    
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            steps_per_epoch += 1
            
        precision = tp_all / (tp_all + fp_all)
        recall = tp_all / (tp_all + fn_all)

        f1 = 2* (precision*recall) / (precision + recall)
        
        print('Epoch %d loss: %.5f ' % (epoch+1,running_loss/steps_per_epoch))
        print('Precision train %.5f , Recall train %.5f , F1 train %.5f' % (precision,recall,f1))


        with torch.no_grad():
            tp_all = 0
            tn_all = 0
            fp_all = 0
            fn_all = 0            
            val_loss = 0.0
            steps_per_epoch = 0
            start_time = time.time()
            correct = 0
            for i, data in enumerate(testloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs = torch.cat(inputs).view(13,inputs[0].size()[0],80,1,100)
                for sensor in range(13):
                    channel_values = inputs[sensor].view(inputs[sensor].shape[0], -1)
                    min_val,max_val = minmax_dict['channel_{}'.format(sensor)]
                    scaled_data = custom_MINMAX(channel_values,min_val,max_val)
                    inputs[sensor] = torch.Tensor(scaled_data.float()).view(inputs[0].size()[0], 80, 1, 100)

                inputs, labels = Variable(inputs.cuda().type(torch.cuda.FloatTensor)), Variable(labels.cuda())

                # forward + backward + optimize
                output = net(inputs)

                output = output.flatten()
                labels = labels.flatten()
                output = output.view(output.size()[0],1)
                labels = labels.view(labels.size()[0],1)

                loss = criterion(output.float(), labels.float())
                output = (torch.sigmoid(output)>0.5).int()

                tp = (labels * output).sum().to(torch.float32)
                tn = ((1 - labels) * (1 - output)).sum().to(torch.float32)
                fp = ((1 - labels) * output).sum().to(torch.float32)
                fn = (labels * (1 - output)).sum().to(torch.float32)

                tp_all+=tp
                tn_all+=tn
                fp_all+=fp
                fn_all+=fn

                val_loss += loss.item()
                steps_per_epoch += 1
            
            
            precision = tp_all / (tp_all + fp_all)
            recall = tp_all / (tp_all + fn_all)

            f1 = 2* (precision*recall) / (precision + recall)

            print('test loss: %.5f ' % (val_loss/steps_per_epoch))
            print('Precision test %.5f , Recall test %.5f , F1 test %.5f' % (precision,recall,f1))
            
            if f1 > f1_best:
                f1_best = f1
                checkpoint = {'epoch': epoch + 1,
                          'state_dict': net.state_dict(),
                          'optimizer': optimizer.state_dict()}
                torch.save(checkpoint, 'DFStrans_reduced_sensors.pt')



2
Epoch 1 loss: 0.13435 
Precision train 0.53125 , Recall train 0.05923 , F1 train 0.10658
test loss: 0.09143 
Precision test 1.00000 , Recall test 0.21138 , F1 test 0.34899
Epoch 2 loss: 0.06393 
Precision train 0.89412 , Recall train 0.52962 , F1 train 0.66521
test loss: 0.05733 
Precision test 0.87912 , Recall test 0.65041 , F1 test 0.74766


KeyboardInterrupt: 

In [8]:

class DFSTrans_model(nn.Module):

    def __init__(self, activation="relu", d_model=240, dim_feedforward=2048,
                 dropout=0.1,n_channels = 13,n_time_steps = 80,output_dim = 3120, n_units_l1 = 512, conv_filters = 20):
        super(DFSTrans_model, self).__init__()
        self.conv_cell = nn.ModuleList([MultiHead1DCNN(conv_filters = conv_filters, time_steps = n_time_steps) for i in range(n_channels)])
        self.TimeDistributed_flatten = nn.ModuleList([TimeDistributed(Flatten) for i in range(n_channels)])
        self.trace = []
        self.TransformerTS_list = nn.ModuleList([TransTS() for i in range(n_channels)])
        self.TransformerS_list = nn.ModuleList([TransSensor() for i in range(n_time_steps)])
        self.d_model = d_model

        self.pos_encoder = DFTEncoding(self.d_model)

        self.sigmoid = nn.Sigmoid()

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.dense1 = nn.Linear(output_dim, n_units_l1)
        self.dropout_out1 = nn.Dropout(dropout)
        self.dense2 = nn.Linear(n_units_l1, 1)

        self.n_time_steps = n_time_steps
        self.n_channels = n_channels

    def forward(self, input_x):

        trace = []
        for sensor_n in range(self.n_channels):
            input_layer = input_x[sensor_n]
            input_layer_reshape = input_layer.view(-1, *(input_layer.size()[2:]))
            x = self.conv_cell[sensor_n](input_layer_reshape)
            x = x.view(x.size()[0] // self.n_time_steps, self.n_time_steps, -1)
            trace.append(x)

        x = torch.stack(trace)
        x = x.permute(2, 1, 3, 0)
        x = self.pos_encoder(x)

        input_ts = torch.clone(x)
        input_sensor = torch.clone(x)
        input_ts = input_ts.permute(3, 0, 1, 2)
        input_sensor = input_sensor.permute(0, 3, 1, 2)

        output_ts_list = []
        ts_weights_list = []

        for i in range(self.n_channels):
            channel = input_ts[i, :, :, :]
            trans_ts = self.TransformerTS_list[i]

            output_ts, weights_ts = trans_ts(channel)
            output_ts_list.append(output_ts)
            ts_weights_list.append(weights_ts)

        output_s_list = []
        s_weights_list = []

        for i in range(self.n_time_steps):
            ts = input_sensor[i, :, :, :]
            trans_s = self.TransformerS_list[i]

            output_s, weights_s = trans_s(ts)
            output_s_list.append(output_s)
            s_weights_list.append(weights_s)

        output_ts = torch.stack(output_ts_list)
        output_ts = output_ts.permute(1, 2, 3, 0)
        output_sensor = torch.stack(output_s_list)
        output_sensor = output_sensor.permute(1, 0, 3, 2)

        output_ts_sensor = output_ts + output_sensor

        bs = output_ts_sensor.size()[0]

        output_ts_sensor = output_ts_sensor.view(-1, output_ts_sensor.size()[2], output_ts_sensor.size()[3])
        output_ts_sensor = output_ts_sensor.permute(0, 2, 1)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(output_ts_sensor))))

        output_ts_sensor = output_ts_sensor + self.dropout2(src2)
        src = self.norm2(output_ts_sensor)
        src = src.permute(0, 2, 1)
        src = src.reshape(bs, self.n_time_steps, src.size()[1] * src.size()[2])
        src = src.mean(1)

        x = self.dense1(src)
        x = F.relu(x)
        x = self.dropout_out1(x)
        output = self.dense2(x)

        return output,ts_weights_list,s_weights_list


In [9]:
from sklearn.model_selection import StratifiedShuffleSplit
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1,AveragePrecision
# data_path = '/data/mcanizo/shap/point_anomaly_dataset.h5'


data_path = '/data/mcanizo/data/simulations_dataset.h5'

sequence_length = 8000       # the length of the time series
window_length = 100          # the length (data points) of the window
window_step = window_length  # indicates how much to slice de window over the time series
time_steps = int(((sequence_length - window_length) / window_step) + 1)  # number of windows in which the time series are divided
n_channels = 20              # number of sensors
batch_size = 16             # the size of the batch
n_epochs = 3                 # number of epochs in training
n_conv_layers = 3            # number of convolutional layers
conv_filters = 20            # number of filters on each convolutional layer
n_rnn_layers = 2             # number of recurrent layers
n_rnn_units = 128            # number of units on each recurrent layer
rnn_layer_type = ['LSTM']   # the type of the recurrent layer (LSTM, Bi-LSTM, GRU, Bi-GRU, SimpleRNN)
learning_rate = 0.001      # the value of the learning rate
# dom seed to reproduce the results over different executions
random.seed(7)


with h5py.File(data_path, 'r') as hf:
    labels = hf['dataset'][:, -1]

# get an ordered an ascending index of labels to be able to match each instance with its corresponding label
index = np.array(range(len(labels)))

data_index_list = np.array(range(len(labels)))
normal_data_index = [index for index in data_index_list if labels[index] == 0]
anomalous_data_index = [index for index in data_index_list if labels[index] == 1]


np.random.shuffle(anomalous_data_index)
n_anomalies_to_get = int(60 * len(labels) / 100)
reduced_anomalous_data_index = anomalous_data_index[:n_anomalies_to_get]

reduced_data_index = np.hstack((normal_data_index, reduced_anomalous_data_index))
reduced_labels = labels[reduced_data_index]

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, train_size=0.7, random_state=42)
for train_index, test_index in sss.split(reduced_data_index, reduced_labels):
    
    trainset = CustomDataGenerator(data_path, train_index, sequence_length, time_steps, window_length, window_step, n_channels)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=12)
    testset = CustomDataGenerator(data_path, test_index, sequence_length, time_steps, window_length, window_step, n_channels)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                          shuffle=False, num_workers=12)
    


net = DFSTrans_model()
optimizer = optim.Adam(net.parameters(), lr=0.00001)
net = net.cuda()
checkpoint = torch.load('DFStrans_reduced_sensors.pt')

net.load_state_dict(checkpoint['state_dict'],strict=False)
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
net.eval()



DFSTrans_model(
  (conv_cell): ModuleList(
    (0): MultiHead1DCNN(
      (conv1d1): Conv1d(1, 20, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn1): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (conv1d2): Conv1d(20, 20, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn2): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (conv1d3): Conv1d(20, 20, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn3): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): MultiHead1DCNN(
      (conv1d1): Conv1d(1, 20, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn1): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (conv1d2): Conv1d(20, 20, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn2): BatchNorm1d(20, eps=1e-05, momentum=0

In [13]:

n_batch = 0
with torch.no_grad():

    for i, data in enumerate(testloader, 0):

        inputs, labels = data

        if 1 in labels:
            if n_batch == 0: 

                inputs = torch.cat(inputs).view(13,inputs[0].size()[0],80,1,100)
                for sensor in range(13):
                    channel_values = inputs[sensor].view(inputs[sensor].shape[0], -1)
                    min_val,max_val = minmax_dict['channel_{}'.format(sensor)]
                    scaled_data = custom_MINMAX(channel_values,min_val,max_val)
                    inputs[sensor] = torch.Tensor(scaled_data.float()).view(inputs[0].size()[0], 80, 1, 100)
                inputs, labels = Variable(inputs.cuda().type(torch.cuda.FloatTensor)), Variable(labels.cuda().type(torch.cuda.FloatTensor))
                # zero the parameter gradients
                # forward + backward + optimize

                output, weights_ts, weights_sensor = net(inputs)
                weights_ts = torch.stack(weights_ts)
                weights_sensor = torch.stack(weights_sensor)
                
                break
            n_batch+=1

weights_ts = weights_ts.permute(1,0,2,3)
weights_sensor = weights_sensor.permute(1,0,2,3)

print(torch.sigmoid(output)>0.5)
print(labels)


tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False

In [15]:

%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}


<IPython.core.display.Javascript object>

**GLOBAL INTERPRETATIONS: WHEN AND WHERE THE ANOMALY HAS OCURRED**

In [16]:


from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import shap
import matplotlib.pyplot as plt
from plotly import tools,subplots
import plotly
import plotly.graph_objs as go
import plotly.io as pio
import h5py
from random import uniform
import matplotlib.pyplot as plt
import shap
%matplotlib inline
# weights_ts = weights_ts.permute(1,0,2,3)

# print(np.where(labels.cpu()==1))
def plot_global_att(idx,type_of_explanation,average):
    font = {'family': 'serif',
        'color':  'black',
        'weight': 'normal',
        'size': 30,
        }
    n_channels = 20
    sequence_length = 8000
    output_params = {'Ax':0, 'Ay':1, 'Az':2, 'Id':3, 'Iq':4,'Om':5, 'Phi':6, 'Vc':7, 'Vcw':8, 'Vd':9, 'Vq':10, 'Zc':11, 'Zcw':12}
    params_to_show = ['Ax', 'Ay', 'Az', 'Id', 'Iq','Om', 'Phi', 'Vc', 'Vcw', 'Vd', 'Vq', 'Zc', 'Zcw']

    x = list(range(sequence_length))
    simulation_id = 0

    data = inputs.permute(1,0,2,3,4)
    fig = tools.make_subplots(rows=len(params_to_show), cols=1, shared_xaxes=True)
    plot_row = 0
    
    for param in params_to_show:
        channel_number = output_params[param]
        data_channel = data[idx,channel_number].flatten().cpu()
        trace = go.Scatter(x=x, y=data_channel, name=param)
        plot_row += 1
        fig.append_trace(trace, plot_row, 1)
        
    for row in range(len(params_to_show)):
        fig.update_yaxes(title_text=params_to_show[row], row=row+1, col=1,titlefont=dict(size=40))


    fig.update_layout(legend=dict(title="",font=dict(size = 10),
        orientation="h"),legend_title=dict(font=dict(size=12)),)
    fig.update_layout(
        autosize=False,
        width=1000,
        height=1000,)
    fig.update_xaxes(title="",tickfont=dict(size=20), titlefont=dict(size=20),linecolor="black", showline=True, linewidth=2,  gridcolor='#BCCCDC')

    fig.update_layout(showlegend=False, plot_bgcolor='rgba(0,0,0,0)',legend=dict(title="",font=dict(size = 30)),)
    fig.show()
    
    plt.figure(figsize=(10,10))
    plt.style.use('default')


    if type_of_explanation == 'Temporal':
        plt.figure(figsize=(10,10))
        plt.rcParams['axes.facecolor']='w'
#         plt.grid(True,color = 'grey')
        plt.imshow(weights_ts[idx].mean(0).detach().cpu().numpy())
        plt.colorbar(fraction=0.046, pad=0.04)

        plt.xlabel('Global temporal attention matrix, $A_{t \prime ,t}^{\mathrm{G}}$',fontdict=font)
        plt.savefig('global_temp_matrix_FJ_{}.pdf'.format(idx),bbox_inches='tight',fontdict=font)  
        plt.show()

        plt.figure(figsize=(10,10))    
        plt.rcParams['axes.facecolor']='w'
        plt.grid(True,color = 'grey')
        plt.bar(np.arange(80),np.sum(weights_ts[idx].mean(0).detach().cpu().numpy(),axis=0))
        plt.xlabel('Global temporal relevance scores, $a^{\mathrm{G}}_t$',fontdict=font)
        plt.show()

        
    if type_of_explanation == 'Spatial':
        
            weights_by_ts = np.sum(weights_ts[idx].mean(0).detach().cpu().numpy(),axis=0)

            weighted_avg_sensor = np.average(weights_sensor[idx].detach().cpu().numpy(), weights=weights_by_ts,axis=0)
            plt.figure(figsize=(10,10))
            plt.rcParams['axes.facecolor']='w'
            plt.imshow(weighted_avg_sensor)
            plt.colorbar(fraction=0.046, pad=0.04)
            plt.xlabel('Global spatial attention matrix, $B_{s \prime, s}^{\mathrm{G}} $ ',fontdict=font)
            labels = params_to_show
            plt.xticks([i for i in range(len(labels))],labels,fontsize=30, rotation=90)
            plt.yticks([i for i in range(len(labels))],labels,fontsize=30)
            plt.show()

            labels_n = labels
#             labels_n = ['Ax', 'Id', 'Iq','Vq','Vd','Phi', 'Az','Zc', 'Vc', 'Zcw','Vcw','Om','Ay']
#             labels_n = [labels_n[len(labels_n)-i-1] for i in range(len(labels_n)) ]
            plt.figure()
            plt.rcParams['axes.facecolor']='w'
            plt.grid(True,color = 'grey')
            shap.summary_plot(weighted_avg_sensor,
                              weighted_avg_sensor,
                              params_to_show, plot_type='bar',show=False)
            plt.xlabel('Global spatial relevance scores, $ b^{\mathrm{G}}_s$',fontdict=font)
            plt.yticks([i for i in range(len(labels_n))],labels_n,fontsize=30)
            plt.show()


bs = 100

normal_idx = np.where(labels.cpu()==0)[0]
anom_idx = np.where(labels.cpu()==1)[0]

print('ANOMALOUS INDEXES: ',anom_idx)
print('NON-ANOMALOUS INDEXES: ',normal_idx)
        
interactive(plot_global_att, idx=[i for i in range(bs)],type_of_explanation = ['Temporal','Spatial'],average = ['Weighted'])


ANOMALOUS INDEXES:  [11 16 20 24 36 40 49 67 71 75 79 81 88 90 91 96 99]
NON-ANOMALOUS INDEXES:  [ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 15 17 18 19 21 22 23 25 26 27
 28 29 30 31 32 33 34 35 37 38 39 41 42 43 44 45 46 47 48 50 51 52 53 54
 55 56 57 58 59 60 61 62 63 64 65 66 68 69 70 72 73 74 76 77 78 80 82 83
 84 85 86 87 89 92 93 94 95 97 98]


interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…

**LOCAL INTERPRETATIONS: At what time-step is focusing for a particular sensor?**

In [22]:

from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import shap
import matplotlib.pyplot as plt
from plotly import tools,subplots
import plotly
import plotly.graph_objs as go
import plotly.io as pio
import h5py
from random import uniform
import matplotlib.pyplot as plt
import shap
# import plotly.express as px

normal_idx = np.where(labels.cpu()==0)[0]
anom_idx = np.where(labels.cpu()==1)[0]

print('ANOMALOUS INDEXES: ',anom_idx)
print('NON-ANOMALOUS INDEXES: ',normal_idx)
def local_temporal_att(sensor_name,anom_index):
    n_channels = 20
    sequence_length = 8000
    output_params = {'Ax':0, 'Ay':1, 'Az':2, 'Id':3, 'Iq':4,'Omega':5, 'Phi':6, 'Vc':7, 'Vcw':8, 'Vd':9, 'Vq':10, 'Zc':11, 'Zcw':12}
    params_to_show = ['Ax', 'Ay', 'Az', 'Id', 'Iq','Omega', 'Phi', 'Vc', 'Vcw', 'Vd', 'Vq', 'Zc', 'Zcw']

    x = list(range(sequence_length))
    simulation_id = 0

    data = inputs.permute(1,0,2,3,4)

    x = list(range(sequence_length))
    colors = ['blue','red','orange','green','purple']


    fig = subplots.make_subplots(specs=[[{"secondary_y": True}]])
    plot_row = 0
    data = inputs.permute(1,0,2,3,4)
    sensors_dict = {'Ax':0, 'Ay':1, 'Az':2, 'Id':3, 'Iq':4,'Omega':5, 'Phi':6, 'Vc':7, 'Vcw':8, 'Vd':9, 'Vq':10, 'Zc':11, 'Zcw':12}
    sensor_id = sensors_dict[sensor_name]
    n = 0
    for anom_index in [11,16,20,24,36]:
        bars = np.sum(weights_ts[anom_index][sensor_id].detach().cpu().numpy(),axis=0)
        zero_idx = np.argsort(bars)[:-60]
        
#         bars[zero_idx] = 0
        
        
        bars_extended = np.repeat(bars,100)

        fig.add_trace(go.Scatter(x=np.arange(8000), y=bars_extended, mode='lines', name=sensor_name,
                                 line={'color':colors[n],'width':2}),secondary_y=True)

        
        if sensor_selected=='Iq':
            if anom_index not in [4,5]:
                data_channel = data[anom_index,sensor_id].flatten().cpu()+0.5-0.2*n
            else:
                data_channel = data[anom_index,sensor_id].flatten().cpu()+0.5-0.05*n
        else:
             data_channel = data[anom_index,sensor_id].flatten().cpu()+0.5-0.3*n

        trace = go.Scatter(x=x, y=data_channel, name=sensor_name,line=dict(color=colors[n]))
        fig.add_trace(trace,secondary_y=False)

        n+=1
        


        fig.update_traces(marker_line_width=1.5, opacity=0.6)

        fig.update_layout(legend=dict(title="",font=dict(size = 20),
            orientation="h",
            yanchor="top",
            y=1,
            xanchor="right",
            x=0.97
        ),legend_title=dict(font=dict(size=12)),)
        fig.update_layout(
            autosize=False,
            width=1000,
            height=500,)
        fig.update_xaxes(title="",tickfont=dict(size=20), titlefont=dict(size=20),linecolor="black", showline=True, linewidth=2,  gridcolor='#BCCCDC')
        fig.update_layout(showlegend=False, plot_bgcolor='rgba(0,0,0,0)',legend=dict(title="",font=dict(size = 20)),yaxis_range=[-1.5,1.5])
        fig.update_yaxes(visible=False)


        print('THESE ARE THE MOST INFLUENTIAL TIME STEPS FOR SENSOR {}: '.format(sensor_name))
        plt.figure(figsize=(10,10))
        plt.rcParams['axes.facecolor']='w'
        plt.imshow(weights_ts[anom_index][sensor_id].detach().cpu().numpy())
        plt.colorbar(fraction=0.046, pad=0.04)

        plt.figure()
        plt.rcParams['axes.facecolor']='w'
        plt.grid(True,color = 'grey')
        plt.bar(np.arange(80),np.sum(weights_ts[anom_index][sensor_id].detach().cpu().numpy(),axis=0))
    fig.show()

sensor_selected = 'Iq'
interactive(local_temporal_att, sensor_name= [sensor_selected], anom_index=[i for i in range(100)])


ANOMALOUS INDEXES:  [11 16 20 24 36 40 49 67 71 75 79 81 88 90 91 96 99]
NON-ANOMALOUS INDEXES:  [ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 15 17 18 19 21 22 23 25 26 27
 28 29 30 31 32 33 34 35 37 38 39 41 42 43 44 45 46 47 48 50 51 52 53 54
 55 56 57 58 59 60 61 62 63 64 65 66 68 69 70 72 73 74 76 77 78 80 82 83
 84 85 86 87 89 92 93 94 95 97 98]


interactive(children=(Dropdown(description='sensor_name', options=('Iq',), value='Iq'), Dropdown(description='…

In [23]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import shap
import matplotlib.pyplot as plt
from plotly import tools,subplots
import plotly
import plotly.graph_objs as go
import plotly.io as pio
import h5py
from random import uniform
import matplotlib.pyplot as plt
import shap
normal_idx = np.where(labels.cpu()==0)[0]
anom_idx = np.where(labels.cpu()==1)[0]

print('ANOMALOUS INDEXES: ',anom_idx)
print('NON-ANOMALOUS INDEXES: ',normal_idx)

def local_spatial_att(time_step,anom_index):

    a = np.sum(weights_ts[anom_index].mean(0).detach().cpu().numpy(),axis=0)
    print('MOST INFLUENTIAL TIME STEP: ',np.where(a==max(a)))
    
    print('THESE ARE THE MOST INFLUENTIAL SENSORS AT TIME-STEP {}: '.format(time_step))
    plt.figure(figsize=(15,15))
    plt.imshow(weights_sensor[anom_index][time_step].detach().cpu().numpy())
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.clim(0,0.15)
        
    params_to_show = ['Ax', 'Ay', 'Az', 'Id', 'Iq','Omega', 'Phi', 'Vc', 'Vcw', 'Vd', 'Vq', 'Zc', 'Zcw']

    plt.figure()
    shap.summary_plot(weights_sensor[anom_index][time_step].detach().cpu().numpy(),
                      weights_sensor[anom_index][time_step].detach().cpu().numpy(),
                      params_to_show, plot_type='bar',show=False)
    plt.xlabel('mean(attention weights per sensor) (average impact on the model output magnitude)')

interactive(local_spatial_att, time_step=[i for i in range(80)], anom_index=[i for i in range(100)])

ANOMALOUS INDEXES:  [11 16 20 24 36 40 49 67 71 75 79 81 88 90 91 96 99]
NON-ANOMALOUS INDEXES:  [ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 15 17 18 19 21 22 23 25 26 27
 28 29 30 31 32 33 34 35 37 38 39 41 42 43 44 45 46 47 48 50 51 52 53 54
 55 56 57 58 59 60 61 62 63 64 65 66 68 69 70 72 73 74 76 77 78 80 82 83
 84 85 86 87 89 92 93 94 95 97 98]


interactive(children=(Dropdown(description='time_step', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,…