In [15]:
from Error_Compare import *
import numpy as no
from dtw import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [209]:
activity_dir = "data/dataset/effad897-2991-4c2c-9a6a-f85a111d0e3d"
min_time = 655
live_min_time = 660
max_time = 1800
# Settings
error_type = "Percent"
lower_threshold = 0.25
upper_threshold = np.inf
# error_type = "Absolute" #
# lower_threshold = 20
# upper_threshold = np.inf
live_shift = "cross_correlation"
#live_shift = 30

activity = os.path.split(activity_dir)[-1]
clean_dfs = load_data(activity_dir, has_uncleaned = False)

key_metrics = ["breathTime", "VT"]#, "VT", "RRAvg", "instBR"]


# Load Data
raw_chest = clean_dfs["raw_slow_df"][["time","c",]]

live_b3_df = clean_dfs["live_b3_df"][key_metrics]
pp_b3_df = clean_dfs["aws_b3_df"][key_metrics]


# Only look at a subset of the data
raw_chest = raw_chest[(raw_chest["time"] > min_time) & (raw_chest["time"] < max_time)]
live_b3_df = live_b3_df[(live_b3_df["breathTime"] > live_min_time) & (live_b3_df["breathTime"] < max_time)]
pp_b3_df = pp_b3_df[(pp_b3_df["breathTime"] > min_time) & (pp_b3_df["breathTime"] < max_time)]

# Clean Data
pp_df, pp_cleaning_demo = run_cleaning_process(pp_b3_df[key_metrics], demo = True)
live_df = run_cleaning_process(live_b3_df[key_metrics], demo = False)



No pickled data found in activity data directory. Cleaning data...
cleaned_data folder already exists at data/dataset\cleaned_data


In [106]:
# Visualize Error
fig = plot_metrics_compare(raw_chest, live_df, pp_df, show = False)
fig.show()

In [55]:
# Get max index of each time-series
max_live_index = live_df.index.max()
max_pp_index = pp_df.index.max()
print("Max live index: ", max_live_index)
print("Max pp index: ", max_pp_index)
# Get min index of each time-series
min_live_index = live_df.index.min()
min_pp_index = pp_df.index.min()
print("Min live index: ", min_live_index)
print("Min pp index: ", min_pp_index)

Max live index:  197
Max pp index:  196
Min live index:  106
Min pp index:  100


In [56]:
reference = pp_b3_df
query = live_b3_df


In [58]:
# Visualize Reference and query
fig = go.Figure()
fig.add_trace(go.Scatter(x = pp_b3_df["breathTime"], y = pp_b3_df["VT"], name = "Reference", mode = "markers", marker=dict(color="blue")))
fig.add_trace(go.Scatter(x = live_b3_df["breathTime"], y = live_b3_df["VT"], name = "Query", mode = "markers",marker=dict(color="red")))
fig.show()

In [166]:
def distance(r: tuple,q: tuple):
    return abs(r[1]-q[1])

def shift_error(r: tuple,q: tuple, avg_shift = 0):
    return abs(r[0] - q[0] - avg_shift)**2

def dtw_distance(qry, ref, window, avg_shift = False):
    '''
    Rules:
        1. each element of query can only be matched to a preceding element of reference
        2. each element is only matched to one query
        3. each element of query can only be matched to an element of reference that is within the window
        4. if there does not exist an element of reference within the window, then the element of query is matched to the first element outside of the window
    :param ref: (x,y) pairs of the reference signal
    :param qry: (x,y) pairs of the query signal
    :param window: the maximum x distance between the reference and query signal
    :return:
    '''
    # Initialize
    n = len(qry)
    m = len(ref)
    DTW = np.zeros((n,m))
    DTW[0,0] = 0
    i_to_j = np.zeros(n, dtype=int)# a map from each element in the query to an element in the reference
    shift_avg = window/2
    shift_record = []
    # Loop through each element of the query
    for i,q in zip(range(n),qry):
        # i is the index of the query
        # q is the element of the query
        # Loop through each element of the reference
        min_r0 = q[0] - window # Minimum x value of reference that can be matched to query
        max_r0 = q[0]
        min_cost = np.inf
        # get all elements of reference that can be matched to query
        ref_window = ref[(ref[:,0] >= min_r0) & (ref[:,0] <= max_r0)]
        # get index of each element in ref_window
        ref_window_index = np.where((ref[:,0] >= min_r0) & (ref[:,0] <= max_r0))[0]
        #print(f"q: {q}, i, :{i},  ref_window: {ref_window}, valid_j's: {ref_window_index}")
        for j,r in zip(ref_window_index,ref_window):
            if avg_shift:
                cost = distance(r,q) + shift_error(r,q,shift_avg)
            else:
                cost = distance(r,q)
            if cost <= min_cost:
                min_cost = cost
                i_to_j[i] = j

            DTW[i,j] = cost + min(DTW[i-1,j-1],DTW[i-1,j],DTW[i,j-1])
        if avg_shift:
            # get exponential average of the shift
            if i == 0:
                shift_avg = shift_avg
            else:
                shift_avg = 0.6*shift_avg + 0.4*(q[0] - ref[i_to_j[i]][0])
            shift_record.append((q[0] - ref[i_to_j[i]][0],shift_avg))
    if avg_shift:
        return DTW, i_to_j, shift_record
    else:
        return DTW, i_to_j

def get_warp_path(qry, ref, DTW):
    '''
    Warp the query signal to the reference signal by finding the minimum cost path through the DTW matrix

    :param qry:
    :param ref:
    :param DTW:
    :return:
    '''
    # Initialize
    n = len(qry)
    m = len(ref)
    path = []
    i = n
    j = m
    while i >= 0 and j >= 0:
        path.append((i,j))
        if i == 0:
            j = j-1
        elif j == 0:
            i = i-1
        else:
            if DTW[i-1,j] == min(DTW[i-1,j-1],DTW[i-1,j],DTW[i,j-1]):
                i = i-1
            elif DTW[i,j-1] == min(DTW[i-1,j-1],DTW[i-1,j],DTW[i,j-1]):
                j = j-1
            else:
                i = i-1
                j = j-1
    path.append((0,0))
    return path

def remap_query(qry, ref, i_to_j):
    '''
    Remap the query signal to the reference signal by using the i_to_j map
    :param qry:
    :param ref:
    :param i_to_j:
    :return:
    '''
    # Initialize
    n = len(qry)
    m = len(ref)
    remapped_qry = np.zeros((n,2))
    for i in range(n):
        remapped_qry[i,0] = ref[int(i_to_j[i]),0]
        remapped_qry[i,1] = qry[i,1]

    return remapped_qry


In [167]:


qry = np.array([[1,2],[2,4],[4,6],[6,2],[7,3]])
ref = np.array([[1,1],[2,3],[3,4],[5,3],[6,3]])
window = 2
DTW, i_to_j = dtw_distance(qry, ref, window)
#path = get_warp_path(qry, ref, DTW)
warped = remap_query(qry, ref, i_to_j)

In [91]:
print(DTW)

[[1. 0. 0. 0. 0.]
 [3. 1. 0. 0. 0.]
 [0. 3. 2. 0. 0.]
 [0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 0.]]


In [98]:
print(i_to_j)
print(warped)

[0. 1. 2. 4. 4.]
[[1. 1.]
 [2. 3.]
 [3. 4.]
 [6. 3.]
 [6. 3.]]


In [104]:
# Visualize the reference,query, and warped signals
fig = go.Figure()
fig.add_trace(go.Scatter(x = qry[:,0], y = qry[:,1], name = "Query", marker=dict(color="blue")))
fig.add_trace(go.Scatter(x = ref[:,0], y = ref[:,1], name = "Reference",marker=dict(color="red")))
fig.add_trace(go.Scatter(x = warped[:,0], y = warped[:,1], name = "Warped",marker=dict(color="green")))
fig.show()

In [210]:
# Now try DTW remapping on the live and pp signals
qry = live_b3_df[["breathTime","VT"]].to_numpy()
ref = pp_b3_df[["breathTime","VT"]].to_numpy()
# visualize the signals
fig = go.Figure()
fig.add_trace(go.Scatter(x = qry[:,0], y = qry[:,1], name = "Query", marker=dict(color="blue")))
fig.add_trace(go.Scatter(x = ref[:,0], y = ref[:,1], name = "Reference",marker=dict(color="red")))
fig.show()

In [211]:
window = 8
DTW, i_to_j, _ = dtw_distance(qry, ref, window, avg_shift = True)
#path = get_warp_path(qry, ref, DTW)
warped = remap_query(qry, ref, i_to_j)

In [212]:
def fix_i_to_j(i_to_j_og):
    # in reverse order of i_to_j, if
    i_to_j = deepcopy(i_to_j_og)
    for i in range(len(i_to_j)-1,1,-1):
        if i_to_j[i] == i_to_j[i-1]:
            i_to_j[i-1] -= 1
    return i_to_j

In [213]:
i_to_j_fixed = fix_i_to_j(i_to_j)
rewarped = remap_query(qry, ref, i_to_j_fixed)


In [179]:
#print i_to_j
for n in range(len(i_to_j)):
    print(f"i: {n}, j: {i_to_j[n]}| qry[i]: {qry[n]}, ref[j]: {ref[i_to_j[n]]}" )


i: 0, j: 0| qry[i]: [ 59 184], ref[j]: [ 56.6 121.6]
i: 1, j: 0| qry[i]: [ 62 136], ref[j]: [ 56.6 121.6]
i: 2, j: 2| qry[i]: [ 66 136], ref[j]: [ 63.92 136.7 ]
i: 3, j: 2| qry[i]: [ 70 147], ref[j]: [ 63.92 136.7 ]
i: 4, j: 3| qry[i]: [ 73 146], ref[j]: [ 67.48 135.1 ]
i: 5, j: 4| qry[i]: [77 58], ref[j]: [71.16 68.6 ]
i: 6, j: 5| qry[i]: [ 80 168], ref[j]: [ 74.28 161.7 ]
i: 7, j: 6| qry[i]: [ 83 170], ref[j]: [ 77.84 158.1 ]
i: 8, j: 7| qry[i]: [ 86 154], ref[j]: [ 80.84 145.9 ]
i: 9, j: 8| qry[i]: [ 90 133], ref[j]: [ 84.44 121.9 ]
i: 10, j: 9| qry[i]: [ 95 214], ref[j]: [ 89.2 201.8]
i: 11, j: 10| qry[i]: [100 321], ref[j]: [ 94.72 293.4 ]
i: 12, j: 12| qry[i]: [106  93], ref[j]: [103.8 103.4]
i: 13, j: 12| qry[i]: [109 115], ref[j]: [103.8 103.4]
i: 14, j: 13| qry[i]: [114 198], ref[j]: [107.72 180.1 ]
i: 15, j: 14| qry[i]: [117 171], ref[j]: [111.76 157.1 ]


In [214]:
# Visualize the reference,query, and warped signals
fig = go.Figure()
fig.add_trace(go.Scatter(x = qry[:,0], y = qry[:,1], name = "Query", marker=dict(color="blue")))
fig.add_trace(go.Scatter(x = ref[:,0], y = ref[:,1], name = "Reference",marker=dict(color="red")))
fig.add_trace(go.Scatter(x = warped[:,0], y = warped[:,1], name = "Warped",marker=dict(color="green")))
fig.add_trace(go.Scatter(x = rewarped[:,0], y = rewarped[:,1], name = "Re-Warped",marker=dict(color="purple")))
fig.show()

In [140]:
# Visualize DTW
fig = go.Figure(data=go.Heatmap(
                   z=DTW,
                   x=[i for i in range(len(ref))],
                   y=[i for i in range(len(qry))],
                   colorscale='Viridis'))
fig.show()

In [165]:
for item in shift_record:
    print(f"Shift: {item[0]:.2f}; Avg: {item[1]:.2f}")

Shift: 1.56; Avg: 3.02
Shift: 2.08; Avg: 2.65
Shift: 2.52; Avg: 2.60
Shift: 5.52; Avg: 3.77
Shift: 5.84; Avg: 4.60
Shift: 2.16; Avg: 3.62
Shift: 2.16; Avg: 3.04
Shift: 1.56; Avg: 2.45
Shift: 5.56; Avg: 3.69
Shift: 0.28; Avg: 2.33
Shift: 5.28; Avg: 3.51
Shift: 2.20; Avg: 2.98
Shift: 5.20; Avg: 3.87
Shift: 2.24; Avg: 3.22
Shift: 1.60; Avg: 2.57


In [None]:
# Estimate avg shift using cross correlation
e