In [36]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle

In [93]:
with open('../files/phys_data.pkl', 'rb') as f:
    whole_df = pickle.load(f)

In [94]:
whole_df

Unnamed: 0,subj,label,hr,hbo,eda,hrv,temp
0,1,0,"[81.0, 82.0, 83.0, 81.0, 80.0, 78.0, 77.0, 78....","[-0.8554872369111797, -0.57496029561218, -0.49...","[0.2815550556451113, 0.28822303060880383, 0.29...","[nan, nan, nan, nan, nan]","[30.51, 30.51, 30.53, 30.53, 30.53, 30.53, 30...."
1,1,0,"[70.0, 69.0, 69.0, 70.0, 71.0, 72.0, 73.0, 73....","[-3.9035742869604984, -3.6690829794927207, -3....","[0.44406848525288256, 0.44719377405633287, 0.4...","[nan, nan, nan, nan, nan]","[31.55, 31.55, 31.55, 31.51, 31.51, 31.51, 31...."
2,1,0,"[64.0, 64.0, 63.0, 63.0, 63.0, 63.0, 63.0, 64....","[1.3921675744672737, 1.1353524592800208, 0.991...","[0.3726680935510207, 0.3753903461879757, 0.375...","[nan, nan, nan, nan, nan]","[31.51, 31.51, 31.51, 31.49, 31.49, 31.49, 31...."
3,1,0,"[77.0, 78.0, 78.0, 78.0, 78.0, 76.0, 72.0, 69....","[7.245555498720603, 7.329383052879299, 7.18260...","[0.3778913025642409, 0.3804111139699416, 0.381...","[nan, nan, nan, nan, nan]","[31.51, 31.51, 31.51, 31.49, 31.49, 31.49, 31...."
4,1,0,"[76.0, 76.0, 76.0, 76.0, 78.0, 79.0, 79.0, 80....","[-0.6213875555732835, -0.5119793317549247, -0....","[0.38933916610036545, 0.3953971906270052, 0.39...","[nan, nan, nan, nan, nan]","[31.07, 31.07, 31.07, 31.07, 31.07, 31.07, 31...."
...,...,...,...,...,...,...,...
923,29,3,"[149.0, 149.0, 149.0, 149.0, 149.0, 149.0, 150...","[-23.10538728883509, -23.05403200438451, -23.1...","[6.123965544413717, 6.09872522717756, 6.066599...","[391.0, 392.0, 391.0, 389.0, 390.0, 388.0, 389...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
924,29,3,"[138.0, 138.0, 139.0, 140.0, 140.0, 140.0, 140...","[3.058798375795244, 2.655746357782251, 2.55250...","[20.37602942760572, 20.297138745295264, 20.182...","[425.0, 423.0, 421.0, 420.0, 421.0, 423.0, 420...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
925,29,3,"[138.0, 138.0, 138.0, 138.0, 139.0, 139.0, 139...","[-2.7496252732924376, -2.6766007643420946, -2....","[10.658320792222952, 10.677367210273614, 10.69...","[415.0, 416.0, 417.0, 413.0, 414.0, 416.0, 416...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
926,29,3,"[140.0, 140.0, 141.0, 141.0, 141.0, 142.0, 142...","[17.369855603108817, 17.267840840287466, 17.67...","[12.171594831168214, 12.151301662457728, 12.13...","[419.0, 416.0, 417.0, 418.0, 419.0, 417.0, 416...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."


In [95]:
input_signal = whole_df.iloc[0][['hr', 'hbo', 'eda', 'hrv', 'temp']]
input_signal

hr      [81.0, 82.0, 83.0, 81.0, 80.0, 78.0, 77.0, 78....
hbo     [-0.8554872369111797, -0.57496029561218, -0.49...
eda     [0.2815550556451113, 0.28822303060880383, 0.29...
hrv                             [nan, nan, nan, nan, nan]
temp    [30.51, 30.51, 30.53, 30.53, 30.53, 30.53, 30....
Name: 0, dtype: object

In [96]:
signal_df = whole_df[['hr', 'hbo', 'eda', 'hrv', 'temp']]
signal_df.shape

(928, 5)

In [97]:
def find_long_nan_sequences(arr):
    nan_mask = np.isnan(arr) 

    # exception
    if np.all(nan_mask):
        return 1.0

    
    max_seq = 0
    i = 0
    
    for is_nan in nan_mask:
        if is_nan:
            # increase subsequence counter for each consecutive nan value
            i = i+1
        else:
            # if sequence is broken, save seq length and reset counter
            # this way, nan sequences all the way to the end are not included in the count 
            max_seq = max(i, max_seq)
            i = 0

    return max_seq / len(arr)


In [98]:
input_signal[0]

  input_signal[0]


array([81., 82., 83., 81., 80., 78., 77., 78., 78., 80., 80., 80., 80.,
       81., 81., 82., 82., 81., 79., 77., 75., 74., 73., 73., 72., 72.,
       72., 72., 72., 72., 71., 71., 70., 70., 70., 70., 70., 70., 69.,
       69., 68., 68., 68., 68., 68., 68., 68., 68., 67., 67., 67., 67.,
       68., 68., 68., 68., 68., 68., 68., 68., 68., 68., 68., 68., 68.,
       69., 69., 69., 69., 69., 69., 69., 69., 69., 69., 69., 69., 69.,
       69., 69., 70., 70., 70., 70., 69., 69., 68., 68., 68., 68., 68.,
       69., 69., 69., 69., 69., 68., 68., 68., 68., 67., 67., 67., 67.,
       67., 67., 67., 67., 67., 67., 67., 67., 67., 68., 68., 68., 68.,
       68., 67., 67., 67., 67., 67., 67., 67., 68., 67., 67., nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na

In [99]:
# only finds continuous nan sequences interspersed within the data - channel 3 is missing 1.0 (100%) of its data
input_signal.map(lambda x: find_long_nan_sequences(x))

hr      0.0
hbo     0.0
eda     0.0
hrv     1.0
temp    0.0
Name: 0, dtype: float64

In [9]:
# find max percent missing for each trial
for fraction_threshold in range(11):
    temp = signal_df.map(lambda x: find_long_nan_sequences(x) > fraction_threshold/10)
    print(f'# of signal channels above nan threshold %{fraction_threshold/10}: ', [f'{c}:{temp[c].sum()}' for c in ['hr', 'hbo', 'eda', 'hrv', 'temp']])

# of signal channels above nan threshold %0.0:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.1:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.2:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.3:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.4:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.5:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.6:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.7:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.8:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels above nan threshold %0.9:  ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7']
# of signal channels

In [100]:
# for these channels, ['hr:15', 'hbo:15', 'eda:7', 'hrv:151', 'temp:7'] so many are above 90% nan values!

In [101]:
# trials are removed that are found to have more than 90% nan values in any one channel
all_nan = signal_df.map(lambda x: find_long_nan_sequences(x) > 0.9)

In [102]:
# 151 trials have a channel missing more than 90% values, not counting tailing nan values
rows_with_missing_channels = all_nan.apply(lambda x: x.any(), axis=1)
rows_with_missing_channels, rows_with_missing_channels.sum()

(0       True
 1       True
 2       True
 3       True
 4       True
        ...  
 923    False
 924    False
 925    False
 926    False
 927    False
 Length: 928, dtype: bool,
 np.int64(151))

In [103]:
# remove all trials with any channel missing more than 90%
signal_df = signal_df[~rows_with_missing_channels]
signal_df

Unnamed: 0,hr,hbo,eda,hrv,temp
96,"[86.0, 86.0, 85.0, 85.0, 85.0, 85.0, 85.0, 85....","[-0.3046305720835594, -0.31955881048304857, -0...","[2.36582778551346, 2.3711798555679184, 2.37761...","[728.0, 723.0, 696.0, 695.0, 698.0, 713.0, 716...","[34.34, 34.34, 34.34, 34.34, 34.34, 34.34, 34...."
97,"[89.0, 90.0, 89.0, 83.0, 76.0, 75.0, 74.0, 75....","[8.052234975262033, 8.033719458490888, 8.13857...","[1.4318178800804344, 1.4347531463715082, 1.438...","[649.0, 776.0, 964.0, 958.0, 938.0, 882.0, 809...","[34.15, 34.15, 34.15, 34.15, 34.15, 34.15, 34...."
98,"[81.0, 80.0, 79.0, 79.0, 80.0, 80.0, 80.0, 80....","[11.730046881495785, 11.693972329257496, 11.85...","[2.93365803430117, 2.916177078764427, 2.905066...","[820.0, 787.0, 743.0, 725.0, 729.0, 744.0, 775...","[34.41, 34.43, 34.43, 34.43, 34.43, 34.41, 34...."
99,"[90.0, 90.0, 89.0, 89.0, 88.0, 88.0, 88.0, 89....","[-3.7118986439529715, -3.6020856713586022, -3....","[1.360926244516401, 1.368126748476356, 1.37547...","[677.0, 697.0, 679.0, 691.0, 696.0, 691.0, 691...","[33.41, 33.39, 33.39, 33.39, 33.39, 33.41, 33...."
100,"[85.0, 85.0, 85.0, 85.0, 83.0, 82.0, 80.0, 80....","[2.9662142523880535, 2.6790208296620466, 2.716...","[1.2437807082337249, 1.24384202999343, 1.24382...","[692.0, 709.0, 707.0, 735.0, 761.0, 844.0, 868...","[34.37, 34.37, 34.37, 34.37, 34.37, 34.37, 34...."
...,...,...,...,...,...
923,"[149.0, 149.0, 149.0, 149.0, 149.0, 149.0, 150...","[-23.10538728883509, -23.05403200438451, -23.1...","[6.123965544413717, 6.09872522717756, 6.066599...","[391.0, 392.0, 391.0, 389.0, 390.0, 388.0, 389...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
924,"[138.0, 138.0, 139.0, 140.0, 140.0, 140.0, 140...","[3.058798375795244, 2.655746357782251, 2.55250...","[20.37602942760572, 20.297138745295264, 20.182...","[425.0, 423.0, 421.0, 420.0, 421.0, 423.0, 420...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
925,"[138.0, 138.0, 138.0, 138.0, 139.0, 139.0, 139...","[-2.7496252732924376, -2.6766007643420946, -2....","[10.658320792222952, 10.677367210273614, 10.69...","[415.0, 416.0, 417.0, 413.0, 414.0, 416.0, 416...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
926,"[140.0, 140.0, 141.0, 141.0, 141.0, 142.0, 142...","[17.369855603108817, 17.267840840287466, 17.67...","[12.171594831168214, 12.151301662457728, 12.13...","[419.0, 416.0, 417.0, 418.0, 419.0, 417.0, 416...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."


In [104]:
# arrays have exclusively tailing nan values or are completely missing data. Does each array have the same fraction of tailing nan values?
arr = []
for c in ['hr', 'hbo', 'eda', 'hrv', 'temp']:
    s = input_signal[c]
    r = np.sum(np.isnan(s)) / len(s)
    print(r)
    arr.append(r)

# each array seems to have the same tail-end fraction cut off by nan values. 
# This means the variance of the nan_count to array_length ratio should be close to 0. 
# Since the variance is the mean standard deviation, it is very sensitive to outliers
# Any values above a certain threshold suggest that an array with far more missing values is present. 
print('variance of nan to signal length ratio across channels', np.var(np.array(arr)))

0.3155080213903743
0.31852423355943466
0.3181818181818182
1.0
0.3181818181818182
variance of nan to signal length ratio across channels 0.07450856742213492


In [105]:
nan_variance = signal_df.apply(lambda y: np.var([y.map(lambda x: np.sum(np.isnan(x)) / len(x))]), axis=1)
# with all trials removed that have missing signal channels, the variance is very high cfor channels with an uneven amount of tailing nan values
print(nan_variance,'median accross all rows', nan_variance.median())

96     0.001232
97     0.001839
98     0.002100
99     0.000942
100    0.001849
         ...   
923    0.017855
924    0.011198
925    0.007222
926    0.011874
927    0.013168
Length: 777, dtype: float64 median accross all rows 0.0030050632295451035


In [106]:
# the median absolute deviation measures dispersion of values around the median. Taking the z-score, the convention is to flag Zmad > 3 as potential outliers
nan_var_med = nan_variance.median()
mad = np.median(np.abs(nan_variance - nan_var_med))
mad_z_scores = np.abs(nan_variance - nan_var_med) / (1.4826 * mad)
mad_z_scores

96     0.459403
97     0.302170
98     0.234548
99     0.534518
100    0.299469
         ...   
923    3.846486
924    2.122214
925    1.092324
926    2.297422
927    2.632424
Length: 777, dtype: float64

In [107]:
# boolean index of high z_scores - potential outliers
outliers = mad_z_scores > 3
outliers, outliers.sum()

(96     False
 97     False
 98     False
 99     False
 100    False
        ...  
 923     True
 924    False
 925    False
 926    False
 927    False
 Length: 777, dtype: bool,
 np.int64(47))

In [108]:
# all trials flagged as outliers: for the shorter sequence length channels, these are clearly visible
signal_df[outliers]

Unnamed: 0,hr,hbo,eda,hrv,temp
146,"[102.0, 103.0, 103.0, 103.0, 104.0, 104.0, 105...","[-8.823993915085032, -8.840882687392488, -8.77...","[0.1041230192439944, 0.10358191494526042, 0.10...","[589.0, 581.0, 577.0, 584.0, 584.0, 577.0, 573...","[30.49, 30.49, 30.49, 30.49, 30.49, 30.47, 30...."
147,"[100.0, 101.0, 102.0, 103.0, 104.0, 104.0, 105...","[-7.828278666123876, -7.938407956434872, -7.95...","[0.13142537833554185, 0.13358221389344, 0.1363...","[586.0, 582.0, 587.0, 584.0, 576.0, 564.0, 561...","[30.33, 30.33, 30.33, 30.31, 30.31, 30.31, 30...."
150,"[103.0, 102.0, 102.0, 102.0, 103.0, 104.0, 105...","[-7.787689556400292, -7.747496611991344, -7.65...","[0.21886583097416304, 0.22113014904364703, 0.2...","[589.0, 584.0, 584.0, 593.0, 587.0, 583.0, 587...","[32.63, 32.63, 32.63, 32.66, 32.66, 32.66, 32...."
151,"[107.0, 106.0, 105.0, 105.0, 104.0, 104.0, 105...","[-2.110626571324452, -1.9622803175438148, -1.9...","[0.22368499399303163, 0.22392059952748186, 0.2...","[563.0, 560.0, 577.0, 578.0, 576.0, 570.0, 591...","[33.33, 33.33, 33.33, 33.33, 33.33, 33.33, 33...."
292,"[75.0, 76.0, 76.0, 77.0, 79.0, 80.0, 80.0, 81....","[-4.649254304993229, -4.61831301089607, -4.615...","[2.394393902079656, 2.395028785246666, 2.39595...","[814.0, 766.0, 805.0, 727.0, 713.0, 688.0, 671...","[33.0, 33.0, 33.0, 33.0, 33.0, 33.0, 33.0, 32...."
293,"[78.0, 79.0, 80.0, 81.0, 80.0, 79.0, 78.0, 78....","[3.2149497072749393, 3.276974031340498, 3.2076...","[2.5772797625690678, 2.575284154546029, 2.5730...","[718.0, 678.0, 695.0, 721.0, 836.0, 832.0, 846...","[32.41, 32.41, 32.41, 32.41, 32.41, 32.41, 32...."
320,"[96.0, nan, nan, nan, nan, nan, nan, nan, nan,...","[-1.261301074625934, -1.223298045035024, -1.21...","[4.589267828668605, 4.577069339759905, 4.56793...","[655.0, nan, nan, nan, nan, nan, nan, nan, nan...","[35.71, 35.71, 35.71, 35.68, 35.68, 35.68, 35...."
321,"[96.0, nan, nan, nan, nan, nan, nan, nan, nan,...","[-4.83593773867081, -4.830249535933329, -4.790...","[3.952255300848163, 3.9830227102247435, 4.0130...","[655.0, nan, nan, nan, nan, nan, nan, nan, nan...","[35.47, 35.47, 35.5, 35.5, 35.5, 35.5, 35.5, 3..."
322,"[96.0, nan, nan, nan, nan, nan, nan, nan, nan,...","[1.4378971712766797, 1.4225714894052117, 1.441...","[3.9115963478904026, 3.9186041870467747, 3.929...","[655.0, nan, nan, nan, nan, nan, nan, nan, nan...","[35.5, 35.5, 35.5, 35.55, 35.55, 35.55, 35.55,..."
323,"[96.0, nan, nan, nan, nan, nan, nan, nan, nan,...","[1.251462946587023, 1.2434395620048506, 1.2038...","[4.455995135060351, 4.454483908740345, 4.45388...","[655.0, nan, nan, nan, nan, nan, nan, nan, nan...","[35.53, 35.53, 35.53, 35.49, 35.49, 35.49, 35...."


In [109]:
# remove trials with outlier z-score higher than 3
signal_df = signal_df[~outliers]

In [110]:
signal_df

Unnamed: 0,hr,hbo,eda,hrv,temp
96,"[86.0, 86.0, 85.0, 85.0, 85.0, 85.0, 85.0, 85....","[-0.3046305720835594, -0.31955881048304857, -0...","[2.36582778551346, 2.3711798555679184, 2.37761...","[728.0, 723.0, 696.0, 695.0, 698.0, 713.0, 716...","[34.34, 34.34, 34.34, 34.34, 34.34, 34.34, 34...."
97,"[89.0, 90.0, 89.0, 83.0, 76.0, 75.0, 74.0, 75....","[8.052234975262033, 8.033719458490888, 8.13857...","[1.4318178800804344, 1.4347531463715082, 1.438...","[649.0, 776.0, 964.0, 958.0, 938.0, 882.0, 809...","[34.15, 34.15, 34.15, 34.15, 34.15, 34.15, 34...."
98,"[81.0, 80.0, 79.0, 79.0, 80.0, 80.0, 80.0, 80....","[11.730046881495785, 11.693972329257496, 11.85...","[2.93365803430117, 2.916177078764427, 2.905066...","[820.0, 787.0, 743.0, 725.0, 729.0, 744.0, 775...","[34.41, 34.43, 34.43, 34.43, 34.43, 34.41, 34...."
99,"[90.0, 90.0, 89.0, 89.0, 88.0, 88.0, 88.0, 89....","[-3.7118986439529715, -3.6020856713586022, -3....","[1.360926244516401, 1.368126748476356, 1.37547...","[677.0, 697.0, 679.0, 691.0, 696.0, 691.0, 691...","[33.41, 33.39, 33.39, 33.39, 33.39, 33.41, 33...."
100,"[85.0, 85.0, 85.0, 85.0, 83.0, 82.0, 80.0, 80....","[2.9662142523880535, 2.6790208296620466, 2.716...","[1.2437807082337249, 1.24384202999343, 1.24382...","[692.0, 709.0, 707.0, 735.0, 761.0, 844.0, 868...","[34.37, 34.37, 34.37, 34.37, 34.37, 34.37, 34...."
...,...,...,...,...,...
921,"[146.0, 146.0, 146.0, 146.0, 147.0, 147.0, 148...","[-41.49453834236974, -41.42930016955021, -41.4...","[1.7414557833087962, 1.727774149878133, 1.7175...","[409.0, 408.0, 406.0, 407.0, 408.0, 405.0, 406...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
924,"[138.0, 138.0, 139.0, 140.0, 140.0, 140.0, 140...","[3.058798375795244, 2.655746357782251, 2.55250...","[20.37602942760572, 20.297138745295264, 20.182...","[425.0, 423.0, 421.0, 420.0, 421.0, 423.0, 420...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
925,"[138.0, 138.0, 138.0, 138.0, 139.0, 139.0, 139...","[-2.7496252732924376, -2.6766007643420946, -2....","[10.658320792222952, 10.677367210273614, 10.69...","[415.0, 416.0, 417.0, 413.0, 414.0, 416.0, 416...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."
926,"[140.0, 140.0, 141.0, 141.0, 141.0, 142.0, 142...","[17.369855603108817, 17.267840840287466, 17.67...","[12.171594831168214, 12.151301662457728, 12.13...","[419.0, 416.0, 417.0, 418.0, 419.0, 417.0, 416...","[382.21, 382.21, 382.21, 382.21, 382.21, 382.2..."


In [111]:
nan_to_length_ratio = signal_df.map(lambda x: np.sum(np.isnan(x)) / len(x)).apply(lambda x: x.max(), axis=1)
nan_to_length_ratio

96     0.482143
97     0.500000
98     0.500000
99     0.467262
100    0.502976
         ...   
921    0.350785
924    0.350785
925    0.331080
926    0.334054
927    0.335889
Length: 730, dtype: float64

In [112]:
percent_nan = np.array([nan_to_length_ratio.values for _ in range(5)]).T
percent_nan

array([[0.48214286, 0.48214286, 0.48214286, 0.48214286, 0.48214286],
       [0.5       , 0.5       , 0.5       , 0.5       , 0.5       ],
       [0.5       , 0.5       , 0.5       , 0.5       , 0.5       ],
       ...,
       [0.33107958, 0.33107958, 0.33107958, 0.33107958, 0.33107958],
       [0.33405411, 0.33405411, 0.33405411, 0.33405411, 0.33405411],
       [0.33588946, 0.33588946, 0.33588946, 0.33588946, 0.33588946]])

In [113]:
percent_nan.shape

(730, 5)

In [114]:
arr_length = signal_df.map(lambda x: len(x))
arr_length

Unnamed: 0,hr,hbo,eda,hrv,temp
96,203,101041,809,336,809
97,203,101041,809,336,809
98,203,101041,809,336,809
99,203,101041,809,336,809
100,203,101041,809,336,809
...,...,...,...,...,...
921,191,94805,759,332,759
924,191,94805,759,332,759
925,191,94805,759,332,759
926,191,94805,759,332,759


In [115]:
arr_length.values.shape

(730, 5)

In [116]:
indexes = percent_nan * arr_length.values
indexes, indexes.shape

(array([[   97.875     , 48716.19642857,   390.05357143,   162.        ,
           390.05357143],
        [  101.5       , 50520.5       ,   404.5       ,   168.        ,
           404.5       ],
        [  101.5       , 50520.5       ,   404.5       ,   168.        ,
           404.5       ],
        ...,
        [   63.23620062, 31388.        ,   251.28940457,   109.91842202,
           251.28940457],
        [   63.80433521, 31670.        ,   253.5470703 ,   110.90596488,
           253.5470703 ],
        [   64.15488635, 31844.        ,   254.9400981 ,   111.51529983,
           254.9400981 ]]),
 (730, 5))

In [118]:
signal_np = signal_df.values
signal_np.shape

(730, 5)

In [None]:
# Extract elements using loops
signal_np_truncated = np.empty((730, 5), dtype=object)

for i in range(730):
    for j in range(5):
        idx = int(indexes[i, j]) 
        seq = signal_np[i, j]  
        if idx < len(seq):  
            signal_np_truncated[i, j] = seq[:-idx]
        else:
            raise ValueError(f'There are {idx} nan values in a sequence of lenght {len(seq)}')



In [126]:
signal_np_truncated.shape

(730, 5)

In [127]:
input_signal.map(lambda x: len(x) if type(x) == np.ndarray else x)

hr        187
hbo     93321
eda       748
hrv         5
temp      748
Name: 0, dtype: int64

In [128]:
whole_df[['hr', 'hbo', 'eda', 'hrv', 'temp']].apply(lambda y: y.map(lambda x: len(x)).min())

hr        185
hbo     91811
eda       735
hrv         5
temp      735
dtype: int64

In [129]:
min_timeframe = whole_df['eda'].map(lambda x: len(x)).min()
min_timeframe

np.int64(735)

In [14]:
def stretch_arr(t, target_len):
    
    orig_len = len(t)
    orig_idx = np.linspace(0, target_len - 1, orig_len)
    new_idx = np.arange(target_len) 

    # linear interpolation
    interpolated = np.interp(new_idx, orig_idx, t)

    return torch.tensor(interpolated, dtype=torch.float32)

In [22]:
t1 = np.array([1, 2, 3, np.nan, 5, np.nan, 7, 8, 9, 10, np.nan, 8, 7, 6, 5, 4, 3, 2, 1])
t2 = np.array([5, 6])
t3 = np.array([7, 8, np.nan, 10, 11])

tensors = [t1, t2, t3]

# length to which other tensors are downsampled
resample = len(t1)

# stretch to resample length and stack, now they fit
t_stretched = torch.stack([stretch_arr(t, resample) for t in tensors])

t_stretched

tensor([[ 1.0000,  2.0000,  3.0000,     nan,  5.0000,     nan,  7.0000,  8.0000,
          9.0000, 10.0000,     nan,  8.0000,  7.0000,  6.0000,  5.0000,  4.0000,
          3.0000,  2.0000,  1.0000],
        [ 5.0000,  5.0556,  5.1111,  5.1667,  5.2222,  5.2778,  5.3333,  5.3889,
          5.4444,  5.5000,  5.5556,  5.6111,  5.6667,  5.7222,  5.7778,  5.8333,
          5.8889,  5.9444,  6.0000],
        [ 7.0000,  7.2222,  7.4444,  7.6667,  7.8889,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan, 10.1111, 10.3333,
         10.5556, 10.7778, 11.0000]])

In [19]:
def fill_nan_running_mean(arr):
    mask = np.isnan(arr)
    if np.all(mask): 
        return np.zeros_like(arr)
    
    # cumulative sum of array
    cumsum = np.nancumsum(arr)
    # cumulative count of valid elements
    count = np.cumsum(~mask)
    running_mean = np.divide(cumsum, count, where=(count != 0))

    arr[mask] = running_mean[mask]
    return arr

In [23]:
t1

array([ 1.,  2.,  3., nan,  5., nan,  7.,  8.,  9., 10., nan,  8.,  7.,
        6.,  5.,  4.,  3.,  2.,  1.])

In [133]:
features = whole_df[['hr', 'hbo', 'eda', 'hrv', 'temp']].values
labels = whole_df[['subj', 'label']].values

features.shape, labels.shape

((928, 5), (928, 2))

In [134]:
whole_df[['subj', 'label']].nunique()

subj     29
label     4
dtype: int64

In [135]:
c = {'hr':0, 'hbo':1, 'eda':2, 'hrv':3, 'temp':4}
list(c.keys())

['hr', 'hbo', 'eda', 'hrv', 'temp']

In [1]:
type(whole_df.loc[0, 'hr'])

NameError: name 'whole_df' is not defined