# Calculate and plot Ensembles: project Erica

> Marcos Duarte  
> [Laboratory of Biomechanics and Motor Control](http://demotu.org/)  
> Federal University of ABC, Brazil

<h1>Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Python-setup" data-toc-modified-id="Python-setup-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Python setup</a></span><ul class="toc-item"><li><span><a href="#Some-color-and-plot-configuration" data-toc-modified-id="Some-color-and-plot-configuration-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Some color and plot configuration</a></span></li></ul></li><li><span><a href="#Helping-functions" data-toc-modified-id="Helping-functions-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Helping functions</a></span></li><li><span><a href="#Load-features" data-toc-modified-id="Load-features-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Load features</a></span></li><li><span><a href="#Create-DataArray-for-conditions-and-plot" data-toc-modified-id="Create-DataArray-for-conditions-and-plot-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Create DataArray for conditions and plot</a></span></li></ul></div>

## Python setup

In [1]:
import numpy as np
import pandas as pd
from scipy import stats
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import sys, os
import glob
import xarray as xr
from tqdm.notebook import tqdm
sys.path.insert(1, r'./../functions')
import read_c3d_xr
xr.set_options(keep_attrs=True)

<xarray.core.options.set_options at 0x1be2b410790>

### Some color and plot configuration

In [2]:
sns.set_context('notebook', font_scale=1.05, rc={"lines.linewidth": 1.5})
sns.set_style("darkgrid")  # {darkgrid, whitegrid, dark, white, ticks}
# print colors
cr = "\033[1;31m"  # red
cb = "\033[1;34m"  # blue
cg = "\033[1;32m"  # green
rs = "\033[0m"     # reset
# plot colors
# https://matplotlib.org/3.1.0/gallery/color/named_colors.html
color1 = 'royalblue'
color2 = 'darkblue'

## Helping functions

In [3]:
def processa(fname, prm, var='GRF', methodt='interp', cycle='stride'):
    """
    """
    if var=='Oxford2':
        data = read_c3d_xr.read_c3d(fname, var=var[:-1], prm=prm)
        data.name = var
    else:
        data = read_c3d_xr.read_c3d(fname, var=var, prm=prm)
    events = read_c3d_xr.find_ev_GRFcycle(data)
    if cycle != 'stride' and events['1'][1][1:] == 'FS':
        side = events['1'][1][0]
        events['1'] = [events[side+'FO'][0], side+'FO']
        events[side+'FS'] = [events[side+'FS'][:1]]
    side = data.Time.attrs['side']
    data = read_c3d_xr.trimmer(data, evs=events, trim=1)
    if var == 'GRF':
        data = read_c3d_xr.normala(data, method='BW', mass=prm['mass'])
    elif var in ['Moment', 'Power']:
        data = read_c3d_xr.normala(data, method='BM', mass=prm['mass'], LL=prm['LL'][side])
    data = read_c3d_xr.normalt(data, method=methodt, step=1)    
    #elif var == 'Moment':
        #data = read_c3d_xr.normala(data, method='', value=1, units='Nm/(BM*H)')
    #elif var == 'Power':
        #data = read_c3d_xr.normala(data, method='', value=1, units='W/(BM*H)')
    #data = read_c3d_xr.normalt(data, method=methodt, step=1)
    #var3 = data.Var.values.tolist()
    if var == 'GRF':
        data.values[:, 0, 0] = data.values[:, 0, 0] - data.values[-1, 0, 0]
        if np.nanmean(data.values[:, 0, 0]) > 0:
            data.values[:, 0, 0] = -data.values[:, 0, 0]      
    elif var in ['Angle', 'Moment']:
        var2 = [side + joint for joint in ['Hip', 'Knee', 'Ankle']]
        data = data.sel(Var=var2)
    elif var == 'Power':
        var2 = [side + joint for joint in ['Hip', 'Knee', 'Ankle']]
        data = data.sel(Var=var2)
    elif var == 'Oxford':
        var2 = [side + joint for joint in ['HFTBA', 'FFHFA', 'FFTBA', 'HXFFA']]
        data = data.sel(Var=var2)
        data.loc[dict(Var=side+'HXFFA', Axis='AP')] = data.sel(Var=side+'HXFFA', Axis='AP').values /\
                                                      prm['FL'][side] * 100
    elif var == 'Oxford2':
        var2 = [side + joint for joint in ['HFTBA', 'FFHFA', 'FFTBA', 'HXFFA']]
        data = data.sel(Var=var2)

    return data


def create_da(features, cycle='stride', variable='GRF'):
    """X dimensions [time, var, axis, trial, subject, assessment, group]
    """
    ft = features.copy(deep=True)
    # array with dimensions [time, var, axis, trial, subject, assessment, group]
    var = {'GRF': 1, 'Angle': 3, 'Moment': 3, 'Power': 3, 'Oxford': 4, 'Oxford2': 4}
    X = np.full((101, var[variable], 3, features['T'].max()+1, features['S'].max()+1,
                 features['A'].max()+1, features['G'].max()+1), np.nan)
    print('Creating DataArray for variable "{}" cycle "{}"'.format(variable, cycle))
    print('Dimensions [time, var, axis, trial, subject, assessment, group]:')
    print(X.shape)
    with tqdm(total=features.shape[0]) as pbar:
        for (i, fname, g, a, s, t) in features[['Filename', 'G', 'A', 'S', 'T']].itertuples(name=None):
            pbar.update()
            filename = os.path.join(path2, fname + '.c3d')
            prm = read_c3d_xr.get_parameters(filename)
            data = processa(filename, prm, var=variable, cycle=cycle)
            if cycle == 'stride' and features.loc[i, 'Cycle'] == 'stride':
                X[:, :, :, t, s, a, g] = data.values
            elif cycle == 'stance':
                X[:, :, :, t, s, a, g] = data.values    
    
    coords = {}
    coords['Time'] = data.Time.values
    if data.name in ['Angle', 'Moment', 'Power', 'Oxford', 'Oxford2']:
        coords['Var'] = [x[1:] for x in data.Var.values] 
    else:
        coords['Var'] = data.Var.values
    coords['Axis'] = data.Axis.values
    coords['Trial'] = range(X.shape[3])
    coords['Subject'] = ft.drop_duplicates('S')['Subject'].values
    coords['Assessment'] = ft.drop_duplicates('A')['Assessment'].values
    coords['Group'] = ft.drop_duplicates('G')['Group'].values    
    dims = ('Time', 'Var', 'Axis', 'Trial', 'Subject', 'Assessment', 'Group')
    da = xr.DataArray(data=X, dims=dims, coords=coords, name=data.name)
    da.attrs['units'] = data.attrs['units']
    da.Time.attrs['units'] = data.Time.attrs['units']
    return da     


def processa_plot(features, cycle='stride', variable='GRF'):
    """
    """
    da = create_da(features, cycle=cycle, variable=variable)
    x = similarity(da, dim='Trial', dim2='Time', central=np.nanquantile, normalize=1,
                   threshold=4, min_trials=5, q=.5)
    var = list(x.Var.values)
    g, footoff = plot(x, features, var, 'T00', 1, cycle=cycle, color=color1)
    g, footoff = plot(x, features, var, 'T12', 1, cycle=cycle, color=color2, g=g, FO_=footoff)
    g, footoff = plot(x, features, var, 'T00', 2, cycle=cycle, color=color1)
    g, footoff = plot(x, features, var, 'T12', 2, cycle=cycle, color=color2, g=g, FO_=footoff)
    
    return g, footoff


def MSE(x, dim='Trial', dim2='Time', central=np.nanquantile, normalize=1,
        **kwargs):
    """Mean Squared Error of `x` w.r.t. `central` across `dim` over `dim2`.
    """
    import warnings
    warnings.filterwarnings('ignore', message=r'All-NaN slice encountered')
    
    mse = ((x - x.reduce(central, **kwargs, dim=dim))**2).reduce(np.nanmean, dim=dim2)
    if normalize == 1:  # normalize by the total MSE
        mse = mse/mse.reduce(np.nansum, dim=dim)
    elif normalize == 2:  # normalize by the data variance
        mse = mse/x.reduce(np.nanstd, dim=dim2)**2

    return mse


def similarity(da, dim='Trial', dim2='Time', central=np.nanquantile, normalize=1,
               threshold=4, min_trials=5, recursive=True, **kwargs):
    """
    """
    # TODO: replace loops with iteraction over dict
    # https://stackoverflow.com/questions/31686899/how-to-iterate-over-a-dynamic-object
    x = da.copy(deep=True)
    #coords = [coord for coord in x.coords.keys() if coord not in [dim, dim2]]
    print(cb, end='')
    print('Number of trials, outlier trial, MSE, Variable, Axis, Subject, Assessment, Group')
    print(rs, end='')
    for g in x['Group'].values:
        for a in x['Assessment'].values:
            for s in x['Subject'].values:
                for ax in x['Axis'].values:
                    for v in x['Var'].values:
                        mse = MSE(x.sel(Group=g, Assessment=a, Subject=s, Axis=ax, Var=v),
                                  dim=dim, dim2=dim2, central=central, normalize=normalize,
                                  **kwargs)
                        if not np.all(np.isnan(mse)):
                            mse = mse.sortby(mse, ascending=True)
                            n = np.count_nonzero(~np.isnan(mse.values))
                            while n > min_trials and mse.values[n-1] > threshold/n:
                                #print(n, mse['Trial'].values[n-1], mse.values[n-1], v, ax, s, a, g)
                                x.loc[dict(Group=g, Assessment=a, Subject=s, Axis=ax,
                                           Var=v, Trial=mse['Trial'].values[n-1])] = np.nan
                                mse = MSE(x.sel(Group=g, Assessment=a, Subject=s, Axis=ax, Var=v),
                                          dim=dim, dim2=dim2, central=central, normalize=normalize,
                                          **kwargs)
                                if not np.all(np.isnan(mse)):
                                    mse = mse.sortby(mse, ascending=True)
                                    n = np.count_nonzero(~np.isnan(mse.values))
                                    if n <= min_trials or mse.values[n-1] <= threshold/n:
                                        print(cr, end='')
                                        print(n, mse['Trial'].values[n-1],
                                              mse.values[n-1], v, ax, s, a, g)
                                        print(rs, end='')
                                else:
                                    pass
                                if not recursive:
                                    break
                            
    return x


def plot(da, features, var, assessment, group, mean=True, cycle='stride',
         color='b', g=None, FO_=None):
    """
    """
    ft = features.copy(deep=True)
    subject = ft[(ft['Group']==group) &
                 (ft['Assessment']==assessment)].drop_duplicates('S')['Subject'].values
    axis = 'VT' if da.name == 'Power' else 'ML'
    N = da.sel(Var=var[0], Axis=axis, Subject=subject, Assessment=assessment,
               Group=group).sum(dim='Time', skipna=False).squeeze()
    NT = N.count().values
    NS = N.sum(dim='Trial', skipna=True).squeeze().count().values 
    print('{} subjects'.format(NS))
    print('{} trials'.format(NT))

    if cycle == 'stride':
        ft['FO'] = ft[ft['Cycle']=='stride']['Stance Time']/ft[ft['Cycle']=='stride']['Stride Time']*100
        ft[ft['FO']>100] = np.nan
        FO = ft[(ft['Group']==group) & (ft['Assessment']==assessment)
                ][['Subject', 'FO']].groupby('Subject').agg(np.nanmedian).agg(['mean', 'std']).values
        FO[1][0] = FO[1][0]/np.sqrt(NS)
        footoff = FO[0][0]
    else:
        footoff = None

    if da.name == 'Power':
        axes = ['VT']
        row='Axis'
        col='Var'
    elif da.name == 'Oxford2':
        axes = ['ML']
        row='Axis'
        col='Var'
    else:
        axes = da.Axis.values
        row='Var'
        col='Axis'

    damed = da.sel(Var=var, Subject=subject, Assessment=assessment,
                   Group=group).reduce(np.nanquantile, q=.5, dim='Trial')
    damean = damed.sel(Var=var, Subject=subject).reduce(np.nanmean, dim='Subject')
    #ci = stats.t.ppf((1+95/100)/2, NS-1)/np.sqrt(NS)
    dastd = damed.sel(Var=var, Subject=subject).reduce(np.nanstd, ddof=1, dim='Subject')/np.sqrt(NS)
    if da.name in ['GRF', 'Power', 'Oxford2']:
        size = 3.6; aspect = 1.1
    elif da.name == 'Oxford':
        size = 2.9; aspect = 1.30
    else:
        size = 3; aspect = 1.25
    
    if g is None:
        g = damean.sel(Axis=axes).plot.line(x='Time', row=row, col=col, sharey=False,
                                            size=size, aspect=aspect, color=color,
                                            alpha=.8, lw=2, label=assessment)
        axs = np.atleast_2d(g.axs)
    else:
        axs = np.atleast_2d(g.axs)
        variables = da['Var'].values
        for r in range(axs.shape[0]):
            for c in range(axs.shape[1]):
                if da.name in ['Power', 'Oxford2']:
                    axs[r, c].plot(damean['Time'], damean.sel(Axis=axes).isel(Var=c, Axis=r),
                                   color=color, alpha=.8, lw=2, label=assessment)
                    axs[r, c].set_title(variables[c])
                else:
                    axs[r, c].plot(damean['Time'], damean.sel(Axis=axes).isel(Var=r, Axis=c),
                                   color=color, alpha=.8, lw=2, label=assessment)
                if r==0 and c==0:
                    loc = 'upper left' if cycle == 'stride' else 'upper center'
                    axs[r, c].legend(title='Assessment', frameon=False, loc=loc)  

                axs[r, c].set_xlim(0, 100)
                if (da.name=='Oxford' and r==3 and c==1):
                    axs[r, c].set_ylabel('Arch Height Length [%FL]')
                else:
                    axs[r, c].axhline(y=0, c='k', lw=1)
                if c == 0:
                    if da.name == 'Oxford':
                        axs[r, c].set_ylabel('{} Angle [{}]'.format(variables[r],
                                                                    da.attrs['units']))
                    elif da.name == 'Oxford2':
                        axs[r, c].set_ylabel('Angle [{}]'.format(da.attrs['units']))
                    elif da.name not in ['GRF', 'Power']:
                        axs[r, c].set_ylabel('{} {} [{}]'.format(variables[r], da.name,
                                                                 da.attrs['units']))
                    elif da.name not in ['GRF', 'Power']:
                        axs[r, c].set_ylabel('{} {} [{}]'.format(variables[r], da.name,
                                                                 da.attrs['units']))
                if (da.name=='Oxford' and r==3 and c==2):
                    axs[r, c].set_ylabel('FTA Angle [{}]'.format(da.attrs['units']))
                ant = [ch for ch in axs[r, c].get_children() if isinstance(ch, plt.Annotation)]
                if len(ant):
                    ant[0].remove()
                if r == 0:
                    axs[r, c].text(x=0, y=1.03, s='FS', fontsize=12, horizontalalignment='center',
                                   transform=axs[r, c].transAxes)
                    FSFO = 'FS' if cycle == 'stride' else 'FO'
                    axs[r, c].text(x=1, y=1.03, s=FSFO, fontsize=12, horizontalalignment='center',
                                   transform=axs[r, c].transAxes)   
                if cycle == 'stride':
                    if FO_ is not None and r == 0:
                        x = (FO_ + FO[0][0])/2/100 + np.max([3*FO[1][0]/100, .01])
                        axs[r, c].text(x=x, y=.92, s='FO', fontsize=12, horizontalalignment='left',
                                       transform=axs[r, c].transAxes)
                    
    for r in range(axs.shape[0]):
        for c in range(axs.shape[1]):   
            if da.name in ['Power', 'Oxford2']:
                axs[r, c].fill_between(da.Time, damean.sel(Axis=axes).isel(Var=c, Axis=r).squeeze()+
                                       dastd.sel(Axis=axes).isel(Var=c, Axis=r).squeeze(),
                                       damean.sel(Axis=axes).isel(Var=c, Axis=r).squeeze()-
                                       dastd.sel(Axis=axes).isel(Var=c, Axis=r).squeeze(),
                                       facecolor=color, alpha=.4, edgecolor='none', zorder=3)
            else:
                axs[r, c].fill_between(da.Time, damean.sel(Axis=axes).isel(Var=r, Axis=c).squeeze()+
                                       dastd.sel(Axis=axes).isel(Var=r, Axis=c).squeeze(),
                                       damean.sel(Axis=axes).isel(Var=r, Axis=c).squeeze()-
                                       dastd.sel(Axis=axes).isel(Var=r, Axis=c).squeeze(),
                                       facecolor=color, alpha=.4, edgecolor='none', zorder=3)
            if cycle == 'stride':
                axs[r, c].axvline(FO[0][0], ls='--', lw=2, color=color, zorder=2)               
                axs[r, c].axvspan(FO[0][0]-FO[1][0], FO[0][0]+FO[1][0], alpha=.2, facecolor=color,
                           edgecolor='none', zorder=1)
            axs[r, c].margins(y=.2)
            
        #if da.name == 'Oxford':
        #    axs[3, 1].set_ylabel('Arch Height Length [%FL]')              
        #    axs[3, 2].set_ylabel('FTA Angle [{}]'.format(da.attrs['units']))
            
        y = .97 if axs.size == 1 else .95 if axs.size == 3 else .98
        g.fig.suptitle('Group: {}'.format(group), fontsize=14, y=y)    
        plt.tight_layout()
        plt.savefig(os.path.join(path2, 'figures', 'G{}_{}_{}.jpg').format(group, da.name, cycle),
                    dpi=300)

    return g, footoff


## Load features

In [4]:
path2 = './../data'
fname = os.path.join(path2, 'features_all.csv')
features = pd.read_csv(fname, sep=',', header=0, index_col=0, dtype={'Assessment':'str'})
display(features)

Unnamed: 0,Filename,Subject,Group,Assessment,Trial,Mass,Height,LegLength,FootLength,Cadence,...,PAvtAMfot,PAvtVLw1,PAvtVLw1t,PAvtPKw1,PAvtPKw1t,PAvtRNw1,G,A,S,T
0,S_S16_T00_2,S16,1,T00,2,76.5,1.55,0.746,0.194,,...,0.66,,,,,,0,0,0,0
1,S_S16_T00_3,S16,1,T00,3,76.5,1.55,0.746,0.194,,...,0.68,,,,,,0,0,0,1
2,S_S16_T00_4,S16,1,T00,4,76.5,1.55,0.746,0.194,,...,0.68,,,,,,0,0,0,2
3,S_S16_T00_5,S16,1,T00,5,76.5,1.55,0.746,0.194,,...,-0.43,-17.014318,-0.26,-8.841038,-0.01,8.173281,0,0,0,3
4,S_S16_T00_6,S16,1,T00,6,76.5,1.55,0.746,0.194,,...,-0.43,-15.725980,-0.25,-6.857044,-0.01,8.868936,0,0,0,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1232,S_S67_T24_11,S67,2,T24,11,75.9,1.66,0.760,0.199,103.627,...,0.73,-19.755606,0.95,-10.454477,0.73,9.301128,1,2,55,6
1233,S_S67_T24_12,S67,2,T24,12,75.9,1.66,0.760,0.199,103.986,...,0.72,-16.659187,0.95,-6.640570,0.73,10.018618,1,2,55,7
1234,S_S67_T24_13,S67,2,T24,13,75.9,1.66,0.760,0.199,103.004,...,0.73,-18.596239,0.96,-7.901696,0.74,10.694543,1,2,55,8
1235,S_S67_T24_14,S67,2,T24,14,75.9,1.66,0.760,0.199,101.868,...,0.73,-19.356586,0.95,-9.146155,0.73,10.210431,1,2,55,9


## Create DataArray for conditions and plot

In [None]:
processa_plot(features, cycle='stride', variable='GRF'); 

Creating DataArray for variable "GRF" cycle "stride"
Dimensions [time, var, axis, trial, subject, assessment, group]:
(101, 1, 3, 12, 59, 3, 2)


  0%|          | 0/1237 [00:00<?, ?it/s]

In [None]:
processa_plot(features, cycle='stride', variable='Angle'); 

In [None]:
processa_plot(features, cycle='stride', variable='Moment'); 

In [None]:
processa_plot(features, cycle='stride', variable='Power'); 

In [None]:
processa_plot(features, cycle='stride', variable='Oxford'); 

In [None]:
g, footoff = processa_plot(features, cycle='stride', variable='Oxford2') 

In [None]:
def plot2(da, features, var, assessment, group, mean=True, cycle='stride',
         color='b', g=None, FO_=None):
    """
    """
    if assessment == 'T00':
        assessment_s = 'Baseline'
    else:
        assessment_s = '12-week'    
    ft = features.copy(deep=True)
    subject = ft[(ft['Group']==group) &
                 (ft['Assessment']==assessment)].drop_duplicates('S')['Subject'].values
    axis = 'VT' if da.name == 'Power' else 'ML'
    N = da.sel(Var=var[0], Axis=axis, Subject=subject, Assessment=assessment,
               Group=group).sum(dim='Time', skipna=False).squeeze()
    NT = N.count().values
    NS = N.sum(dim='Trial', skipna=True).squeeze().count().values 
    print('{} subjects'.format(NS))
    print('{} trials'.format(NT))

    if cycle == 'stride':
        ft['FO'] = ft[ft['Cycle']=='stride']['Stance Time']/ft[ft['Cycle']=='stride']['Stride Time']*100
        ft[ft['FO']>100] = np.nan
        FO = ft[(ft['Group']==group) & (ft['Assessment']==assessment)
                ][['Subject', 'FO']].groupby('Subject').agg(np.nanmedian).agg(['mean', 'std']).values
        FO[1][0] = FO[1][0]/np.sqrt(NS)
        footoff = FO[0][0]
    else:
        footoff = None

    if da.name == 'Power':
        axes = ['VT']
        row='Axis'
        col='Var'
    elif da.name == 'Oxford2':
        axes = ['ML']
        row='Axis'
        col='Var'
    else:
        axes = da.Axis.values
        row='Var'
        col='Axis'

    damed = da.sel(Var=var, Subject=subject, Assessment=assessment,
                   Group=group).reduce(np.nanquantile, q=.5, dim='Trial')
    damean = damed.sel(Var=var, Subject=subject).reduce(np.nanmean, dim='Subject')
    #ci = stats.t.ppf((1+95/100)/2, NS-1)/np.sqrt(NS)
    dastd = damed.sel(Var=var, Subject=subject).reduce(np.nanstd, ddof=1, dim='Subject')/np.sqrt(NS)
    if da.name in ['GRF', 'Power', 'Oxford2']:
        size = 3.6; aspect = 1.1
    elif da.name == 'Oxford':
        size = 2.9; aspect = 1.30
    else:
        size = 3; aspect = 1.25
    
    if g is None:
        g = damean.sel(Axis=axes).plot.line(x='Time', row=row, col=col, sharey=False,
                                            size=size, aspect=aspect, color=color,
                                            alpha=.8, lw=2, label=assessment_s)
        axs = np.atleast_2d(g.axs)
    else:
        axs = np.atleast_2d(g.axs)
        variables = da['Var'].values
        for r in range(axs.shape[0]):
            for c in range(axs.shape[1]):
                if da.name in ['Power', 'Oxford2']:
                    axs[r, c].plot(damean['Time'], damean.sel(Axis=axes).isel(Var=c, Axis=r),
                                   color=color, alpha=.8, lw=2, label=assessment_s)
                    titles = ['Hindfoot to tibia', 'Forefoot to hindfoot',
                              'Forefoot to tibia', 'Hallux to forefoot']
                    axs[r, c].set_title(titles[c], fontsize=14)
                    
                else:
                    axs[r, c].plot(damean['Time'], damean.sel(Axis=axes).isel(Var=r, Axis=c),
                                   color=color, alpha=.8, lw=2, label=assessment_s)
                if r==0 and c==axs.shape[1]-1:
                    loc = 'best'
                    axs[r, c].legend(title='Assessment', frameon=False, loc=loc)  
                    

                axs[r, c].set_xlim(0, 100)
                if (da.name=='Oxford' and r==3 and c==1):
                    axs[r, c].set_ylabel('Arch Height Length [%FL]')
                else:
                    axs[r, c].axhline(y=0, c='k', lw=1)
                if c == 0:
                    if da.name == 'Oxford':
                        axs[r, c].set_ylabel('{} Angle [{}]'.format(variables[r],
                                                                    da.attrs['units']))
                    elif da.name == 'Oxford2':
                        ylabel = 'Plantarflexion(-) / Dorsiflexion(+) [$^o$]'
                        axs[r, c].set_ylabel(ylabel.format(da.attrs['units']))
                        #axs[r, c].yaxis.set_label_coords(-.2, 1)
                    elif da.name not in ['GRF', 'Power']:
                        axs[r, c].set_ylabel('{} {} [{}]'.format(variables[r], da.name,
                                                                 da.attrs['units']))
                    elif da.name not in ['GRF', 'Power']:
                        axs[r, c].set_ylabel('{} {} [{}]'.format(variables[r], da.name,
                                                                 da.attrs['units']))
                if (da.name=='Oxford' and r==3 and c==2):
                    axs[r, c].set_ylabel('FTA Angle [{}]'.format(da.attrs['units']))
                ant = [ch for ch in axs[r, c].get_children() if isinstance(ch, plt.Annotation)]
                if len(ant):
                    ant[0].remove()
                if r == 0:
                    #axs[r, c].text(x=0, y=1.03, s='FS', fontsize=12, horizontalalignment='center',
                    #               transform=axs[r, c].transAxes)
                    FSFO = 'FS' if cycle == 'stride' else 'FO'
                    #axs[r, c].text(x=1, y=1.03, s=FSFO, fontsize=12, horizontalalignment='center',
                    #               transform=axs[r, c].transAxes)   
                if cycle == 'stride':
                    if FO_ is not None and r == 0:
                        x = (FO_ + FO[0][0])/2/100 + np.max([3*FO[1][0]/100, .01])
                        #axs[r, c].text(x=x, y=.92, s='FO', fontsize=12, horizontalalignment='left',
                        #               transform=axs[r, c].transAxes)
                    
    for r in range(axs.shape[0]):
        for c in range(axs.shape[1]):   
            if da.name in ['Power', 'Oxford2']:
                axs[r, c].fill_between(da.Time, damean.sel(Axis=axes).isel(Var=c, Axis=r).squeeze()+
                                       dastd.sel(Axis=axes).isel(Var=c, Axis=r).squeeze(),
                                       damean.sel(Axis=axes).isel(Var=c, Axis=r).squeeze()-
                                       dastd.sel(Axis=axes).isel(Var=c, Axis=r).squeeze(),
                                       facecolor=color, alpha=.4, edgecolor='none', zorder=3)
            else:
                axs[r, c].fill_between(da.Time, damean.sel(Axis=axes).isel(Var=r, Axis=c).squeeze()+
                                       dastd.sel(Axis=axes).isel(Var=r, Axis=c).squeeze(),
                                       damean.sel(Axis=axes).isel(Var=r, Axis=c).squeeze()-
                                       dastd.sel(Axis=axes).isel(Var=r, Axis=c).squeeze(),
                                       facecolor=color, alpha=.4, edgecolor='none', zorder=3)
            if cycle == 'stride':
                axs[r, c].axvline(FO[0][0], ls='--', lw=2, color=color, zorder=2)               
                axs[r, c].axvspan(FO[0][0]-FO[1][0], FO[0][0]+FO[1][0], alpha=.2, facecolor=color,
                           edgecolor='none', zorder=1)
            axs[r, c].margins(y=.1)
            
        y = .95
        if group == 2:
            texto = 'Control Group'
        else:
            texto = 'Intervention Group'
        g.fig.suptitle(texto, fontsize=18, y=y)    
        
        plt.tight_layout()
        plt.savefig(os.path.join(path2, 'figures', 'G{}_{}_{}.jpg').format(group, da.name, cycle),
                    dpi=300)

    return g, footoff


def processa_plot2(features, cycle='stride', variable='GRF'):
    """
    """
    da = create_da(features, cycle=cycle, variable=variable)
    x = similarity(da, dim='Trial', dim2='Time', central=np.nanquantile, normalize=1,
                   threshold=4, min_trials=5, q=.5)
    var = list(x.Var.values)
    g, footoff = plot2(x, features, var, 'T00', 2, cycle=cycle, color=color1)
    g, footoff = plot2(x, features, var, 'T12', 2, cycle=cycle, color=color2, g=g, FO_=footoff)
    g, footoff = plot2(x, features, var, 'T00', 1, cycle=cycle, color=color1)
    g, footoff = plot2(x, features, var, 'T12', 1, cycle=cycle, color=color2, g=g, FO_=footoff)
    
    return g, footoff

In [None]:
g, footoff = processa_plot2(features, cycle='stride', variable='Oxford2') 