## Setup

In [None]:
import warnings
warnings.filterwarnings("ignore")

from __future__ import division

import numpy as np
import os
from glob import glob

from PIL import Image
from copy import deepcopy

from IPython.display import clear_output

from sklearn import linear_model, datasets, neighbors
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn import svm

%matplotlib inline
from scipy.misc import imread, imresize
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
sns.set_context('poster')
colors = sns.color_palette("cubehelix", 5)
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

import pandas as pd

import scipy.stats as stats
from scipy.stats import norm
import sklearn
import itertools

from importlib import reload



### define paths to data

In [None]:
curr_dir = os.getcwd()
os.path.abspath(os.path.join(curr_dir,'..','..'))

In [None]:
# definte ROIs
roi_list_connect = np.array(['V1Draw', 'V2Draw', 'LOCDraw',
                             'InsulaDraw', 'postCentralDraw', 'preCentralDraw', 
                             'ParietalDraw', 'FrontalDraw', 'smgDraw'])

In [None]:
## root paths
proj_dir = os.path.abspath(os.path.join(curr_dir,'..','..')) ## use relative paths
data_dir = os.path.abspath(os.path.join(curr_dir,'..','..','data')) ## use relative paths 'D:\\data'
results_dir = os.path.join(proj_dir, 'csv')
nb_name = '4_connectivity_pattern_during_drawing'

## add helpers to python path
import sys
if os.path.join(proj_dir, 'python') not in sys.path:
    sys.path.append(os.path.join(proj_dir, 'python'))

## module definitions
import utils as utils
reload(utils)
utils.data_dir = data_dir
utils.path_to_connect = os.path.join(data_dir, 'features/connect')
utils.roi_list_connect = roi_list_connect

### get file list

In [None]:
## get raw file list for recognition runs
path_to_connect = utils.path_to_connect

CONNECT_METAS = sorted([i for i in os.listdir(path_to_connect) if (i.split('.')[-1]=='csv')])
CONNECT_FEATS = sorted([i for i in os.listdir(path_to_connect) if (i.split('.')[-1]=='npy')])
CONNECT_SUBS = np.array([i.split('_')[0] for i in CONNECT_FEATS])

sub_list = np.unique(CONNECT_SUBS)

In [None]:
print('We have data from {} subjects.'.format(len(sub_list)))

### DRAWING: How well do we do at classifying the target during production runs based on the connectivity patterns across trials?

In [None]:
reload(utils)
version = 'phase'  # 'phase' or 'allruns'
logged = True

really_run = 1
if really_run:

    ALLDM, Acc = utils.make_drawing_connectivity_predictions(sub_list[:3],roi_list_connect,
                                                             version=version, logged=logged)
    ## save out ALLDM & Acc
    Acc = np.array(Acc)
    np.save(os.path.join(results_dir,
                         'connectivity_{}_accuracy_production.npy'.format(version)),Acc)
    ALLDM.to_csv(os.path.join(results_dir,
                              'connectivity_{}_logprobs_production.csv'.format(version)),index=False)
    
else:
    ## load in existing ALLDM & Acc 
    Acc = np.load(os.path.join(results_dir,
                               'connectivity_{}_accuracy_production.npy'.format(version)))
    ALLDM = pd.read_csv(os.path.join(results_dir,
                                     'connectivity_{}_logprobs_production.csv'.format(version)))
    
print('Done!')

### make summary timecourse plots

In [None]:
reload(utils)
versions = ['phase']
tags = ['logged']
iv_list = ['run_num']
plotType='bar'

reallyRun = 1
if reallyRun:
    for version in versions:
        for tag in tags:
            for iv in iv_list:
                ALLDM = pd.read_csv(os.path.join(results_dir, 
                                                 'connectivity_{}_logprobs_production.csv'.format(version)))
                try:
                    utils.plot_connect_timecourse(ALLDM,
                                                  this_iv=iv,
                                                  roi_list=roi_list_connect,
                                                  render_cond=0,
                                                  version = version,
                                                  proj_dir=proj_dir,
                                                  plotType=plotType)
                except KeyError:
                    pass