In [1]:
import numpy as np
import matplotlib.pyplot as plt
import hdf5storage
import librosa

from model import *
from tb_utils import *
from preprocessing import *

from scipy import io

from sklearn.metrics import accuracy_score

default_envelope_config = {
    'homomorphic_envelogram_with_hilbert': {'lpf_frequency': 8},
    'psd': {'fl_low': 40, 'fl_high': 60, 'resample': True},
    'wavelet': {'wavelet': 'db1',
                'levels': [4],
                'start_level': 1,
                'end_level': 5,
                'erase_pad': True}
}

In [2]:
def load_groundtruth(filename):
    with open(filename, 'r') as f:
        data = f.read()

    data = data.split('\n')

    input_data = np.empty((len(data)-1, 3), dtype=float)  # Creazione di un array vuoto con le dimensioni corrette

    for i in range(len(data)-1):
        row_values = data[i].split('\t')
        for j in range(len(row_values)):
            input_data[i, j] = float(row_values[j])

    # Find the duration in seconds of the recording (last value of the second column)
    T = input_data[-1,1]

    # Generate a time array with 50 Hz sampling rate
    t = np.arange(0, T, 1.0/50.0)

    # Generate the s array
    s = np.zeros_like(t)

    # For each time instant, find its heart state and assign it to the s array
    for i in range(len(t)):
        # Find the first value greater than t(i)
        j = np.argmax(input_data[:,1] > t[i])
        # Assign the heart state to s(i)
        s[i] = input_data[j,2]
    
    return s

In [3]:
def prepare_dataset(x_enve,s_enve,N,tau):
    X = np.zeros((0, 4, N))
    S = np.zeros((0, N))

    # Get envelopes
    x_global = x_enve
    s_global = s_enve


    # Extract the indexes of the s elements with heart state information
    labeled_idxs_global = np.where(s_global!=0)[0]

    # Find 0 intervals between heart state changes
    zero_intervals = np.diff(labeled_idxs_global)-1 != 0

    # Split x and s between those intervals
    x_split = np.split(x_global, labeled_idxs_global[1:][zero_intervals], axis=1)
    s_split = np.split(s_global, labeled_idxs_global[1:][zero_intervals])

    for k in range(len(x_split)):

        # Extract the indexes of the s elements with heart state information
        labeled_idxs = np.where(s_split[k]!=0)[0]

        # Use only data with heart state information
        x = x_split[k][:, labeled_idxs]
        s = s_split[k][labeled_idxs]

        if x.shape[1] < N:
            # If the window is smaller than N, discard it
            continue

        x = rolling_strided_window(x, N, tau)
        s = rolling_strided_window(s, N, tau)

        x, s = check_valid_sequence(x, s, 1)

        # Stack the windows
        X = np.vstack((X, x))
        S = np.vstack((S, s))
        
    # Create a new axis in S and concatenate X and S
    S = S[:, np.newaxis, :]
    XS = np.concatenate((X, S), axis=1)

    # Shuffle the samples
    #np.random.shuffle(XS)

    # Return the X and S arrays
    X = XS[:, :x.shape[1], :]
    S = XS[:, x.shape[1]:, :]

    # Swap axes to format the data as channels_last
    X = np.swapaxes(X, 1, 2)
    S = np.swapaxes(S, 1, 2)

    # Transform S to categorical
    S = to_categorical(S-1)

    return X,S

In [4]:
def unroll_strided_windows(S: np.ndarray, tau: int) -> np.ndarray:
    """Unrolls the input `S` 3D array with shape (n_windows, N, 4), which is
    supposed to be generated with stride `tau`, outputing a 2D vector. The
    elements in the overlapping positions are averaged.

    Args:
    -----
        S (np.ndarray): Input 3D array.
        tau (int): Stride of the input array.

    Returns:
    --------
        s (np.ndarray): 2D array with shape (tau*(n_windows-1) + N, 4).
    """
    
    # Check that the input array is a np.ndarray
    if not isinstance(S, np.ndarray):
        raise TypeError('Input array must be a np.ndarray')
    
    # Check that the input array is 2D
    if S.ndim != 3:
        raise ValueError('Input array must be 3D.')

    # Check that the stride is positive integer
    if not isinstance(tau, int):
        raise TypeError('Stride must be an integer.')

    # Obtain the window size and the number of windows
    N = S.shape[1]
    n_windows = S.shape[0]

    # Calculate the length of the output array
    s_len = tau*(n_windows-1) + N

    # Create a 2D array of NaNs of size (n_windows, s_len)
    s_expanded = np.full((n_windows, s_len, 4), np.nan)

    # Allocate each window to the corresponding position in the expanded array
    for i in range(n_windows):
        s_expanded[i, tau*i:tau*i+N, :] = S[i, :, :]
    
    # Calculate the mean of the expanded array in the first axis
    s = np.nanmean(s_expanded, axis=0)

    return np.squeeze(s)

In [5]:
def seq_max_temporal_model(x: np.ndarray) -> np.ndarray:
    """Implementation of the sequential max temporal modeling. It forces the input
    states sequence to contain only admisible transitions among heart states
    (S1->systolic->S2->diastolic->S1).

    Args:
    -----
        x (np.ndarray): Input sequence of states. The elements must be integers
        between 1 and 4.

    Returns:
    --------
        y (np.ndarray): Output sequence of states, where only admisible
        transitions are present.
    """

    # Check if x is a numpy array of 1D
    if not isinstance(x, np.ndarray) or x.ndim != 1:
        raise TypeError('x must be a numpy array of 1D.')
    
    # Create y as an empty array of same size as x
    y = np.zeros(x.shape)

    # Set that the first element of y is the same as the first element of x
    y[0] = x[0]

    # Iterate over the rest of the elements of x
    for i in range(1, x.shape[0]):
        # If x[i] = (x[i-1] + 1) % 4, then y[i] = x[i]
        if y[i-1] % 4 + 1 == x[i]:
            y[i] = x[i]
        # Otherwise, y[i] = y[i-1]
        else:
            y[i] = y[i-1]
    
    return y

In [6]:
#Load recordings
recording_2530, frequency = librosa.load('./2530_AV.wav', sr=None)
recording_14241, frequency = librosa.load('./14241_PV.wav', sr=None)
recording_23625, frequency = librosa.load('./23625_MV.wav', sr=None)
recording_24160, frequency = librosa.load('./24160_MV.wav', sr=None)
recording_40840, frequency = librosa.load('./40840_TV.wav', sr=None)

#Renna pre-processing
pre_proc_data_2530_orig=renna_preprocess_wave(input_signal=recording_2530, fs=4000, config_dict=default_envelope_config) #WITH final standardization
pre_proc_data_14241_orig=renna_preprocess_wave(input_signal=recording_14241, fs=4000, config_dict=default_envelope_config)
pre_proc_data_23625_orig=renna_preprocess_wave(input_signal=recording_23625, fs=4000, config_dict=default_envelope_config)
pre_proc_data_24160_orig=renna_preprocess_wave(input_signal=recording_24160, fs=4000, config_dict=default_envelope_config)
pre_proc_data_40840_orig=renna_preprocess_wave(input_signal=recording_40840, fs=4000, config_dict=default_envelope_config)

#Load CNN
model = get_model()
model.load_weights('parameters.h5')

In [7]:
#2530 Q32
envelopes_2530Q32 = [] #appending order is the same as the one in "preprocessing.py"
envelopes_2530Q32.append(np.squeeze(hdf5storage.loadmat('./2530_data/homo_test_2530.mat')['homo_test'])[0:pre_proc_data_2530_orig.shape[1]])
envelopes_2530Q32.append(np.squeeze(hdf5storage.loadmat('./2530_data/hilb_test_2530.mat')['hilb_test'])[0+1:pre_proc_data_2530_orig.shape[1]+1])
envelopes_2530Q32.append(np.squeeze(hdf5storage.loadmat('./2530_data/psd_test_2530.mat')['psd_test'])[0+1:pre_proc_data_2530_orig.shape[1]+1]) #I have to correct the amount of group delay in the logic, thus I do it here temporaney, visually assessed
envelopes_2530Q32.append(np.squeeze(hdf5storage.loadmat('./2530_data/swt_test_2530.mat')['swt_test'])[0:pre_proc_data_2530_orig.shape[1]])

envelopes_2530Q32 = np.stack(envelopes_2530Q32, axis=0)
envelopes_2530Q32 = (envelopes_2530Q32 - envelopes_2530Q32.mean(axis=1, keepdims=True)) / envelopes_2530Q32.std(axis=1, keepdims=True)

#2530 Q16
envelopes_2530Q16 = []
envelopes_2530Q16.append(np.squeeze(hdf5storage.loadmat('./2530_data_Q16/homo_env_2530Q16.mat')['homo_env'])[0:pre_proc_data_2530_orig.shape[1]])
envelopes_2530Q16.append(np.squeeze(hdf5storage.loadmat('./2530_data_Q16/hilb_env_2530Q16.mat')['hilb_env'])[0+1:pre_proc_data_2530_orig.shape[1]+1])
envelopes_2530Q16.append(np.squeeze(hdf5storage.loadmat('./2530_data_Q16/psd_env_2530Q16.mat')['psd_env'])[0+1:pre_proc_data_2530_orig.shape[1]+1])
envelopes_2530Q16.append(np.squeeze(hdf5storage.loadmat('./2530_data_Q16/swt_env_2530Q16.mat')['swt_env'])[0:pre_proc_data_2530_orig.shape[1]])

envelopes_2530Q16 = np.stack(envelopes_2530Q16, axis=0)
envelopes_2530Q16 = (envelopes_2530Q16 - envelopes_2530Q16.mean(axis=1, keepdims=True)) / envelopes_2530Q16.std(axis=1, keepdims=True)

#14241 Q32
envelopes_14241Q32 = []
envelopes_14241Q32.append(np.squeeze(hdf5storage.loadmat('./14241_data/homo_env_14241.mat')['homo_env'])[0:pre_proc_data_14241_orig.shape[1]])
envelopes_14241Q32.append(np.squeeze(hdf5storage.loadmat('./14241_data/hilb_env_14241.mat')['hilb_env'])[0+1:pre_proc_data_14241_orig.shape[1]+1])
envelopes_14241Q32.append(np.squeeze(hdf5storage.loadmat('./14241_data/psd_env_14241.mat')['psd_env'])[0+1:pre_proc_data_14241_orig.shape[1]+1])
envelopes_14241Q32.append(np.squeeze(hdf5storage.loadmat('./14241_data/swt_env_14241.mat')['swt_env'])[0:pre_proc_data_14241_orig.shape[1]])

envelopes_14241Q32 = np.stack(envelopes_14241Q32, axis=0)
envelopes_14241Q32 = (envelopes_14241Q32 - envelopes_14241Q32.mean(axis=1, keepdims=True)) / envelopes_14241Q32.std(axis=1, keepdims=True)

#14241 Q16
envelopes_14241Q16 = []
envelopes_14241Q16.append(np.squeeze(hdf5storage.loadmat('./14241_data_Q16/homo_env_14241Q16.mat')['homo_env'])[0:pre_proc_data_14241_orig.shape[1]])
envelopes_14241Q16.append(np.squeeze(hdf5storage.loadmat('./14241_data_Q16/hilb_env_14241Q16.mat')['hilb_env'])[0+1:pre_proc_data_14241_orig.shape[1]+1])
envelopes_14241Q16.append(np.squeeze(hdf5storage.loadmat('./14241_data_Q16/psd_env_14241Q16.mat')['psd_env'])[0+1:pre_proc_data_14241_orig.shape[1]+1])
envelopes_14241Q16.append(np.squeeze(hdf5storage.loadmat('./14241_data_Q16/swt_env_14241Q16.mat')['swt_env'])[0:pre_proc_data_14241_orig.shape[1]])

envelopes_14241Q16 = np.stack(envelopes_14241Q16, axis=0)
envelopes_14241Q16 = (envelopes_14241Q16 - envelopes_14241Q16.mean(axis=1, keepdims=True)) / envelopes_14241Q16.std(axis=1, keepdims=True)

#23625 Q32
envelopes_23625Q32 = []
envelopes_23625Q32.append(np.squeeze(hdf5storage.loadmat('./23625_data/homo_env_23625Q32.mat')['homo_env'])[0:pre_proc_data_23625_orig.shape[1]])
envelopes_23625Q32.append(np.squeeze(hdf5storage.loadmat('./23625_data/hilb_env_23625Q32.mat')['hilb_env'])[0+1:pre_proc_data_23625_orig.shape[1]+1])
envelopes_23625Q32.append(np.squeeze(hdf5storage.loadmat('./23625_data/psd_env_23625Q32.mat')['psd_env'])[0+1:pre_proc_data_23625_orig.shape[1]+1])
envelopes_23625Q32.append(np.squeeze(hdf5storage.loadmat('./23625_data/swt_env_23625Q32.mat')['swt_env'])[0:pre_proc_data_23625_orig.shape[1]])

envelopes_23625Q32 = np.stack(envelopes_23625Q32, axis=0)
envelopes_23625Q32 = (envelopes_23625Q32 - envelopes_23625Q32.mean(axis=1, keepdims=True)) / envelopes_23625Q32.std(axis=1, keepdims=True)

#23625 Q16
envelopes_23625Q16 = []
envelopes_23625Q16.append(np.squeeze(hdf5storage.loadmat('./23625_data_Q16/homo_env_23625Q16.mat')['homo_env'])[0:pre_proc_data_23625_orig.shape[1]])
envelopes_23625Q16.append(np.squeeze(hdf5storage.loadmat('./23625_data_Q16/hilb_env_23625Q16.mat')['hilb_env'])[0+1:pre_proc_data_23625_orig.shape[1]+1])
envelopes_23625Q16.append(np.squeeze(hdf5storage.loadmat('./23625_data_Q16/psd_env_23625Q16.mat')['psd_env'])[0+1:pre_proc_data_23625_orig.shape[1]+1])
envelopes_23625Q16.append(np.squeeze(hdf5storage.loadmat('./23625_data_Q16/swt_env_23625Q16.mat')['swt_env'])[0:pre_proc_data_23625_orig.shape[1]])

envelopes_23625Q16 = np.stack(envelopes_23625Q16, axis=0)
envelopes_23625Q16 = (envelopes_23625Q16 - envelopes_23625Q16.mean(axis=1, keepdims=True)) / envelopes_23625Q16.std(axis=1, keepdims=True)

#24160 Q32
envelopes_24160Q32 = []
envelopes_24160Q32.append(np.squeeze(hdf5storage.loadmat('./24160_data/homo_env_24160.mat')['homo_env'])[0:pre_proc_data_24160_orig.shape[1]])
envelopes_24160Q32.append(np.squeeze(hdf5storage.loadmat('./24160_data/hilb_env_24160.mat')['hilb_env'])[0+1:pre_proc_data_24160_orig.shape[1]+1])
envelopes_24160Q32.append(np.squeeze(hdf5storage.loadmat('./24160_data/psd_env_24160.mat')['psd_env'])[0+1:pre_proc_data_24160_orig.shape[1]+1])
envelopes_24160Q32.append(np.squeeze(hdf5storage.loadmat('./24160_data/swt_env_24160.mat')['swt_env'])[0:pre_proc_data_24160_orig.shape[1]])

envelopes_24160Q32 = np.stack(envelopes_24160Q32, axis=0)
envelopes_24160Q32 = (envelopes_24160Q32 - envelopes_24160Q32.mean(axis=1, keepdims=True)) / envelopes_24160Q32.std(axis=1, keepdims=True)

#24160 Q16
envelopes_24160Q16 = []
envelopes_24160Q16.append(np.squeeze(hdf5storage.loadmat('./24160_data_Q16/homo_env_24160Q16.mat')['homo_env'])[0:pre_proc_data_24160_orig.shape[1]])
envelopes_24160Q16.append(np.squeeze(hdf5storage.loadmat('./24160_data_Q16/hilb_env_24160Q16.mat')['hilb_env'])[0+1:pre_proc_data_24160_orig.shape[1]+1])
envelopes_24160Q16.append(np.squeeze(hdf5storage.loadmat('./24160_data_Q16/psd_env_24160Q16.mat')['psd_env'])[0+1:pre_proc_data_24160_orig.shape[1]+1])
envelopes_24160Q16.append(np.squeeze(hdf5storage.loadmat('./24160_data_Q16/swt_env_24160Q16.mat')['swt_env'])[0:pre_proc_data_24160_orig.shape[1]])

envelopes_24160Q16 = np.stack(envelopes_24160Q16, axis=0)
envelopes_24160Q16 = (envelopes_24160Q16 - envelopes_24160Q16.mean(axis=1, keepdims=True)) / envelopes_24160Q16.std(axis=1, keepdims=True)

#40840 Q32
envelopes_40840Q32 = []
envelopes_40840Q32.append(np.squeeze(hdf5storage.loadmat('./40840_data/homo_env_40840.mat')['homo_env'])[0:pre_proc_data_40840_orig.shape[1]])
envelopes_40840Q32.append(np.squeeze(hdf5storage.loadmat('./40840_data/hilb_env_40840.mat')['hilb_env'])[0+1:pre_proc_data_40840_orig.shape[1]+1])
envelopes_40840Q32.append(np.squeeze(hdf5storage.loadmat('./40840_data/psd_env_40840.mat')['psd_env'])[0+1:pre_proc_data_40840_orig.shape[1]+1])
envelopes_40840Q32.append(np.squeeze(hdf5storage.loadmat('./40840_data/swt_env_40840.mat')['swt_env'])[0:pre_proc_data_40840_orig.shape[1]])

envelopes_40840Q32 = np.stack(envelopes_40840Q32, axis=0)
envelopes_40840Q32 = (envelopes_40840Q32 - envelopes_40840Q32.mean(axis=1, keepdims=True)) / envelopes_40840Q32.std(axis=1, keepdims=True)

#40840 Q16
envelopes_40840Q16 = []
envelopes_40840Q16.append(np.squeeze(hdf5storage.loadmat('./40840_data_Q16/homo_env_40840Q16.mat')['homo_env'])[0:pre_proc_data_40840_orig.shape[1]])
envelopes_40840Q16.append(np.squeeze(hdf5storage.loadmat('./40840_data_Q16/hilb_env_40840Q16.mat')['hilb_env'])[0+1:pre_proc_data_40840_orig.shape[1]+1])
envelopes_40840Q16.append(np.squeeze(hdf5storage.loadmat('./40840_data_Q16/psd_env_40840Q16.mat')['psd_env'])[0+1:pre_proc_data_40840_orig.shape[1]+1])
envelopes_40840Q16.append(np.squeeze(hdf5storage.loadmat('./40840_data_Q16/swt_env_40840Q16.mat')['swt_env'])[0:pre_proc_data_40840_orig.shape[1]])

envelopes_40840Q16 = np.stack(envelopes_40840Q16, axis=0)
envelopes_40840Q16 = (envelopes_40840Q16 - envelopes_40840Q16.mean(axis=1, keepdims=True)) / envelopes_40840Q16.std(axis=1, keepdims=True)

In [8]:
s_2530 = load_groundtruth("2530_AV.tsv")
s_14241 = load_groundtruth("14241_PV.tsv")
s_23625 = load_groundtruth("23625_MV.tsv") #it's a plus
s_24160 = load_groundtruth("24160_MV.tsv")
s_40840 = load_groundtruth("40840_TV.tsv")

In [9]:
N=64
tau=8

X_2530Q32, S_2530Q32 = prepare_dataset(envelopes_2530Q32,s_2530,N,tau) #function adapted from Daniel cose "generate_X_S_2022()" in utils.preprocessing.py see: https://github.com/eneriz-daniel/PCG-Segmentation-Model-Optimization/blob/master/training/utils/preprocessing.py
X_2530Q16, S_2530Q16 = prepare_dataset(envelopes_2530Q16,s_2530,N,tau)
X_2530orig, S_2530orig = prepare_dataset(pre_proc_data_2530_orig,s_2530,N,tau)

X_14241Q32, S_14241Q32 = prepare_dataset(envelopes_14241Q32,s_14241,N,tau)
X_14241Q16, S_14241Q16 = prepare_dataset(envelopes_14241Q16,s_14241,N,tau)
X_14241orig, S_14241orig = prepare_dataset(pre_proc_data_14241_orig,s_14241,N,tau)

X_23625Q32, S_23625Q32 = prepare_dataset(envelopes_23625Q32,s_23625,N,tau)
X_23625Q16, S_23625Q16 = prepare_dataset(envelopes_23625Q16,s_23625,N,tau)
X_23625orig, S_23625orig = prepare_dataset(pre_proc_data_23625_orig,s_23625,N,tau)

X_24160Q32, S_24160Q32 = prepare_dataset(envelopes_24160Q32,s_24160,N,tau)
X_24160Q16, S_24160Q16 = prepare_dataset(envelopes_24160Q16,s_24160,N,tau)
X_24160orig, S_24160orig = prepare_dataset(pre_proc_data_24160_orig,s_24160,N,tau)

X_40840Q32, S_40840Q32 = prepare_dataset(envelopes_40840Q32,s_40840,N,tau)
X_40840Q16, S_40840Q16 = prepare_dataset(envelopes_40840Q16,s_40840,N,tau)
X_40840orig, S_40840orig = prepare_dataset(pre_proc_data_40840_orig,s_40840,N,tau)

In [10]:
pred_2530Q32 = model.predict(X_2530Q32)
pred_2530Q16 = model.predict(X_2530Q16)
pred_2530_orig = model.predict(X_2530orig)

pred_14241Q32 = model.predict(X_14241Q32)
pred_14241Q16 = model.predict(X_14241Q16)
pred_14241_orig = model.predict(X_14241orig)

pred_23625Q32 = model.predict(X_23625Q32)
pred_23625Q16 = model.predict(X_23625Q16)
pred_23625_orig = model.predict(X_23625orig)

pred_24160Q32 = model.predict(X_24160Q32)
pred_24160Q16 = model.predict(X_24160Q16)
pred_24160_orig = model.predict(X_24160orig)

pred_40840Q32 = model.predict(X_40840Q32)
pred_40840Q16 = model.predict(X_40840Q16)
pred_40840_orig = model.predict(X_40840orig)


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 197ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step 
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4

In [11]:
s_pred2530Q32 = np.squeeze((unroll_strided_windows(pred_2530Q32, tau)).argmax(axis=-1)+1)
s_true2530Q32 = np.squeeze((unroll_strided_windows(S_2530Q32, tau)).argmax(axis=-1)+1)
s_pred2530Q16 = np.squeeze((unroll_strided_windows(pred_2530Q16, tau)).argmax(axis=-1)+1)
s_true2530Q16 = np.squeeze((unroll_strided_windows(S_2530Q16, tau)).argmax(axis=-1)+1)
s_pred2530orig = np.squeeze((unroll_strided_windows(pred_2530_orig, tau)).argmax(axis=-1)+1)
s_true2530orig = np.squeeze((unroll_strided_windows(S_2530orig, tau)).argmax(axis=-1)+1)

s_pred14241Q32 = np.squeeze((unroll_strided_windows(pred_14241Q32, tau)).argmax(axis=-1)+1)
s_true14241Q32 = np.squeeze((unroll_strided_windows(S_14241Q32, tau)).argmax(axis=-1)+1)
s_pred14241Q16 = np.squeeze((unroll_strided_windows(pred_14241Q16, tau)).argmax(axis=-1)+1)
s_true14241Q16 = np.squeeze((unroll_strided_windows(S_14241Q16, tau)).argmax(axis=-1)+1)
s_pred14241orig = np.squeeze((unroll_strided_windows(pred_14241_orig, tau)).argmax(axis=-1)+1)
s_true14241orig = np.squeeze((unroll_strided_windows(S_14241orig, tau)).argmax(axis=-1)+1)

s_pred23625Q32 = np.squeeze((unroll_strided_windows(pred_23625Q32, tau)).argmax(axis=-1)+1)
s_true23625Q32 = np.squeeze((unroll_strided_windows(S_23625Q32, tau)).argmax(axis=-1)+1)
s_pred23625Q16 = np.squeeze((unroll_strided_windows(pred_23625Q16, tau)).argmax(axis=-1)+1)
s_true23625Q16 = np.squeeze((unroll_strided_windows(S_23625Q16, tau)).argmax(axis=-1)+1)
s_pred23625orig = np.squeeze((unroll_strided_windows(pred_23625_orig, tau)).argmax(axis=-1)+1)
s_true23625orig = np.squeeze((unroll_strided_windows(S_23625orig, tau)).argmax(axis=-1)+1)

s_pred24160Q32 = np.squeeze((unroll_strided_windows(pred_24160Q32, tau)).argmax(axis=-1)+1)
s_true24160Q32 = np.squeeze((unroll_strided_windows(S_24160Q32, tau)).argmax(axis=-1)+1)
s_pred24160Q16 = np.squeeze((unroll_strided_windows(pred_24160Q16, tau)).argmax(axis=-1)+1)
s_true24160Q16 = np.squeeze((unroll_strided_windows(S_24160Q16, tau)).argmax(axis=-1)+1)
s_pred24160orig = np.squeeze((unroll_strided_windows(pred_24160_orig, tau)).argmax(axis=-1)+1)
s_true24160orig = np.squeeze((unroll_strided_windows(S_24160orig, tau)).argmax(axis=-1)+1)

s_pred40840Q32 = np.squeeze((unroll_strided_windows(pred_40840Q32, tau)).argmax(axis=-1)+1)
s_true40840Q32 = np.squeeze((unroll_strided_windows(S_40840Q32, tau)).argmax(axis=-1)+1)
s_pred40840Q16 = np.squeeze((unroll_strided_windows(pred_40840Q16, tau)).argmax(axis=-1)+1)
s_true40840Q16 = np.squeeze((unroll_strided_windows(S_40840Q16, tau)).argmax(axis=-1)+1)
s_pred40840orig = np.squeeze((unroll_strided_windows(pred_40840_orig, tau)).argmax(axis=-1)+1)
s_true40840orig = np.squeeze((unroll_strided_windows(S_40840orig, tau)).argmax(axis=-1)+1)

In [12]:
s_pred2530Q32_stm = seq_max_temporal_model(s_pred2530Q32)
s_pred2530Q16_stm = seq_max_temporal_model(s_pred2530Q16)
s_pred2530orig_stm = seq_max_temporal_model(s_pred2530orig)

s_pred14241Q32_stm = seq_max_temporal_model(s_pred14241Q32)
s_pred14241Q16_stm = seq_max_temporal_model(s_pred14241Q16)
s_pred14241orig_stm = seq_max_temporal_model(s_pred14241orig)

s_pred23625Q32_stm = seq_max_temporal_model(s_pred23625Q32)
s_pred23625Q16_stm = seq_max_temporal_model(s_pred23625Q16)
s_pred23625orig_stm = seq_max_temporal_model(s_pred23625orig)

s_pred24160Q32_stm = seq_max_temporal_model(s_pred24160Q32)
s_pred24160Q16_stm = seq_max_temporal_model(s_pred24160Q16)
s_pred24160orig_stm = seq_max_temporal_model(s_pred24160orig)

s_pred40840Q32_stm = seq_max_temporal_model(s_pred40840Q32)
s_pred40840Q16_stm = seq_max_temporal_model(s_pred40840Q16)
s_pred40840orig_stm = seq_max_temporal_model(s_pred40840orig)

In [13]:
acc_2530Q32 = accuracy_score(s_true2530Q32, s_pred2530Q32_stm, normalize=True)
acc_2530Q16 = accuracy_score(s_true2530Q16, s_pred2530Q16_stm, normalize=True)
acc_2530orig = accuracy_score(s_true2530orig, s_pred2530orig_stm, normalize=True)

acc_14241Q32 = accuracy_score(s_true14241Q32, s_pred14241Q32_stm, normalize=True)
acc_14241Q16 = accuracy_score(s_true14241Q16, s_pred14241Q16_stm, normalize=True)
acc_14241orig = accuracy_score(s_true14241orig, s_pred14241orig_stm, normalize=True)

acc_23625Q32 = accuracy_score(s_true23625Q32, s_pred23625Q32_stm, normalize=True)
acc_23625Q16 = accuracy_score(s_true23625Q16, s_pred23625Q16_stm, normalize=True)
acc_23625orig = accuracy_score(s_true23625orig, s_pred23625orig_stm, normalize=True)

acc_24160Q32 = accuracy_score(s_true24160Q32, s_pred24160Q32_stm, normalize=True)
acc_24160Q16 = accuracy_score(s_true24160Q16, s_pred24160Q16_stm, normalize=True)
acc_24160orig = accuracy_score(s_true24160orig, s_pred24160orig_stm, normalize=True)

acc_40840Q32 = accuracy_score(s_true40840Q32, s_pred40840Q32_stm, normalize=True)
acc_40840Q16 = accuracy_score(s_true40840Q16, s_pred40840Q16_stm, normalize=True)
acc_40840orig = accuracy_score(s_true40840orig, s_pred40840orig_stm, normalize=True)

print("2530 Orig ACC: ", "{:.2f}".format(acc_2530orig))
print("2530 Q32 ACC: ", "{:.2f}".format(acc_2530Q32))
print("2530 Q16 ACC: ", "{:.2f}".format(acc_2530Q16))
print()
print("14241 Orig ACC: ", "{:.2f}".format(acc_14241orig))
print("14241 Q32 ACC: ", "{:.2f}".format(acc_14241Q32))
print("14241 Q16 ACC: ", "{:.2f}".format(acc_14241Q16))
print()
print("23625 Orig ACC: ", "{:.2f}".format(acc_23625orig))
print("23625 Q32 ACC: ", "{:.2f}".format(acc_23625Q32))
print("23625 Q16 ACC: ", "{:.2f}".format(acc_23625Q16))
print()
print("24160 Orig ACC: ", "{:.2f}".format(acc_24160orig))
print("24160 Q32 ACC: ", "{:.2f}".format(acc_24160Q32))
print("24160 Q16 ACC: ", "{:.2f}".format(acc_24160Q16))
print()
print("40840 Orig ACC: ", "{:.2f}".format(acc_40840orig))
print("40840 Q32 ACC: ", "{:.2f}".format(acc_40840Q32))
print("40840 Q16 ACC: ", "{:.2f}".format(acc_40840Q16))


2530 Orig ACC:  0.92
2530 Q32 ACC:  0.90
2530 Q16 ACC:  0.90

14241 Orig ACC:  0.91
14241 Q32 ACC:  0.90
14241 Q16 ACC:  0.90

23625 Orig ACC:  0.93
23625 Q32 ACC:  0.91
23625 Q16 ACC:  0.91

24160 Orig ACC:  0.88
24160 Q32 ACC:  0.68
24160 Q16 ACC:  0.67

40840 Orig ACC:  0.91
40840 Q32 ACC:  0.91
40840 Q16 ACC:  0.90
