In [None]:
from __future__ import division

import numpy as np
import os
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import scipy.stats as stats
import imageio
import warnings

from scipy.stats import norm
from IPython.display import clear_output
from PIL import Image

#### Do you want to show figures in the notebook?
#### Do you want to save them?

In [None]:
show_fig = False
save_fig = True

In [None]:
# Assign paths where features can be found
path_to_recog = '/Volumes/ntb/projects/sketchloop02/data/features/recog'
path_to_draw = '/Volumes/ntb/projects/sketchloop02/data/features/drawing'
roi_list = np.array(['V1Draw', 'V2Draw', 'LOCDraw', 'InsulaDraw', 'postCentralDraw', 
                     'preCentralDraw', 'ParietalDraw', 'FrontalDraw', 'smgDraw'])

In [None]:
# Assign variables within imported analysis helpers
import analysis_helpers
from importlib import reload
reload(analysis_helpers)

analysis_helpers.path_to_recog = path_to_recog
analysis_helpers.path_to_draw = path_to_draw

# Suppress warnings
warnings.filterwarnings('ignore')
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#### Helper data loader functions

def preprocess_recog(RECOG_METAS, RECOG_FEATS):
    M = [i for i in RECOG_METAS if len(i.split('.')[0].split('_'))==4]
    F = [i for i in RECOG_FEATS if len(i.split('.')[0].split('_'))==4]
    return M,F

def extract_good_sessions(DRAW_METAS,DRAW_FEATS):
    _DRAW_METAS = [i for i in DRAW_METAS if i.split('_')[1] in sub_list]
    _DRAW_FEATS = [i for i in DRAW_FEATS if i.split('_')[0] in sub_list]
    return _DRAW_METAS, _DRAW_FEATS

def cleanup_df(df):    
    surplus = [i for i in df.columns if 'Unnamed' in i]
    df = df.drop(surplus,axis=1)
    return df

def load_draw_meta(this_sub):
    this_file = 'metadata_{}_drawing.csv'.format(this_sub)
    x = pd.read_csv(os.path.join(path_to_draw,this_file))
    x = cleanup_df(x)
    x['trial_num'] = np.repeat(np.arange(40),23)        
    return x
    
def load_draw_feats(this_sub,this_roi):
    this_file = '{}_{}_featurematrix.npy'.format(this_sub,this_roi)
    y = np.load(os.path.join(path_to_draw,this_file))
    y = y.transpose()
    return y

def load_draw_data(this_sub,this_roi):
    x = load_draw_meta(this_sub)
    y = load_draw_feats(this_sub,this_roi)
    assert y.shape[0] == x.shape[0]    
    return x,y

    
def objects(sub_list):
    sub_obj = pd.DataFrame(columns=['objects', 'subject'])
    for this_sub in sub_list:
        DM, DF = load_draw_data(this_sub,'V1Draw')
        trained_objs = np.unique(DM.label.values)
        sub_obj = sub_obj.append({'objects': str(trained_objs[0])+str(trained_objs[1]), 
                                  'subject': this_sub}, ignore_index=True)
    return sub_obj  

In [None]:
# LOAD FILE LIST FOR RECOGNITION RUNS
RECOG_METAS = sorted([i for i in os.listdir(path_to_recog) if i.split('.')[-1]=='csv'])
RECOG_FEATS = sorted([i for i in os.listdir(path_to_recog) if i.split('.')[-1]=='npy'])
RECOG_SUBS = np.array([i.split('_')[0] for i in RECOG_FEATS])

recog_sub_list = np.unique(RECOG_SUBS)

RECOG_METAS, RECOG_FEATS = preprocess_recog(RECOG_METAS, RECOG_FEATS)

In [None]:
# LOAD FILE LIST FOR DRAWING RUNS
DRAW_METAS = sorted([i for i in os.listdir(path_to_draw) if i.split('.')[-1]=='csv'])
DRAW_FEATS = sorted([i for i in os.listdir(path_to_draw) if i.split('.')[-1]=='npy'])
DRAW_SUBS = np.array([i.split('_')[0] for i in DRAW_FEATS])
draw_sub_list = np.unique(DRAW_SUBS)

In [None]:
## get subject ID's that have complete datasets from all phases of experiment
sub_list = draw_sub_list

In [None]:
## filter file list so only contains the sessions that have full datasets

DRAW_METAS,DRAW_FEATS =  \
extract_good_sessions(DRAW_METAS,DRAW_FEATS)

DRAW_SUBS = np.array([i.split('_')[0]+'_neurosketch' for i in DRAW_FEATS])
DRAW_ROIS = np.array([i.split('_')[1] for i in DRAW_FEATS])

In [None]:
## general plotting params
sns.set_context('poster')
colors = sns.color_palette("cubehelix", 5)

In [None]:
# derive lists of subjects depending on which two items they drew. Currently unused but may prove useful later
sub_obj = objects(sub_list)

bedbench = np.array(sub_obj[sub_obj['objects']=='bedbench']['subject'])
bedchair = np.array(sub_obj[sub_obj['objects']=='bedchair']['subject'])
bedtable = np.array(sub_obj[sub_obj['objects']=='bedtable']['subject'])
benchchair = np.array(sub_obj[sub_obj['objects']=='benchchair']['subject'])
benchtable = np.array(sub_obj[sub_obj['objects']=='benchtable']['subject'])
chairtable = np.array(sub_obj[sub_obj['objects']=='chairtable']['subject'])

### BUILD MATRIX CONTAINING CLASSIFIER PROBABILITIES FOR ALL SUBJECTS, ROIS, DRAWING RUNS

In [None]:
import analysis_helpers
from importlib import reload
reload(analysis_helpers)

logged = [True]
versions = ['2wayDraw']

really_run = 1
if really_run:
    for l in logged:
        tag = 'logged' if l else 'raw'
        print(tag)
        for version in versions:
            roi_list_short = roi_list
            sub_list_short = sub_list
            ALLDM, Acc = analysis_helpers.make_drawing_predictions(sub_list_short,roi_list_short,version=version,logged=l)
            ALLDM.to_csv('./logistic_timeseries_drawing_neural_{}_{}.csv'.format(version,tag)) ## train recog, test drawing run    

In [None]:
# Turn accuracy into array
Acc = np.array(Acc)
x = pd.DataFrame(Acc.transpose())
roi_reduce = [roi[:-4] for roi in roi_list_short]
x.columns = roi_reduce

In [None]:
# Plot raw classifier accuracy
sns.set_context('talk')
plt.figure(figsize=(12,6))
sns.barplot(data=x,palette='Set2',ci=95)
chance_dict = {'4way':0.25,'3way':0.33,'2way':0.5, '2wayDraw': 0.5}
plt.axhline(chance_dict[version],linestyle=':',color='k')
plt.ylabel('Accuracy')
plt.xlabel('Intersect of Drawing Task Map and ROI')
plt.tight_layout()
plt.title('Object Classification during Drawing Runs')
if save_fig:
    plt.savefig('drawing_classification.pdf')
plt.show() if show_fig else plt.close()

## IMPORT AND PREPROCESS VGG PROBABILITIES

In [None]:
types = ['bedbench', 'bedchair', 'bedtable', 'benchchair', 'benchtable', 'chairtable']
subs = [bedbench, bedchair, bedtable, benchchair, benchtable, chairtable]
objects = [['bed','bench'], ['bed','chair'], ['bed','table'], ['bench','chair'], ['bench','table'], ['chair','table']]
obdict = dict(zip(types, objects))
subdict = dict(zip(types, subs))

In [None]:
VGG_feat = pd.read_csv('partial_sketch_full.csv')
VGG_feat['time_point'] = VGG_feat['numSketch'] + 1
VGG_feat['trial_num'] = VGG_feat['trial'] - 320
VGG_feat['subj'] = [x.split('_')[0] for x in VGG_feat['wID']]
updated = []
for subject in sub_list:
    sub_only = VGG_feat[VGG_feat['subj']==subject]
    for typer in types:
        if subject in subdict[typer]:
            t1, t2 = obdict[typer]
    sub_only['t1'] = t1
    sub_only['t2'] = t2
    sub_only['t1_prob'] = np.log(sub_only[t1])
    sub_only['t2_prob'] = np.log(sub_only[t2])
    if len(updated) == 0:
        updated = sub_only
    else:
        updated = pd.concat((updated, sub_only))
updated.to_csv('VGG.csv')

## Plot correspondence between probabilities across regions - all time points

#### Choose whether you want a plot for every subject

In [None]:
# Do you want to plot this subject-wise?
plot_subs = True

#### Cycle through subjects, and compute the correlations

In [None]:
#ALLDM = pd.read_csv('./logistic_timeseries_drawing_neural_2wayDraw_logged.csv')
#ALLDM = pd.read_csv('./logistic_timeseries_drawing_neural_2wayDraw_raw.csv')
#updated = pd.read_csv('./VGG.csv')


_sub_list = sub_list
_roi_list = ['V1Draw','V2Draw', 'LOCDraw', 'ParietalDraw', 'smgDraw', 
             'postCentralDraw', 'preCentralDraw', 'FrontalDraw', 'VGG']
roi_name = ['V1', 'V2', 'LOC', 'Par', 'SMG', 'Sens', 'Mot', 'Front', 'VGG']

images = []
all_subs_all_rois = []
all_corrs = []

for sub in _sub_list:
    sub_only = ALLDM[ALLDM['subj']==sub]
    sub_only_vgg = updated[updated['subj']==sub]
    drawpreds = []
    for roi in _roi_list:
        if roi == 'VGG':
            roi_only = sub_only_vgg
        else:
            roi_only = sub_only[sub_only['roi']==roi]
        t1 = np.array(roi_only['t1_prob'])
        t2 = np.array(roi_only['t2_prob'])
        assert t1.shape == t2.shape
        trained = np.hstack((t1,t2))
        drawpreds.append(trained)
    drawpreds = np.array(drawpreds)
    
    #correlation across all regions for one subject
    corrs = np.corrcoef(drawpreds)
    
    #stack
    all_corrs = corrs if len(all_corrs) == 0 else np.dstack((all_corrs, corrs))
    
    if plot_subs:
        clear_output(wait=True)
        np.place(corrs, corrs>0.9, np.nan)
        fig, ax = plt.subplots(figsize=(10,10))
        im = ax.matshow(corrs, vmin = -0.15, vmax = 0.65)
        ax.set_xticklabels(['']+roi_name)
        ax.set_yticklabels(['']+roi_name)
        plt.colorbar(im)
        plt.title(str(sub), y=1.08)
        plt.tight_layout()
        if save_fig:
            plt.savefig(str(sub)+'.pdf')
            currIm = Image.open(str(sub) + '.pdf')
            images.append(currIm)
        plt.show() if show_fig else plt.close()
all_corrs = np.array(all_corrs)
t, p = stats.ttest_1samp(all_corrs, 0, axis=2)
    
# giffify results
if plot_subs and save_fig:
    images = [np.array(image) for image in images]
    imageio.mimsave('all_mats_subs.gif', images, duration=1)   

#### Plot the average across subjects, first as a matrix, then as a barplot with 95 CI error bars

In [None]:
# Plot matrices showing average across subjects

all_mean = np.mean(all_corrs, axis=2)
all_std = (np.std(all_corrs, axis=2)/np.sqrt(len(sub_list)))*1.96


fig, ax = plt.subplots(figsize=(10,10))
im = ax.matshow(np.absolute(all_mean), vmin=0, vmax = 0.5)
ax.set_xticklabels(['']+roi_name)
ax.set_yticklabels(['']+roi_name)
plt.colorbar(im)
colours = im.cmap(im.norm(all_mean))
if save_fig:
    plt.savefig('all_subjects.pdf')
plt.show() if show_fig else plt.close()


# Plot average across subjects as barplot with error bars

indices = np.arange(0,9)
fig, axes = plt.subplots(9,1,figsize=(10,10), sharey=True, sharex=True)
for i in range(all_mean.shape[0]):
    ax = axes[i]
    means = all_mean[i,:]
    sterr = all_std[i,:]
    ax.bar(indices, means, 1, color=colours[i], yerr=sterr, zorder=-1, edgecolor='black', lw=2, 
           ecolor='red', error_kw=dict(lw=2, capsize=2.5, capthick=1))
    ax.set_facecolor('white')
    ax.set_xticks(indices)
    ax.set_xticklabels(roi_name)
    ax.text(9.5, 0.25, roi_name[i], fontsize=17, fontname="Arial", verticalalignment='center', 
            horizontalalignment='center')
    ax.set_ylim([-0.05,0.5])
    plt.errorbar
plt.tight_layout()
if save_fig:
    plt.savefig('all_subjects_bar.pdf')
plt.show() if show_fig else plt.close()

## Plot correspondence between probabilities across regions - by timepoint

#### Choose what time unit to plot by, and whether you want a moving window, or to plot every unit

In [None]:
# Do you want to plot by trial number (trial_num), or by TR (time_point)    
by = 'trial_num'

# Do you want to plot as a moving window, or (if False) by each unit of time. 
moving = True

# Choose the size of the moving window
windowsize = 3

# Choose whether you want the windows to be non-overlapping/discrete
discrete = False

#### This cell will create the actual plots, and giffify if you so choose

In [None]:
#ALLDM = pd.read_csv('./logistic_timeseries_drawing_neural_2wayDraw_logged.csv')
#ALLDM = pd.read_csv('./logistic_timeseries_drawing_neural_2wayDraw_raw.csv')

if moving:
    if discrete:
        step = windowsize
else:
    step = 1
    windowsize = 1
if by == 'trial_num':
    low = 0
    high = 40
elif by == 'time_point':
    low = 1
    high = 24

_sub_list = sub_list
_roi_list = ['V1Draw','V2Draw', 'LOCDraw', 'ParietalDraw', 'smgDraw', 
             'postCentralDraw', 'preCentralDraw', 'FrontalDraw', 'VGG']
roi_name = ['V1', 'V2', 'LOC', 'Par', 'SMG', 'Sens', 'Mot', 'Front', 'VGG']
images = []
all_all = []
for point in range(low,high-windowsize+1,step):
    bottom = point
    top = point + windowsize
    _ALLDM = ALLDM[ALLDM[by]>=bottom]
    __ALLDM = _ALLDM[_ALLDM[by]<top]
    _updated = updated[updated[by]>=bottom]
    __updated = _updated[_updated[by]<top]
    all_corrs = []
    for sub in _sub_list:
        sub_only = __ALLDM[__ALLDM['subj']==sub]
        sub_only_vgg = __updated[__updated['subj']==sub]
        drawpreds = []
        for roi in _roi_list:
            if roi == 'VGG':
                roi_only = sub_only_vgg
            else:
                roi_only = sub_only[sub_only['roi']==roi]
            t1 = np.array(roi_only['t1_prob'])
            t2 = np.array(roi_only['t2_prob'])
            assert t1.shape == t2.shape
            trained = np.hstack((t1,t2))
            drawpreds.append(trained)
        drawpreds = np.array(drawpreds)
        corrs = np.corrcoef(drawpreds)
        all_corrs = corrs if len(all_corrs) == 0 else np.dstack((all_corrs, corrs))
    all_corrs = np.array(all_corrs)   
    t, p = stats.ttest_1samp(all_corrs, 0, axis=2)
    clear_output(wait=True)
    all_collapse = np.mean(all_corrs, axis=2)
    np.place(all_collapse, all_collapse>0.9, np.nan)
    fig, ax = plt.subplots(figsize=(10,10))
    im = ax.matshow(all_collapse, vmin = 0, vmax = 0.5)
    ax.set_xticklabels(['']+roi_name)
    ax.set_yticklabels(['']+roi_name)
    plt.colorbar(im)
    smoothing = ', win_size = {}'.format(windowsize) if windowsize > 1 else ''
    plt.title(str(by) + ': '+ str(point) + str(smoothing), y=1.08)
    plt.tight_layout()
    if save_fig:
        plt.savefig(str(point)+'.pdf')
        currIm = Image.open(str(point) + '.pdf')
        images.append(currIm)
    plt.show() if show_fig else plt.close()

# giffify results
if save_fig:
    images = [np.array(image) for image in images]
    imageio.mimsave('all_mats_' + str(by) + '.gif', images, duration=1)