# Analyze synthetic gap experiments
### Includes 
* Aggregate median error heatmaps
* Individual median error heatmaps
* Example imputation experiment trace

In [1]:
import numpy as np
from scipy.io import loadmat
import os
import h5py
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from pathlib import Path
import re

In [2]:
# Find all of the files that completed. 
def find_files(base_path, key):
    return [f for f in Path(base_path).glob('**/' + key)]

base_path = '/n/home02/daldarondo/LabDir/Diego/data'
files = find_files(base_path, 'analyze_6_26_19/errors.mat')
print("Data from: ")
[print(f) for f in files]

# Load the data
data = [loadmat(f) for f in files]
export_folder = '/n/holylfs02/LABS/olveczky_lab/Diego/code/MarkerBasedImputation/mbi/analysis/SyntheticGapAnalysis/images'
if not os.path.isdir(export_folder):
    os.makedirs(export_folder)
plt.rcParams['font.size'] = 20


Data from: 
/n/home02/daldarondo/LabDir/Diego/data/JDM27/20171207/models/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM25/20170917/models/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM25/20170919/models/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM25/20170916/models/stride_5/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM32/20171024/models/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM32/20171023/models/stride_5/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM31_imputation_test/models/model_ensemble/viz/analyze_6_26_19/errors.mat
/n/home02/daldarondo/LabDir/Diego/data/JDM33/20171125/models/model_ensemble/viz/analyze_6_26_19/errors.mat


In [3]:
for dat in data:
    print(dat['delta_markers'].shape)

(1, 10)
(1, 10)
(1, 10)
(1, 1)
(1, 10)
(1, 10)
(1, 1)
(1, 10)


## Find the median error at the mid-gap in sythetically generated gaps

In [4]:
def get_midgap(data):
    midgap = np.floor(data.shape[1]/2).astype('int32')
    return np.squeeze(data[:, midgap, :])

# Get data for all individuals
n_lengths = data[0]['delta_markers'].size
ind_median_error = [None] * len(files)
for i, d in enumerate(data):
    error = [get_midgap(d['delta_markers'][0, length]) for length in range(n_lengths)]
    ind_median_error[i] = [np.nanmedian(e, axis=0) for e in error]
    ind_median_error[i] = np.stack(ind_median_error[i],axis=1)

# Get data for aggregate
agg_median_error = [None] * n_lengths
for i, length in enumerate(range(n_lengths)):
    error = [get_midgap(d['delta_markers'][0, length]) for d in data]
    agg_median_error[i] = np.nanmedian(np.concatenate(error, axis=0), axis=0)

agg_median_error = np.stack(agg_median_error, axis=1)

IndexError: index 1 is out of bounds for axis 1 with size 1

## Plot aggregate data as a heat map

In [None]:
f, ax = plt.subplots(figsize=(12,8))
kp_names = ['HeadF', 'HeadB', 'HeadL', 'SpineF', 'SpineM',
            'SpineL', 'Offset1', 'Offset2', 'HipL', 'HipR',
            'ElbowL', 'ArmL', 'ShoulderL', 'ShoulderR',
            'ElbowR', 'ArmR', 'KneeR', 'KneeL', 'ShinL', 'ShinR']

def delta_heatmap(data, ax):
    n_lengths = data.shape[1]
    inds = np.argsort(data[:,-1])
    plt.imshow(data[inds, :],aspect=.5)
    ax.set_yticks(range(data.shape[0]))
    ax.set_yticklabels([kp_names[i] for i in inds])
    ax.set_xticks(range(n_lengths))
    ax.set_xticklabels([i*10 for i in range(1, n_lengths+1)])
    ax.set_xlabel('Synthetic gap duration (frames)')
    c = plt.colorbar()
    c.ax.set_ylabel('Median error at mid-gap (mm)')
    plt.rcParams['font.size'] = 20
    
delta_heatmap(agg_median_error, ax)
save_path = os.path.join(export_folder, 'Aggregate_median_error_at_midgap.eps')
plt.savefig(save_path)

# Make the same plot for each individual

In [None]:
ind_median_error[0].shape
for i, f in enumerate(files):
    fig, ax = plt.subplots(figsize=(12,8))
    delta_heatmap(ind_median_error[i], ax)
    rat_name = re.findall('JDM.*/*/', str(f))
    rat_name = re.sub('/models/.*', '', rat_name[0]) 
    rat_name =  re.sub('/', ' ', rat_name)
    ax.set_title(rat_name)
    save_path = os.path.join(export_folder, rat_name + '_median_error_at_midgap.eps')
    plt.savefig(save_path)
    


In [None]:
kp_names_dim = np.repeat(kp_names, 3)
for i, name in enumerate(kp_names_dim):
    if np.mod(i, 3) == 0:
        kp_names_dim[i] += '_x'
    elif np.mod(i, 3) == 1:
        kp_names_dim[i] += '_y'
    elif np.mod(i, 3) == 2:
        kp_names_dim[i] += '_z'

In [None]:
data[0].keys()

In [None]:
exp_params = {
    'exp': 2,
    'trace': 7500,
    'length_id': 9,
    'win': 75,
    'input_length': 9,
}

def get_trace_context(data, exp=2, trace=7500, length_id=9, win=75):
    lengths = np.arange(10,101,10)
    input_length = 9
    dat = data[exp]
    input_ids = dat['input_ids'][0, length_id][trace, :]
    start = np.min(input_ids)
    ids = np.arange(start - win, start + win + input_length + lengths[length_id])
    marker_stds = data[exp]['marker_stds']
    marker_means = data[exp]['marker_means']
    markers = dat['markers'][ids, :]*marker_stds + marker_means
    return markers

def names2logical(m_ids):
    if m_ids is None:
        m_ids = np.ones((len(kp_names_dim),), dtype=bool)
    else:
        m_ids = [any(part in name for part in m_ids) for name in kp_names_dim]
    return np.array(m_ids)

def get_prediction(data, exp=2, trace=270, length_id=0, win=75, m_ids=None):
    marker_ids = names2logical(m_ids)
    marker_stds = data[exp]['marker_stds']
    marker_means = data[exp]['marker_means']
    preds = data[exp]['predictions'][0, length_id]
#     imputed_trace = np.round(trace / data[exp]['skip'].item()).astype('int32')
    preds = preds[trace, :, marker_ids]
    return preds.T

# def get_prediction(data, exp=2, trace=7500, length_id=9, win=75, m_ids=None):
#     marker_ids = names2logical(m_ids)
#     target_ids = data[exp]['target_ids'][0, length_id][trace, :]
#     imputed_trace = np.round(trace / data[exp]['skip'].item()).astype('int32')
#     marker_stds = data[exp]['marker_stds']
#     marker_means = data[exp]['marker_means']
#     markers = data[exp]['markers'][target_ids[:-9], :].squeeze()*marker_stds + marker_means
#     markers = markers[:, marker_ids]
#     delta = data[exp]['delta_markers'][0,length_id][imputed_trace, :].squeeze()
#     delta = np.repeat(delta, 3, axis=1)
#     delta = delta[:, marker_ids]
#     return delta + markers


markers = get_trace_context(data)
print(markers.shape)
ids = names2logical(['Head'])
predictions = get_prediction(data)

In [None]:
def zero_offset(y):
    offset = np.nanmean(y)
    y -= offset
    return y, offset
    
def plot_traces(markers, predictions, m_names, p_names=None, win=75):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12,8))

    # Get the marker ids
    m_ids = np.argwhere(names2logical(m_names))
    m_ids = [num.item() for num in m_ids]
    p_ids = np.argwhere(names2logical(p_names))
    p_ids = [num.item() for num in p_ids]
    
    # find which predictions to plot
    marker_preds_to_plot = [any(pred in marker_name for pred in p_names) for marker_name in m_names]
    # Set graph parameters
    cmap = matplotlib.cm.get_cmap('Dark2')
    dist = 30
    ub = dist * len(m_ids)
    spacing = np.arange(0, ub, dist)
    spacing = spacing[::-1]
    fs = 60
    input_length = 9
    mean_offset = [None]*len(m_ids)
    
    # Plot the context and ground truth
    for i, marker_id in enumerate(m_ids):
        y, mean_offset[i] = zero_offset(markers[:, marker_id])
        y += spacing[i]
        plt.plot(np.arange(0, markers.shape[0])/fs,
                 y, color=cmap(i), linewidth=3)
        
    # Plot the imputed values
    # Be wary of this one.
    prediction_offset = -5
    for i, marker_id in enumerate(m_ids):
        if not marker_preds_to_plot[i]:
            continue;
        y = predictions[:, marker_id] - mean_offset[i]   
        y += spacing[i]
        y += prediction_offset
        color = [c*.5 for c in cmap(i)]
#         plt.plot(
#             np.arange(
#                 win + input_length, 
#                 win + input_length + predictions.shape[0])/fs,
#             y, color=color, linewidth=3)
        
        # Plot original trace, replacing the gap with a prediction
        orig, _ = zero_offset(markers[:, marker_id])
        orig += spacing[i]
        orig += prediction_offset
        orig[win + input_length:win + input_length + predictions.shape[0]] = y
        plt.plot(np.arange(0, markers.shape[0])/fs,
                 orig, color=color, linewidth=3)
    
    # Set up legends, ticks, and vertical lines
    h = [None]*len(m_ids)
    for i in range(len(m_ids)):
        h[i] = matplotlib.patches.Patch(color=cmap(i), label=m_names[i])
    plt.legend(handles=h)
    plt.xlabel('Time (seconds)')
    plt.yticks(spacing)
    ax.set_yticklabels([])
    ymin, ymax = plt.ylim()
    plt.vlines((win + input_length)/fs, ymin, ymax, linestyle='--')
    plt.vlines((win + input_length + predictions.shape[0] )/fs, ymin, ymax, linestyle='--')
    plt.ylim([ymin, ymax])
    
skip = data[0]['skip'].item()
exp=2
trace=500
length_id=5
win=51
markers = get_trace_context(data,  exp=exp, trace=trace*skip, length_id=length_id, win=win)
predictions = get_prediction(data, exp=exp, trace=trace, length_id=length_id, win=win)
plot_traces(markers, predictions, ['HeadF_x', 'ArmR_x', 'HipL_x'], ['HeadF_x', 'ArmR_x', 'HipL_x'], win=win)

save_path = os.path.join(export_folder, 'trace_example.eps')
plt.savefig(save_path)