In [None]:
import os, sys, pickle, json, collections, re, subprocess, functools, shutil

pkg_path = os.path.abspath('/home/jupyter/code')
if pkg_path not in sys.path:
    sys.path.insert(0, pkg_path)

import numpy as np    
import matplotlib.pylab as plt
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable


import torch

from optical_electrophysiology import load_file, detrend_linear
from visualization import imshow, plot_image_label_overlay, make_video_ffmpeg
from utility import get_topk_indices, get_cor_map_4d

use_gpu = True
if use_gpu and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
%load_ext autoreload
%autoreload 2

In [9]:
command = ['gsutil', 'ls', 'gs://broad-opp-voltage/sami_2015/processed']
response = subprocess.run(command, capture_output=True)
assert response.returncode == 0

In [None]:
command = ['gsutil', 'ls', 'gs://broad-opp-voltage/sami_2015']
response = subprocess.run(command, capture_output=True)
assert response.returncode == 0

filepaths = sorted([f for f in response.stdout.decode().split() if re.search('.bin$', f)])
data_folder = '.'
n_frames = 7499
width = 512
height = 180
start_segment = 0
end_segment = 10
period = 750
start = 70
end = 720
skip = 1
train_idx = np.array(list(range(170))+list(range(550, 630)))
linear_order = 3
input_transformation = None

In [None]:
def preprocess_detrend(Mat, save_folder=None):
    spatial_mean = Mat.mean((1,2)).cpu().numpy()
    selected = functools.reduce(lambda x, y: x+y, 
                                [list(range(start-skip+period*i, end-skip+period*i)) for i in range(start_segment, end_segment)]) 
    trimmed = sorted(set(range(len(Mat))).difference(selected))
    Mat = torch.stack([Mat[start-skip+period*i:end-skip+period*i] for i in range(start_segment, end_segment)], dim=0)

    Mat = torch.stack([torch.stack(detrend_linear(mat, train_idx=train_idx, linear_order=linear_order, return_trend=True, 
                                                  input_transformation=input_transformation), dim=0) for mat in Mat])
    trend = Mat[:, 1].reshape(-1, height, width)
    mat = Mat[:, 0].reshape(-1, height, width)

    fig, axes = plt.subplots(3, 1, sharex=True, figsize=(20, 20))
    axes[0].set_title('Spatial mean fluorescence intensity with detrending')
    axes[0].scatter(selected, spatial_mean[selected], c='g', label='selected', s=1)
    axes[0].scatter(trimmed, spatial_mean[trimmed], c='b', label='trimmed', s=1)
    axes[0].legend()
    axes[0].set_xlabel('frame')
    axes[0].set_ylabel('mean fluorescence intensity')
    axes[1].scatter(selected, spatial_mean[selected], c='g', label='original', s=1)
    axes[1].scatter(selected, trend.mean((1,2)).cpu(), c='k', label='trend', s=1)
    axes[1].legend()
    axes[1].set_xlabel('frame')
    axes[1].set_ylabel('mean fluorescence intensity')
    axes[2].scatter(selected, mat.mean((1, 2)).cpu(), c='r', label='detrended', s=1)
    axes[2].legend()
    axes[2].set_xlabel('frame')
    axes[2].set_ylabel('mean fluorescence intensity')
    if save_folder is not None:
        plt.savefig(f'{save_folder}/spatial_mean_with_detrending.png')
        plt.close()
        np.save(f'{save_folder}/mat.npy', mat.cpu().numpy())
        np.save(f'{save_folder}/trend.npy', trend.cpu().numpy())
    else:
        plt.show()
    return mat, trend

def preprocess_cor_map(mat, trend, save_folder=None):
    cor_mat, cor_mat_all = get_cor_map_4d(mat.view(end_segment-start_segment, -1, height, width), top_cor_map_percentage=20, padding=2, 
                             shift_times=[0, 1, 2], select_frames=True, return_all=True, plot=False)
    cor_trend = get_cor_map_4d(trend.view(end_segment-start_segment, -1, height, width), top_cor_map_percentage=20, padding=2, 
                               shift_times=[0, 1, 2], select_frames=False, return_all=False, plot=False)
    fig, axes = plt.subplots(2, 2, figsize=(20, 10))
    im = axes[0, 0].imshow(mat.mean(0).cpu().numpy(), origin='lower')
    axes[0, 0].set_title('Temporal mean fluorescence intensity of detrended video')
    divider = make_axes_locatable(axes[0, 0])
    cax = divider.append_axes("right", size="2%", pad=0.1)
    fig.colorbar(im, cax=cax)
    im = axes[1, 0].imshow(trend.mean(0).cpu().numpy(), origin='lower')
    axes[1, 0].set_title('Temporal mean fluorescence intensity of trend video')
    divider = make_axes_locatable(axes[1, 0])
    cax = divider.append_axes("right", size="2%", pad=0.1)
    fig.colorbar(im, cax=cax)
    im = axes[0, 1].imshow(cor_mat.cpu().numpy(), origin='lower')
    axes[0, 1].set_title('Correlation map of detrended video')
    divider = make_axes_locatable(axes[0, 1])
    cax = divider.append_axes("right", size="2%", pad=0.1)
    fig.colorbar(im, cax=cax)
    im = axes[1, 1].imshow(cor_trend.cpu().numpy(), origin='lower')
    axes[1, 1].set_title('Correlation map of trend video')
    divider = make_axes_locatable(axes[1, 1])
    cax = divider.append_axes("right", size="2%", pad=0.1)
    fig.colorbar(im, cax=cax)
    fig.tight_layout()
    if save_folder is not None:
        plt.savefig(f'{save_folder}/temporal_mean_and_correlation_map.png')
        np.save(f'{save_folder}/cor_mat.npy', cor_mat.cpu().numpy())
        plt.close()
    else:
        plt.show()
    return cor_mat_all

In [None]:
for filepath in filepaths:
    exp_id = filepath.split('/')[-1].strip('.bin')
    print(exp_id)
    save_folder = f'{data_folder}/{exp_id}'
    if not os.path.exists(save_folder):
        print(f'Create {save_folder}')
        os.makedirs(save_folder)
    command = ['gsutil', '-m', 'cp', filepath, data_folder]
    response = subprocess.run(command, capture_output=True)
    assert response.returncode == 0
    
    Mat = load_file(f'{data_folder}/{exp_id}.bin', size=(n_frames, height, width))
    mat, trend = preprocess_detrend(Mat, save_folder=save_folder)
    del Mat
    torch.cuda.empty_cache()
    cor_mat_all = preprocess_cor_map(mat, trend, save_folder=save_folder)
    cor_mat_all = cor_mat_all.view(-1, height, width)
    np.save(f'{save_folder}/features.npy', 
            torch.stack([mat.mean(0), mat.std(0), cor_mat_all.mean(0), cor_mat_all.std(0)], dim=0).cpu().numpy())
    np.save(f'{save_folder}/feature_names.npy', ['mean', 'std', 'cor_mean', 'cor_std'])
    command = ['gsutil', '-m', 'cp', '-r', save_folder, f'gs://broad-opp-voltage/sami_2015/processed/{exp_id}']
    response = subprocess.run(command, capture_output=True)
    assert response.returncode == 0

    os.remove(f'{data_folder}/{exp_id}.bin')
    shutil.rmtree(save_folder)

    del mat, trend
    torch.cuda.empty_cache()