In [48]:
import os
import numpy as np
import glob
import csv
import matplotlib
import glob
import matplotlib.pyplot as plt
import pandas as pd
from pandas import DataFrame, Series
import seaborn as sns
import scipy
import scipy.io
from collections import Counter
import sys
import nibabel as nib
import pickle
from moss.mosaic import Mosaic
from scipy.optimize import minimize
from ipyparallel import Client
import operator
%matplotlib inline

In [21]:
home_dir ='/Users/ianballard/Dropbox/Decision Neuroscience Lab/Pilot_MRI_Data/jesse_scan/'

In [None]:
cmap = sns.cubehelix_palette(as_cmap=True, reverse=True,light=1, dark=0) #colormap for plotting

In [15]:
prefix = 's0'
scan_map = {'06':(0,30,0),#acquisition angle, TE, pepolar
            '07':(0,30,1),
            '08':(0,30,1),
            '12':(0,27,1),
            '36':(0,27,0),
            '37':(0,27,1),
            '13':(15,27,0),
            '14':(15,27,0),  
            '16':(15,27,0),
            '17':(15,27,1),
            '18':(15,27,1),
            '20':(15,30,0),
            '21':(15,30,1),
            '23':(15,30,1),
            '24':(15,30,0),
            '26':(30,30,1),
            '27':(30,30,0),  
            '29':(30,30,0),
            '30':(30,30,1),
            '31':(30,27,0),
            '32':(30,27,1),
            '33':(30,27,1),
            '34':(30,27,0),
           }

In [16]:
#motion correct data (to a common template (random slice taken from one of the runs))
for s in scan_map.keys():
    
    scan = home_dir + prefix + s + 'a1001.nii.gz'
    cmd = ['mcflirt','-in',scan,'-reffile',home_dir + '/slice_target.nii.gz']
    os.system(' '.join(cmd))

In [153]:
def find_spikes(d, spike_thresh):
    slice_mean = d.mean(axis=0).mean(axis=0)
    t_z = (slice_mean - np.atleast_2d(slice_mean.mean(axis=1)).T) / np.atleast_2d(slice_mean.std(axis=1)).T
    spikes = np.abs(t_z)>spike_thresh
    spike_inds = np.transpose(spikes.nonzero())
    spike_inds = [x[0] for x in spike_inds]
    return spike_inds, t_z

In [154]:
def process_data(scan):
    scan = home_dir + prefix + s + 'a1001_mcf.nii.gz'

    d = nib.load(scan).get_data()
    
    #drop T1 saturation
    d = d[...,2:]

    #find and remove spikes
    spike_inds, t_z = find_spikes(d, 3.5)
    mask = np.ones(d.shape[-1], dtype=bool)
    mask[spike_inds] = False
    d = d[...,mask]

    #compute mean and tsnr
    mean_signal = d.mean(axis=3)
    tsnr = mean_signal / d.std(axis=3)

    #compute a rough mask
    counts, bins = np.histogram(mean_signal[mean_signal > 0], 50)
    thresh = bins[np.diff(counts) > 0][0]
    mask = mean_signal>thresh

    return mean_signal,tsnr,mask

In [163]:
#get median tsnr for each scan as well load up all data
tsnr = {}
all_data = {}
for s in scan_map.keys():
    scan = home_dir + prefix + s + 'a1001_mcf.nii.gz'
    d1,tsnr1,mask1 = process_data(scan)
    all_data[s] = (d1,tsnr1,mask1)
    med_tsnr = np.median(tsnr1[mask1])
    tsnr[s] = med_tsnr



In [164]:
sorted_tsnr = sorted(tsnr.items(), key=operator.itemgetter(1))
sorted_tsnr = [(scan_map[x],y) for x,y in sorted_tsnr]
sorted_tsnr

[((15, 30, 1), 76.029144353909516),
 ((15, 30, 0), 78.779035202770302),
 ((15, 27, 0), 82.492650837849084),
 ((15, 27, 1), 83.147995923933919),
 ((30, 30, 0), 87.843098911206098),
 ((15, 27, 1), 91.827082361196986),
 ((30, 30, 0), 93.95820849713742),
 ((15, 27, 0), 96.662831134074551),
 ((15, 27, 0), 97.723920939552016),
 ((0, 27, 1), 99.761887608176835),
 ((30, 30, 1), 103.21623946665611),
 ((30, 27, 1), 109.63703485004439),
 ((30, 27, 0), 113.30422287110817),
 ((30, 27, 1), 117.90740931975849),
 ((30, 27, 0), 125.08152630661495),
 ((30, 30, 1), 125.19361100232072),
 ((15, 30, 1), 127.36821455890288),
 ((15, 30, 0), 131.20591567357221),
 ((0, 27, 1), 131.86249808201251),
 ((0, 27, 0), 134.03668886227422),
 ((0, 30, 1), 135.45174207312377),
 ((0, 30, 1), 143.55631709856956),
 ((0, 30, 0), 152.63841343239233)]

In [169]:
def make_diff_pairwise(params1,params2):
    
    tsnr_trg = []
    tsnr_cmp = []
    for s in scan_map.keys():
        if scan_map[s] == params1: #target scan
            d1,tsnr,mask1 = all_data[s]
            tsnr_trg.append(tsnr)
        if scan_map[s] == params2: #comparison scan
            d2,tsnr,mask2 = all_data[s]
            tsnr_cmp.append(tsnr) 
    
    #take mean tsnr across exemplars
    tsnr1 = np.stack(tsnr_trg)
    tsnr1 = np.mean(tsnr1,axis=0)

    tsnr2 = np.stack(tsnr_cmp)
    tsnr2 = np.mean(tsnr2,axis=0)

    #compute tsnr difference
    diff = (tsnr1 - tsnr2)/tsnr2
    diff = np.nan_to_num(diff)
    mask = np.ma.mask_or(mask1,mask2).astype(int)

    #plot and save
    m = Mosaic(d2,stat = diff,mask = mask, step=1)
    m.plot_activation(thresh = .05,vmin=.05,vmax = 1,neg_cmap='Blues')
    m.savefig(home_dir + '/figs/pairwise/angle_' + str(params1[0]) + '_TE_' +
              str(params1[1]) + '_pe_'+  str(params1[2]) +
              '_grtr_than_angle_' + str(params2[0]) + '_TE_' +
              str(params2[1]) + '_pe_'+  str(params2[2]) + '.png')

In [171]:
def make_diff(params):
    #Average TSNR of target parameters as well as TSNR of all other parameters
    tsnr_trg = []
    tsnr_others = []
    for s in scan_map.keys():
        if scan_map[s] == params: #target scan
            d1,tsnr,mask1 = all_data[s]
            tsnr_trg.append(tsnr)
        else:
            d2,tsnr,mask2 = all_data[s]
            tsnr_others.append(tsnr)        

    #take mean tsnr across exemplars
    tsnr1 = np.stack(tsnr_trg)
    tsnr1 = np.mean(tsnr1,axis=0)

    tsnr2 = np.stack(tsnr_others)
    tsnr2 = np.mean(tsnr2,axis=0)

    #compute tsnr difference
    diff = (tsnr1 - tsnr2)/tsnr2
    diff = np.nan_to_num(diff)
    mask = np.ma.mask_or(mask1,mask2).astype(int)

    #plot and save
    m = Mosaic(d2,stat = diff,mask = mask, step=1)
    m.plot_activation(thresh = .05,vmin=.05,vmax = 1,neg_cmap='Blues')
    m.savefig(home_dir + '/figs/angle_' + str(params[0]) + '_TE_' +
              str(params[1]) + '_pe_'+  str(params[2]) + '_grtr_than_mean.png')

In [4]:
for p in set(scan_map.values()):
    make_diff(p)

In [3]:
params = (30,27,1)
for p in set(scan_map.values()):
    if p != params:
        make_diff_pairwise(params,p)

In [2]:
params = (0,30,0)
for p in set(scan_map.values()):
    if p != params:
        make_diff_pairwise(params,p)

In [1]:
params = (0,27,0)
for p in set(scan_map.values()):
    if p != params:
        make_diff_pairwise(params,p)