In [None]:
import numpy as np
import pandas as pd
import pickle
import os
from cc_utils import _get_clip_labels

# plot
import plotly.offline as py
py.init_notebook_mode(connected=True)
import plotly.graph_objs as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default = 'plotly_white'
from plot_utils import (_hex_to_rgb, _plot_ts,
    _add_box, _highlight_max)

colors = px.colors.qualitative.Plotly

In [None]:
class ARGS():
    roi = 300
    net = 7
    subnet = 'wb'
    
    cutoff = 0.1
    zscore = 1

    k_fold = 10
    
    # lstm
    k_hidden = 150
    k_layers = 1
    batch_size = 16
    num_epochs = 20
    
    # cpm
    corr_thresh = 0.2
    
    select_bhv =   ['PMAT24_A_CR', 'PicVocab_Unadj',
        'NEOFAC_A', 'NEOFAC_E', 'NEOFAC_C',
        'NEOFAC_N', 'NEOFAC_O']

In [None]:
args = ARGS()
K_RUNS = 4
'''
SCORES:
'mse': mean squared error
'p': pearson correlation
's': spearman correlation
'''
SCORES = ['mse', 'p', 's']
GLOBAL_LAYOUT = dict(font=dict(
    family='helvetica', color='black'),
    xaxis=dict(tickfont=dict(size=16)),
    yaxis=dict(tickfont=dict(size=16)))

In [None]:
pretty_bhv = {}
for bhv in args.select_bhv:
    pretty_bhv[bhv] = bhv

pretty_bhv['PMAT24_A_CR'] = 'Fluid Intelligence'
pretty_bhv['PicVocab_Unadj'] = 'Verbal IQ'

In [None]:
LSTM_DIR = 'results/bhv_lstm'
CPM_DIR = 'results/bhv_cpm'

In [None]:
def _get_results(args, bhv):
    
    r = {}
    res_path = (LSTM_DIR + 
        '/roi_%d_net_%d' %(args.roi, args.net) + 
        '_nw_%s' %(args.subnet) +
        '_bhv_%s_cutoff_%0.1f' %(bhv, args.cutoff) +
        '_kfold_%d_k_hidden_%d' %(args.k_fold, args.k_hidden) +
        '_k_layers_%d_batch_size_%d' %(args.k_layers, args.batch_size) +
        '_num_epochs_%d_z_%d.pkl' %(args.num_epochs, args.zscore))

    with open(res_path, 'rb') as f:
        r['lstm'] = pickle.load(f)
        
    res_path = (CPM_DIR + 
        '/roi_%d_net_%d' %(args.roi, args.net) + 
        '_nw_%s' %(args.subnet) +
        '_bhv_%s_cutoff_%0.1f' %(bhv, args.cutoff) +
        '_corrthresh_%0.1f' %(args.corr_thresh) +
        '_kfold_%d_z_%d.pkl' %(args.k_fold, args.zscore))

    with open(res_path, 'rb') as f:
        r['cpm'] = pickle.load(f)
        
    return r

In [None]:
r = {}
for bhv in args.select_bhv:
    r[bhv] = _get_results(args, bhv)

In [None]:
clip_y = _get_clip_labels()
clip_y['testretest'] = 0
k_clip = len(np.unique(list(clip_y.values())))
print('number of clips = %d' %k_clip)

clip_names = np.zeros(k_clip).astype(str)
clip_names[0] = 'testretest'
for key, item in clip_y.items():
    if item!=0:
        clip_names[item] = key

# ACC

In [None]:
def _plot_clip_acc(r, bhv_measures, clip_names, score):
    '''
    compare clip accuracy for each bhv measure
    max accuracy across time for lstm
    '''
    k_cols = 3
    k_clips = len(clip_names)
    k_rows = int(np.ceil(k_clips/k_cols))
    fig1 = make_subplots(rows=k_rows, cols=k_cols, 
        subplot_titles=clip_names, print_grid=False,
        shared_xaxes=True, shared_yaxes=True)
    
    mu_lstm = np.zeros((len(clip_names), len(bhv_measures)))
    mu_cpm = np.zeros((len(clip_names), len(bhv_measures)))
    for ii, clip in enumerate(clip_names):
        c = clip_y[clip]
        showlegend = (ii==0)
        
        row = int(ii/k_cols) + 1
        col = (ii%k_cols) + 1

        lstm = {'mean':[], 'ste':[]}
        cpm = {'mean':[], 'ste':[]}
        for jj, bhv in enumerate(bhv_measures):
            a = r[bhv]
            
            # averaged across folds
            v = np.mean(a['lstm']['t_val_%s'%score][c], axis=0)
            # find time instant with max acc
            if score=='mse':
                tmax = np.argmin(v)
            else:
                tmax = np.argmax(v)

            # averaged across time
            la = a['lstm']['t_val_%s'%score][c][:, tmax]
            lstm['mean'].append(np.mean(la))
            lstm['ste'].append(
                1/np.sqrt(args.k_fold)*np.std(la))

            ca = a['cpm']['c_val_%s'%score][c]
            cpm['mean'].append(np.nanmean(ca))
            cpm['ste'].append(
                1/np.sqrt(args.k_fold)*np.std(ca))
            #
            mu_lstm[ii, jj] = np.mean(la)
            mu_cpm[ii, jj] = np.nanmean(ca)

        # val acc
        trace = go.Bar(name='lstm',
            x=np.arange(len(bhv_measures)), 
            y=np.array(lstm['mean']),
            error_y=dict(type='data',
                array=3*lstm['ste']),
            marker_color=colors[0],
            showlegend=showlegend,
            legendgroup='lstm')
        fig1.add_trace(trace, row, col)
        
        trace = go.Bar(name='cpm',
            x=np.arange(len(bhv_measures)), 
            y=np.array(cpm['mean']),
            error_y=dict(type='data',
                array=3*cpm['ste']),
            marker_color=colors[1],
            showlegend=showlegend,
            legendgroup='cpm')
        fig1.add_trace(trace, row, col)
        
        if row==k_rows:
            fig1.update_xaxes(dtick=1,
                tickvals=np.arange(len(bhv_measures)),
                ticktext=bhv_measures,
                title_text='bhv measure', row=row, col=col)
        if col==1:
            fig1.update_yaxes(title_text=score, 
                row=row, col=col)
    
    fig1.update_xaxes(dtick=1)
    fig1.update_yaxes(dtick=0.1)
    
    fig1.update_layout(barmode='group',
        height=int(250*k_rows), width=950,
        title_text='lstm vs cpm accuracy (3 sem)')
    '''
    plot heatmap of (mean-null) values
    '''
    fig2 = go.Figure()
    mu = mu_lstm - mu_cpm
    z = mu.T
    zmax = np.max(z)
    trace = go.Heatmap(z=z,
        zmax=0.9, zmin=0,
        colorscale='ylorrd',
        colorbar=dict(tickfont_size=20))
    fig2.add_trace(trace)
    #fig2 = _highlight_max(fig2, z, axis=0)
    fig2.update_yaxes(tickvals=np.arange(len(bhv_measures)),
        ticktext=[pretty_bhv[bhv] for bhv in bhv_measures], 
        tickfont=dict(size=20))
    fig2.update_xaxes(tickvals=np.arange(len(clip_names)),
        ticktext=clip_names, tickfont=dict(size=20), tickangle=45)
    fig2.update_layout(height=400, width=650)
    '''
    plot violin of acc distribution across clips
    for each bhv
    '''
    fig3 = go.Figure()
    for ii, bhv in enumerate(bhv_measures):
        
        # violins
        showlegend = (ii==0)
        y_l = mu_lstm[:, ii]
        x_l = ii-0.2
        trace = go.Violin(x=x_l*np.ones(k_clips),
            y=y_l, text=clip_names, width=0.5,
            side='negative', jitter=0.0, # remove all jitter
            box_visible=True, meanline_visible=True,
            marker_color=colors[0], name='LSTM',
            showlegend=showlegend, pointpos=0, points='all',
            scalemode='count', legendgroup='LSTM')
        fig3.add_trace(trace)
        y_c = mu_cpm[:, ii]
        x_c = ii+0.2
        trace = go.Violin(x=x_c*np.ones(k_clips),
            y=y_c, text=clip_names, width=0.5,
            side='positive', jitter=0.0, # remove all jitter
            box_visible=True, meanline_visible=True,
            marker_color=colors[1], name='CPM',
            showlegend=showlegend, pointpos=0, points='all',
            scalemode='count', legendgroup='CPM')
        fig3.add_trace(trace)

        # connectors
        for jj in range(k_clips):
            y = [y_l[jj], y_c[jj]]
            color = colors[0] if y[1] > y[0] else colors[1]
            line = go.Scatter(x=[x_l, x_c],
                y=y, mode='lines', showlegend=False,
                line=dict(color=color, width=0.5))
            fig3.add_trace(line)
    
    fig3.update_xaxes(
        tickvals=np.arange(len(bhv_measures)),
        ticktext=[pretty_bhv[bhv] for bhv in bhv_measures],
        tickfont=dict(size=16))
    fig3.update_yaxes(
        title=dict(text='Spearman correlation',
            font_size=20),
        dtick=0.2,
        gridwidth=0.8, gridcolor='#bfbfbf',
        tickfont=dict(size=20))
    fig3.update_layout(violingap=1,
        font_color='black',
        violinmode='group',
        height=500, width=620,
        legend_orientation='h',
        legend_font_size=20,
        legend=dict(x=0, y=-0.4))
    
    return fig1, fig2, fig3

In [None]:
score = 's'

fig1, fig2, fig3 = _plot_clip_acc(r, args.select_bhv, clip_names, score)
fig1.show()

In [None]:
fig2.show()

In [None]:
fig3.update_yaxes(range=[-0.5, 1.05])
fig3.show()

# SCATTER SCORES

In [None]:
def _scatter_score(r, bhv_measures, clip_names, 
    score, criteria='max'):
    '''
    scatter true and predicted scores
    '''
    fig = {}
    k_clips = len(clip_names)
    for ii, bhv in enumerate(bhv_measures):
        
        a = r[bhv]
        b_l = np.zeros(k_clips)
        b_c = np.zeros(k_clips)
        
        # get best clip
        for jj, clip in enumerate(clip_names):

            c = clip_y[clip]
            # averaged across folds
            v = np.mean(a['lstm']['t_val_%s'%score][c], axis=0)
            if criteria=='max':
                # find time instant with best score
                if score=='mse':
                    b_l[jj] = np.min(v)
                else:
                    b_l[jj] = np.max(v)
            elif criteria=='mean':
                b_l[jj] = np.mean(v)
                
            v = a['cpm']['c_val_%s'%score][c]
            b_c[jj] = np.nanmean(v)
            
        best_l = clip_names[np.argmax(b_l)]
        best_c = clip_names[np.argmax(b_c)]
        c_l = clip_y[best_l]
        c_c = clip_y[best_c]
        title_text = ('<b>%s</b> <br> '%bhv +
            'lstm: <i>%s</i>, cpm: <i>%s</i>'%(best_l, best_c))
        
        # collect true and predicted scores
        y = []
        y_hat = []
        tag = []
        '''
        lstm
        '''
        # get best time index
        v = np.mean(a['lstm']['t_val_%s'%score][c_l], axis=0)
        if score=='mse':
            tmax = np.argmin(v)
        else:
            tmax = np.argmax(v)
        
        for f in range(args.k_fold):
            # clip labels for fold and tmax
            c_y = a['lstm']['c'][f][:, tmax]

            # predictions for fold and tmax
            # for best clip 'c_l'
            y += list(a['lstm']['y'][f][:, tmax][c_y==c_l])
            y_hat += list(a['lstm']['y_hat'][f][:, tmax][c_y==c_l])
            k_samp = len(list(a['lstm']['y'][f][:, tmax][c_y==c_l]))
            tag += ['lstm']*k_samp
        '''
        cpm
        '''
        for f in range(args.k_fold):
            # clip labels for fold and tmax
            c_y = a['cpm']['c'][f]

            # predictions for fold and tmax
            # for best clip 'c_c'
            y += list(a['cpm']['y'][f][c_y==c_c])
            y_hat += list(a['cpm']['y_hat'][f][c_y==c_c])
            k_samp = len(list(a['cpm']['y'][f][c_y==c_c]))
            tag += ['cpm']*k_samp
        
        
        # df for px
        df = pd.DataFrame({'true':y, 'predicted':y_hat,
            'tag':tag})
        fig[bhv] = px.scatter(df, x='true', 
            y='predicted', color='tag',
            trendline='ols')
        fig[bhv].update_layout(height=350, width=400,
            title_text=title_text)
        fig[bhv].update_xaxes(dtick=0.2)
        fig[bhv].update_yaxes(dtick=0.2)
    
    return fig

In [None]:
score = 'p'
fig = _scatter_score(r, args.select_bhv, clip_names, 
    score, criteria='max')
bhv = 'PicVocab_Unadj'
fig[bhv].show()