In [None]:
import numpy as np
import tensorflow as tf
import t3f
tf.set_random_seed(0)
np.random.seed(0)
%matplotlib inline
import matplotlib.pyplot as plt
import metric_util as mt
import data_util as du
from t3f import shapes
from nilearn import image
import nibabel as nib
from math import sqrt
import metric_util


# Tensor completion

In this example we will see how can we do tensor completion with t3f, i.e. observe a fraction of values in a tensor and recover the rest by assuming that the original tensor has low TT-rank.
Mathematically it means that we have a binary mask $P$ and a ground truth tensor $A$, but we observe only a noisy and sparsified version of $A$: $P \odot (\hat{A})$, where $\odot$ is the elementwise product (applying the binary mask) and $\hat{A} = A + \text{noise}$. In this case our task reduces to the following optimization problem:
\begin{equation*}
\begin{aligned}
& \underset{X}{\text{minimize}} 
& & \|P \odot (X - \hat{A})\|_F^2 \\
& \text{subject to} 
& & \text{tt_rank}(X) \leq r_0
\end{aligned}
\end{equation*}



### Generating problem instance,
Lets generate a random matrix $A$, noise, and mask $P$.

In [None]:
subject_scan_path = du.get_full_path_subject1()
print "Subject Path: " + str(subject_scan_path)
x_true_org = mt.read_image_abs_path(subject_scan_path)
#x_true_org = image.index_img(x_true_org,1)
x_true_org1 = mt.read_image_abs_path(subject_scan_path)

In [None]:
x_true_img = np.array(x_true_org.get_data())

In [None]:
#shape = (3, 4, 4, 5, 7, 5)
shape = (53,63,46,144)
# Fix random seed so the results are comparable between runs.
tf.set_random_seed(0)
# Generate ground truth tensor A. To make sure that it has low TT-rank,
# let's generate a random tt-rank 5 tensor and apply t3f.full to it to convert to actual tensor.
#ground_truth = t3f.full(t3f.random_tensor(shape, tt_rank=5))
ground_truth = x_true_img
# Make a (non trainable) variable out of ground truth. Otherwise, it will be randomly regenerated on each sess.run.
ground_truth = tf.get_variable('ground_truth', initializer=ground_truth, trainable=False)
noise = 1e-2 * tf.get_variable('noise', initializer=tf.random_normal(shape), trainable=False)
noisy_ground_truth = ground_truth + noise
# Observe 25% of the tensor values.
sparsity_mask = tf.cast(tf.random_uniform(shape) <= 0.80, tf.float32)
sparsity_mask = tf.get_variable('sparsity_mask', initializer=sparsity_mask, trainable=False)
sparse_observation = noisy_ground_truth * sparsity_mask

### Initialize the variable and compute the loss

In [None]:
def frobenius_norm_tf(x):
    return tf.reduce_sum(x ** 2) ** 0.5

In [None]:
def relative_error1(x_hat,x_true):
    percent_error = frobenius_norm_tf(x_hat - x_true) / (frobenius_norm_tf(x_true))
    return percent_error

In [None]:
observed_total = tf.reduce_sum(sparsity_mask)
total = np.prod(shape)
ranks_a = np.array([53,63,46,144,1])
tt_with_ranks = t3f.to_tt_tensor(x_true_img, max_tt_rank=144)
ranks = shapes.tt_ranks(tt_with_ranks)
initialization = t3f.random_tensor(shape, tt_rank=10)
estimated = t3f.get_variable('estimated', initializer=initialization)
# Loss is MSE between the estimated and ground-truth tensor as computed in the observed cells.
loss = tf.reduce_sum((sparsity_mask * t3f.full(estimated) - sparse_observation)**2)/(tf.reduce_sum(sparse_observation)**2)
# Test loss is MSE between the estimated tensor and full (and not noisy) ground-truth tensor A.
test_loss = tf.reduce_sum((t3f.full(estimated) - ground_truth)**2)/(tf.reduce_sum(ground_truth)**2)
rel_error1 = relative_error1(t3f.full(estimated), ground_truth)

In [None]:
shape

# SGD optimization
The simplest way to solve the optimization problem is Stochastic Gradient Descent: let TensorFlow differentiate the loss w.r.t. the factors (cores) of the TensorTrain decomposition of the estimated tensor and minimize the loss with your favourite SGD variation.

In [None]:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01, epsilon=1e-18)
step = optimizer.minimize(loss)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
train_loss_hist = []
test_loss_hist = []
for i in range(2):
    _, tr_loss_v, test_loss_v, rel_error1_v, ranks_v = sess.run([step, loss, test_loss,rel_error1, ranks])
    train_loss_hist.append(tr_loss_v)
    test_loss_hist.append(test_loss_v)
    print(i, tr_loss_v, test_loss_v, rel_error1_v, ranks_v)
    #if i % 1000 == 0:
     #   print(i, tr_loss_v, test_loss_v, rel_error1_v)

In [None]:
plt.loglog(train_loss_hist, label='train')
plt.loglog(test_loss_hist, label='test')
plt.xlabel('Iteration')
plt.ylabel('MSE Loss value')
plt.title('SGD completion')
plt.legend()


In [None]:
#ground_truth_var = t3f.get_variable('ground_truth', initializer=ground_truth, reuse=True)

In [None]:
#ground_truth.read_value()
ground_truth_val = ground_truth.eval(session=sess)

In [None]:
estimated_val = sess.run(t3f.full(estimated))

In [None]:
def relative_error(x_hat,x_true):
    percent_error = np.linalg.norm(x_hat - x_true) / np.linalg.norm(x_true)
    return percent_error

In [None]:
rel_error = relative_error(estimated_val,ground_truth_val)

In [None]:
rel_error

In [None]:
#ten_ones = np.ones_like(mask)
#x_reconstr = mt.reconstruct(x_hat,x_true, ten_ones, mask)

In [None]:
estimated_val.shape

In [None]:
from nilearn import image

In [None]:
shape = (53,63,46,144)

In [None]:
sparse_observation_val=sparse_observation.eval(session=sess)

In [None]:
x_miss_img = mt.reconstruct_image_affine(x_true_org, sparse_observation_val)

In [None]:
#x_miss = x_miss_img
x_miss = image.index_img(x_miss_img,1)

In [None]:
x_hat_img = mt.reconstruct_image_affine(x_true_org, estimated_val)


In [None]:
#x_hat = x_hat_img
x_hat = image.index_img(x_hat_img,1)

In [None]:
from nilearn import plotting

In [None]:
x_true_org = image.index_img(x_true_org,1)

In [None]:
recovered_image = plotting.plot_epi(x_hat, bg_img=None,black_bg=True, cmap='jet', cut_coords=None) 

In [None]:
x_miss_image = plotting.plot_epi(x_miss, bg_img=None,black_bg=True, cmap='jet', cut_coords=None) 

In [None]:
import ellipsoid_masker as elpm
import ellipsoid_mask as em

In [None]:
def create_corrupted_image(x0,y0,z0, x_r, y_r):
    pass

In [None]:
def generate_structural_missing_pattern(x0,y0,z0, x_r, y_r, frames_count, path_folder):
    subject_scan_path = du.get_full_path_subject1()
        
    print ("3D Random Missing Value Pattern Simulations has started...")
    print "Subject Path: " + str(subject_scan_path)
    
    n = 0
    # type 1 (center is the center of the image), corrupt first 10 frames
    #x0, y0, z0 = (0 ,-18 , 17)
    #x_r, y_r, z_r = (20, 17, 15)
    
    print "===Type 1 Experiments===="
    
    target_img = image.index_img(subject_scan_path,n)
    
    type_1_folder_path = path_folder
    masked_img_file_path  = type_1_folder_path + "/" + "size_" + str(x_r) + "_" + str(y_r) + "_" + str(z_r) + "_scan_" + str(n)
    
    corrupted_volumes_list = []
    corrupted_volumes_list_scan_numbers = []
    
    for i in xrange(frames_count):
        masked_img_file_path  = type_1_folder_path + "/" + "size_" + str(x_r) + "_" + str(y_r) + "_" + str(z_r) + "_scan_" + str(i)
        target_img = image.index_img(subject_scan_path,i)
        image_masked_by_ellipsoid = elpm.create_ellipsoid_mask(x0, y0, z0, x_r, y_r, z_r, target_img, masked_img_file_path)
        
        masked_img_file_path = masked_img_file_path + ".nii"
        ellipsoid = em.EllipsoidMask(x0, y0, z0, x_r, y_r, z_r, masked_img_file_path)
        ellipsoid_volume = ellipsoid.volume()
        observed_ratio = mt.compute_observed_ratio(image_masked_by_ellipsoid)
        
        corrupted_volumes_list.append(image_masked_by_ellipsoid)
        corrupted_volumes_list_scan_numbers.append(i)
        print ("Ellipsoid Volume: " + str(ellipsoid_volume) + "; Missing Ratio: " + str(observed_ratio))
    
    # now create corrupted 4d where fist 10 frames has ellipsoid missing across 10 frames
    counter = 0
    
    volumes_list = []
    for img in image.iter_img(subject_scan_path):
        print "Volume Index: " + str(counter)
        if counter in corrupted_volumes_list_scan_numbers:
            print "Adding corrupted volume to the list " + str(counter)
            volumes_list.append(corrupted_volumes_list[counter])
        else:
            print "Adding normal volume to the list " + str(counter)
            volumes_list.append(img)
        counter = counter + 1
        
    # now generate corrupted 4D from the list
    corrupted4d_10 = image.concat_imgs(volumes_list)
    print "Corrupted 4D - 10 frames: " + str(corrupted4d_10)
    observed_ratio4D_10 = mt.compute_observed_ratio(corrupted4d_10)
    print ("Corrupted 4D - 10 Volume: " + "; Missing Ratio: " + str(observed_ratio4D_10))
    corr_file_path4D = du.corrupted4D_10_frames_path()
    nib.save(corrupted4d_10, corr_file_path4D)
    return corrupted4d_10

In [None]:
def get_xyz(i, j, k, epi_img):
    M = epi_img.affine[:3, :3]
    abc = epi_img.affine[:3, 3]
    return M.dot([i, j, k]) + abc

In [None]:
coord = [26,31,23]
path_folder3D = "/work/el/3D"

In [None]:
x_true_org3D = image.index_img(subject_scan_path,0)
nib.save(x_true_org3D,path_folder3D)

In [None]:
x_coord  = get_xyz(10,7,4, x_true_org3D)

In [None]:
x_coord

In [None]:
x_coord_center  = get_xyz(26,31,23, x_true_org3D)

In [None]:
x_coord_center

In [None]:
path_folder3D = "/work/el/3D"

In [None]:
path_folder = "/work/el/75"

In [None]:
#x0, y0, z0 = (-10, -20,17)
#x0, y0, z0 = (-5, -20,17)
#x0, y0, z0 = (2, 32,22)
#x_r, y_r, z_r = (20,17,15)
x0, y0, z0 = (2, 32,22)
# size 1
#x_r, y_r, z_r = (7,10,8)
# size 2
#x_r, y_r, z_r = (9,10,8)
#size 3
x_r, y_r, z_r = (12,10,8)

In [None]:
x_miss_img = generate_structural_missing_pattern(x0,y0,z0, x_r, y_r, 1, path_folder)

In [None]:
x_miss_img

In [None]:
x_miss_img_data = np.array(x_miss_img.get_data())

In [None]:
mask = elpm.ellipsoid_masker(x_r, y_r, z_r, x0, y0, z0, x_true_org3D)

In [None]:
mask_data = np.array(mask.get_data())

In [None]:
mask.shape

In [None]:
mask_data.shape

In [None]:
def generate_system_noise_roi_mask(img, snr_db, mask):
    snr = sqrt(np.power(10.0, snr_db / 10.0))
    print ("snr: " + str(snr))
    data = np.array(img.get_data())
    signal = data[mask > 0].reshape(-1)
    sigma_n = signal.mean() / snr
    print ("sigma_n: " + str(sigma_n))
    n_1 = np.random.normal(size=data.shape, scale=sigma_n)
    n_2 = np.random.normal(size=data.shape, scale=sigma_n)
    stde_1 = n_1 / sqrt(2.0)
    stde_2 = n_2 / sqrt(2.0)
    im_noise = np.sqrt((data + stde_1)**2 + (stde_2)**2)
    im_noise[mask == 0] = 0
    noise_idxs = np.where(im_noise > 0)
    data[noise_idxs] = im_noise[noise_idxs]
    return data, im_noise

In [None]:
def create_noisy_image(x, snr_db, mask):
    x_noisy, noise_mask = generate_system_noise_roi_mask(x, snr_db, mask)
    x_noisy_img = mt.reconstruct_image_affine(x, x_noisy)
    noise_mask_img = mt.reconstruct_image_affine(x, noise_mask)
    return x_noisy_img, noise_mask_img

In [None]:
noisy_roi = generate_system_noise(x_true_org3D, 2, mask_data)

In [None]:
data_x = np.array(x_true_org3D.get_data())

In [None]:
data1 = data_x*(1./np.linalg.norm(data_x))

In [None]:
#1./np.linalg.norm(self.x_init)

In [None]:
idxs = np.where(noisy_roi > 0)

In [None]:
data1[mask_data == 1] = 0 

In [None]:
ss = data1 + noisy_roi

In [None]:
data1[x_r +2 , y_r +2, z_r +2]

In [None]:
noisy_roi_img = mt.reconstruct_image_affine(x_true_org3D, noisy_roi)

In [None]:
x_n, noise_mask = create_noisy_image(x_true_org3D, 2, mask_data)

In [None]:
ss[x_r +2 , y_r +2, z_r +2]

In [None]:
disp = plotting.plot_img(x_true_org3D, bg_img=None,black_bg=True, cmap='jet', cut_coords=[x0, y0, z0]) 
#disp.add_contours(mask, levels=[0.1, 0.3, 0.4, 0.5], filled=False, colors='b')
disp.add_overlay(noise_mask, alpha = 0.7)

In [None]:
counter = 0
volumes_list = []
for img in image.iter_img(subject_scan_path):
    print "Volume Index: " + str(counter)
    if counter == 0:
        print "Adding corrupted volume to the list " + str(counter)
        volumes_list.append(x_n)
    else:
        print "Adding normal volume to the list " + str(counter)
        volumes_list.append(img)
        counter = counter + 1
        
    # now generate corrupted 4D from the list
x_corr_img = image.concat_imgs(volumes_list)

In [None]:
mask_img_data = np.array(x_corr_img.get_data())


In [None]:
print mask_img_data.shape

In [None]:
print mask_indices_img

In [None]:
mask_indices = np.ones_like(mask_img_data)
mask_indices[mask_img_data == 0] = 0.0
mask_indices_img = mt.reconstruct_image_affine(x_true_org1, mask_indices)

In [None]:
from nilearn.masking import apply_mask
masked_data = apply_mask(x_corr_img, mask)

In [None]:
print x_corr_img

In [None]:
masked_data.shape

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(7, 5))
plt.plot(masked_data[:148, :25])
plt.xlabel('Time [TRs]', fontsize=16)
plt.ylabel('Intensity', fontsize=16)
plt.xlim(0, 143)
plt.subplots_adjust(bottom=.12, top=.95, right=.95, left=.12)


In [None]:
x_miss_img_data.shape

In [None]:
print x_true_org

In [None]:
target_img = image.index_img(x_true_org1,0)

In [None]:
masked_img_file_path  = "/work/el" + "/" + "size_" + str(x_r) + "_" + str(y_r) + "_" + str(z_r) + "_scan_" + str(i)

In [None]:
masked_img_file_path

In [None]:
image_masked_by_ellipsoid = elpm.create_ellipsoid_mask(x0, y0, z0, x_r, y_r, z_r, target_img, masked_img_file_path)

In [None]:
target_img

In [None]:
import nipy

In [None]:
image_masked_by_ellipsoid

In [None]:
def spectrum_mask(size):
    """Creates a mask to filter the image of size size"""
    import numpy as np
    from scipy.ndimage.morphology import distance_transform_edt as distance

    ftmask = np.ones(size)

    # Set zeros on corners
    # ftmask[0, 0] = 0
    # ftmask[size[0] - 1, size[1] - 1] = 0
    # ftmask[0, size[1] - 1] = 0
    # ftmask[size[0] - 1, 0] = 0
    ftmask[size[0] // 2, size[1] // 2] = 0

    # Distance transform
    ftmask = distance(ftmask)
    ftmask /= ftmask.max()

    # Keep this just in case we want to switch to the opposite filter
    ftmask *= -1.0
    ftmask += 1.0

    ftmask[ftmask >= 0.4] = 1
    ftmask[ftmask < 1] = 0
    return ftmask


In [None]:
def slice_wise_fft(in_file, ftmask=None, spike_thres=3., out_prefix=None):
    """Search for spikes in slices using the 2D FFT"""
    import os.path as op
    import numpy as np
    import nibabel as nb
    from scipy.ndimage.filters import median_filter
    from scipy.ndimage import generate_binary_structure, binary_erosion
    from statsmodels.robust.scale import mad

    if out_prefix is None:
        fname, ext = op.splitext(op.basename(in_file))
        if ext == '.gz':
            fname, _ = op.splitext(fname)
        out_prefix = op.abspath(fname)

    func_data = nb.load(in_file).get_data()

    if ftmask is None:
        ftmask = spectrum_mask(tuple(func_data.shape[:2]))

    fft_data = []
    for t in range(func_data.shape[-1]):
        func_frame = func_data[..., t]
        fft_slices = []
        for z in range(func_frame.shape[2]):
            sl = func_frame[..., z]
            fftsl = median_filter(np.real(np.fft.fft2(sl)).astype(np.float32),
                                  size=(5, 5), mode='constant') * ftmask
            fft_slices.append(fftsl)
        fft_data.append(np.stack(fft_slices, axis=-1))

    # Recompose the 4D FFT timeseries
    fft_data = np.stack(fft_data, -1)

    # Z-score across t, using robust statistics
    mu = np.median(fft_data, axis=3)
    sigma = np.stack([mad(fft_data, axis=3)] * fft_data.shape[-1], -1)
    idxs = np.where(np.abs(sigma) > 1e-4)
    fft_zscored = fft_data - mu[..., np.newaxis]
    fft_zscored[idxs] /= sigma[idxs]

    # save fft z-scored
    out_fft = op.abspath(out_prefix + '_zsfft.nii.gz')
    nii = nb.Nifti1Image(fft_zscored.astype(np.float32), np.eye(4), None)
    nii.to_filename(out_fft)

    # Find peaks
    spikes_list = []
    for t in range(fft_zscored.shape[-1]):
        fft_frame = fft_zscored[..., t]

        for z in range(fft_frame.shape[-1]):
            sl = fft_frame[..., z]
            if np.all(sl < spike_thres):
                continue

            # Any zscore over spike_thres will be called a spike
            sl[sl <= spike_thres] = 0
            sl[sl > 0] = 1

            # Erode peaks and see how many survive
            struc = generate_binary_structure(2, 2)
            sl = binary_erosion(sl.astype(np.uint8), structure=struc).astype(np.uint8)

            if sl.sum() > 10:
                print ((t, z), sl.sum() )
                spikes_list.append((t, z))

    out_spikes = op.abspath(out_prefix + '_spikes.tsv')
    np.savetxt(out_spikes, spikes_list, fmt=b'%d', delimiter=b'\t', header='TR\tZ')

    return len(spikes_list), out_spikes, out_fft, spikes_list

In [None]:
from nilearn.masking import compute_epi_mask
mask_img1 = np.array(compute_epi_mask(subject_scan_path).get_data())

In [None]:
n_spikes, out_spikes, out_fft, spikes_list = slice_wise_fft(subject_scan_path, spike_thres=4.)

In [None]:
out_fft

In [None]:
n_spikes

In [None]:
spiked_fft = mt.read_image_abs_path(out_fft)

In [None]:
print spiked_fft

In [None]:
tr_6_img = image.index_img(x_true_org1, 39)

In [None]:
mean_fft_img = image.mean_img(spiked_fft)

In [None]:
spike_6_img = image.index_img(spiked_fft,39)

In [None]:
spike_6 = plotting.plot_img(spike_6_img, display_mode='z', bg_img=None,black_bg=True, cmap='Greys_r', cut_coords=[18])
spike_6.add_contours(z_score_d_mask_img, levels=[0.5], filled=True, alpha=0.8, colors='r')

In [None]:
tr_6 = plotting.plot_epi(tr_6_img, display_mode='z', bg_img=None,black_bg=True,cut_coords=[18]) 


In [None]:
#z_score_epi_mask = get_z_score_robust_spatial_mask(tr_6_img, 4) 

In [None]:
def get_z_score_robust_spatial_mask(x_img, z_score_cut_off):
    mu = np.median(np.array(x_img.get_data()))
    sigma = np.stack([mad(np.array(x_img.get_data()))] * np.array(x_img.get_data()).shape[-1], -1)
    idxs = np.where(np.abs(sigma) > 1e-10)
    ground_truth_z_score = np.array(x_img.get_data()) - mu[..., np.newaxis]
    ground_truth_z_score[idxs] /= sigma[idxs]
    mask_z_score_indices = (abs(ground_truth_z_score) > z_score_cut_off).astype('int')
    print ("Z-score indices count: " + str(get_mask_z_indices_count(mask_z_score_indices)))
    return mask_z_score_indices

In [None]:
def get_z_score_robust_mask(x_img, z_score_cut_off):
    mu = np.median(np.array(x_img.get_data()), axis=3)
    sigma = np.stack([mad(np.array(x_img.get_data()), axis=3)] * np.array(x_img.get_data()).shape[-1], -1)
    idxs = np.where(np.abs(sigma) > 1e-10)
    ground_truth_z_score = np.array(x_img.get_data()) - mu[..., np.newaxis]
    ground_truth_z_score[idxs] /= sigma[idxs]
    mask_z_score_indices = (abs(ground_truth_z_score) > z_score_cut_off).astype('int')
    print ("Z-score indices count: " + str(get_mask_z_indices_count(mask_z_score_indices)))
    return mask_z_score_indices

def get_mask_z_indices_count(mask_z_score):
    mask_z_indices_count = np.count_nonzero(mask_z_score==1)
    return mask_z_indices_count

In [None]:
from statsmodels.robust.scale import mad
z_score_d_mask = get_z_score_robust_spatial_mask(spike_6_img, 4) 

In [None]:
z_score_mask = get_z_score_robust_mask(x_true_org1, 4) 

In [None]:
z_score_mask.shape

In [None]:
z_score_d_mask.shape

In [None]:
z_score_d_mask_img = image.new_img_like(spike_6_img,z_score_d_mask)

In [None]:
spike_6_with_overlay = plotting.plot_img(spike_6_img, display_mode='z', bg_img=None,black_bg=True, cmap='gray', colorbar=True, cut_coords=10)
spike_6_with_overlay.add_contours(z_score_d_mask_img, levels=[0.5], filled=True, alpha=0.8, colors='r')

In [None]:
tr_6 = plotting.plot_epi(tr_6_img, display_mode='z', bg_img=None,black_bg=True, cut_coords=10) 

In [None]:
z_score_fft = get_z_score_robust_mask(spiked_fft, 4)

In [None]:
z_score_fft_img = image.new_img_like(spiked_fft,z_score_fft)

In [None]:
masked_data = apply_mask(spiked_fft, z_score_d_mask_img)

# masked_data shape is (timepoints, voxels). We can plot the first 150
# timepoints from two voxels

# And now plot a few of these
import matplotlib.pyplot as plt
plt.figure(figsize=(7, 5))
plt.plot(masked_data[:144, :3])
plt.xlabel('Time [TRs]', fontsize=16)
plt.ylabel('Intensity', fontsize=16)
plt.xlim(0, 150)
plt.subplots_adjust(bottom=.12, top=.95, right=.95, left=.12)



In [None]:
masked_data.shape

In [None]:
masked_data_t = apply_mask(spiked_fft, z_score_d_mask_img)

In [None]:
out_spikes

In [None]:
import nibabel as nb

In [None]:
def plot_spikes(in_file, in_fft, spikes_list, cols=3,
                labelfmt='t={0:.3f}s (z={1:d})',
                out_file=None):
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    nii = nb.as_closest_canonical(nb.load(in_file))
    fft = nb.load(in_fft).get_data()

    data = nii.get_data()
    zooms = nii.header.get_zooms()[:2]
    tstep = nii.header.get_zooms()[-1]
    ntpoints = data.shape[-1]

    if len(spikes_list) > cols * 7:
        cols += 1

    nspikes = len(spikes_list)
    rows = 1
    if nspikes > cols:
        rows = math.ceil(nspikes / cols)

    fig = plt.figure(figsize=(7 * cols, 5 * rows))

    for i, (t, z) in enumerate(spikes_list):
        prev = None
        pvft = None
        if t > 0:
            prev = data[..., z, t - 1]
            pvft = fft[..., z, t - 1]

        post = None
        psft = None
        if t < (ntpoints - 1):
            post = data[..., z, t + 1]
            psft = fft[..., z, t + 1]

        ax1 = fig.add_subplot(rows, cols, i + 1)
        divider = make_axes_locatable(ax1)
        ax2 = divider.new_vertical(size="100%", pad=0.1)
        fig.add_axes(ax2)

        plot_slice_tern(data[..., z, t], prev=prev, post=post, spacing=zooms,
                        ax=ax2,
                        label=labelfmt.format(t * tstep, z))

        plot_slice_tern(fft[..., z, t], prev=pvft, post=psft, vmin=-5, vmax=5,
                        cmap='binary', ax=ax1)

    plt.tight_layout()
    if out_file is None:
        fname, ext = op.splitext(op.basename(in_file))
        if ext == '.gz':
            fname, _ = op.splitext(fname)
        out_file = op.abspath('%s.svg' % fname)

    fig.savefig(out_file, format='svg', dpi=300, bbox_inches='tight')
    return out_file

In [None]:
import math
import os.path as op
import numpy as np
import nibabel as nb

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.backends.backend_pdf import FigureCanvasPdf as FigureCanvas
import seaborn as sns

In [None]:
def plot_slice_tern(dslice, prev=None, post=None,
                    spacing=None, cmap='Greys_r', label=None, ax=None,
                    vmax=None, vmin=None):
    from matplotlib.cm import get_cmap

    if isinstance(cmap, (str, bytes)):
        cmap = get_cmap(cmap)

    est_vmin, est_vmax = _get_limits(dslice)
    if not vmin:
        vmin = est_vmin
    if not vmax:
        vmax = est_vmax

    if ax is None:
        ax = plt.gca()

    if spacing is None:
        spacing = [1.0, 1.0]
    else:
        spacing = [spacing[1], spacing[0]]

    phys_sp = np.array(spacing) * dslice.shape

    if prev is None:
        prev = np.ones_like(dslice)
    if post is None:
        post = np.ones_like(dslice)

    combined = np.swapaxes(np.vstack((prev, dslice, post)), 0, 1)
    ax.imshow(combined, vmin=vmin, vmax=vmax, cmap=cmap,
              interpolation='nearest', origin='lower',
              extent=[0, phys_sp[1] * 3, 0, phys_sp[0]])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)

    if label is not None:
        ax.text(.5, .05, label,
                transform=ax.transAxes,
                horizontalalignment='center',
                verticalalignment='top',
                size=14,
                bbox=dict(boxstyle="square,pad=0", ec='k', fc='k'),
                color='w')


In [None]:
def get_parula():
    from matplotlib.colors import LinearSegmentedColormap

    cm_data = [
        [0.2081, 0.1663, 0.5292],
        [0.2116238095, 0.1897809524, 0.5776761905],
        [0.212252381, 0.2137714286, 0.6269714286],
        [0.2081, 0.2386, 0.6770857143],
        [0.1959047619, 0.2644571429, 0.7279],
        [0.1707285714, 0.2919380952, 0.779247619],
        [0.1252714286, 0.3242428571, 0.8302714286],
        [0.0591333333, 0.3598333333, 0.8683333333],
        [0.0116952381, 0.3875095238, 0.8819571429],
        [0.0059571429, 0.4086142857, 0.8828428571],
        [0.0165142857, 0.4266, 0.8786333333],
        [0.032852381, 0.4430428571, 0.8719571429],
        [0.0498142857, 0.4585714286, 0.8640571429],
        [0.0629333333, 0.4736904762, 0.8554380952],
        [0.0722666667, 0.4886666667, 0.8467],
        [0.0779428571, 0.5039857143, 0.8383714286],
        [0.079347619, 0.5200238095, 0.8311809524],
        [0.0749428571, 0.5375428571, 0.8262714286],
        [0.0640571429, 0.5569857143, 0.8239571429],
        [0.0487714286, 0.5772238095, 0.8228285714],
        [0.0343428571, 0.5965809524, 0.819852381],
        [0.0265, 0.6137, 0.8135],
        [0.0238904762, 0.6286619048, 0.8037619048],
        [0.0230904762, 0.6417857143, 0.7912666667],
        [0.0227714286, 0.6534857143, 0.7767571429],
        [0.0266619048, 0.6641952381, 0.7607190476],
        [0.0383714286, 0.6742714286, 0.743552381],
        [0.0589714286, 0.6837571429, 0.7253857143],
        [0.0843, 0.6928333333, 0.7061666667],
        [0.1132952381, 0.7015, 0.6858571429],
        [0.1452714286, 0.7097571429, 0.6646285714],
        [0.1801333333, 0.7176571429, 0.6424333333],
        [0.2178285714, 0.7250428571, 0.6192619048],
        [0.2586428571, 0.7317142857, 0.5954285714],
        [0.3021714286, 0.7376047619, 0.5711857143],
        [0.3481666667, 0.7424333333, 0.5472666667],
        [0.3952571429, 0.7459, 0.5244428571],
        [0.4420095238, 0.7480809524, 0.5033142857],
        [0.4871238095, 0.7490619048, 0.4839761905],
        [0.5300285714, 0.7491142857, 0.4661142857],
        [0.5708571429, 0.7485190476, 0.4493904762],
        [0.609852381, 0.7473142857, 0.4336857143],
        [0.6473, 0.7456, 0.4188],
        [0.6834190476, 0.7434761905, 0.4044333333],
        [0.7184095238, 0.7411333333, 0.3904761905],
        [0.7524857143, 0.7384, 0.3768142857],
        [0.7858428571, 0.7355666667, 0.3632714286],
        [0.8185047619, 0.7327333333, 0.3497904762],
        [0.8506571429, 0.7299, 0.3360285714],
        [0.8824333333, 0.7274333333, 0.3217],
        [0.9139333333, 0.7257857143, 0.3062761905],
        [0.9449571429, 0.7261142857, 0.2886428571],
        [0.9738952381, 0.7313952381, 0.266647619],
        [0.9937714286, 0.7454571429, 0.240347619],
        [0.9990428571, 0.7653142857, 0.2164142857],
        [0.9955333333, 0.7860571429, 0.196652381],
        [0.988, 0.8066, 0.1793666667],
        [0.9788571429, 0.8271428571, 0.1633142857],
        [0.9697, 0.8481380952, 0.147452381],
        [0.9625857143, 0.8705142857, 0.1309],
        [0.9588714286, 0.8949, 0.1132428571],
        [0.9598238095, 0.9218333333, 0.0948380952],
        [0.9661, 0.9514428571, 0.0755333333],
        [0.9763, 0.9831, 0.0538]]

    return LinearSegmentedColormap.from_list('parula', cm_data)

In [None]:
def _get_limits(nifti_file, only_plot_noise=False):
    if isinstance(nifti_file, str):
        nii = nb.as_closest_canonical(nb.load(nifti_file))
        data = nii.get_data()
    else:
        data = nifti_file

    data_mask = np.logical_not(np.isnan(data))

    if only_plot_noise:
        data_mask = np.logical_and(data_mask, data != 0)
        vmin = np.percentile(data[data_mask], 0)
        vmax = np.percentile(data[data_mask], 61)
    else:
        vmin = np.percentile(data[data_mask], 0.5)
        vmax = np.percentile(data[data_mask], 99.5)

    return vmin, vmax


def _bbox(img_data, bbox_data):
    B = np.argwhere(bbox_data)
    (ystart, xstart, zstart), (ystop, xstop, zstop) = B.min(0), B.max(0) + 1
    return img_data[ystart:ystop, xstart:xstop, zstart:zstop]

In [None]:
plot_spikes(subject_scan_path, out_fft, spikes_list)

In [None]:
out_spikes

In [None]:
for i in range(3):
        print i