In [None]:
import random, os, sys, re, math, functools, itertools, collections, time, pickle, io
import cv2
sys.path.append('/home/jupyter/code')

import numpy as np
import pandas
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import colors as mcolors
from skimage.filters import threshold_otsu
from PIL import Image, ImageSequence
from sklearn.decomposition import FastICA, PCA
from IPython.display import HTML
from scipy.io import loadmat

import torch
import torch.nn as nn

from utility import densenet_regression, linear_regression, power_series, neighbor_cor, cosine_similarity
from utility import weighted_mse_loss, empty_cache, detect_outliers, svd, get_label_image
from visualization import plot_tensor, imshow, plot_cdf, plot_3d_scatter, plot_image, plot_images, plot_curves
from visualization import plot_hist, get_image, make_video, make_3d_video, plot_trace, plot_image_label_overlay
from models import UNet, MultiConv, get_mask, get_bg_mat, restore_image_noise2self
from nmf import non_negative_factorization
from pmd import denoise, total_variation, second_order_difference, pmd_compress, rank_one_decomposition 
from pmd import get_threshold
from optical_electrophysiology import load_mat, detrend, extract_super_pixels, get_submat_traces, prep_train_data
from optical_electrophysiology import detrend_high_magnification, load_file, get_size_from_txt
from optical_electrophysiology import refine_segmentation, detrend_linear, extract_single_trace, extract_traces
from train import train_model, step_decompose, rank_k_decompose

use_gpu = True
if use_gpu and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
colors = sorted(mcolors.CSS4_COLORS)
random.shuffle(colors)
colors = sorted(mcolors.BASE_COLORS) + colors
bad_colors = '|'.join(['white', 'light', 'gray', 'mint', 'rebeccapurple', 'steelblue', 'darkkhaki', 'ivory',  
                       'cornsilk', 'honeydew', 'peru', 'alice', 'azure'])
colors = [i for i in colors if i!='w' and not re.search(bad_colors, i)]
get_cm = lambda sel_colors: LinearSegmentedColormap.from_list('cmap_name', sel_colors, N=len(sel_colors))

plt.rcParams['figure.figsize'] = 20, 15

%load_ext autoreload
%autoreload 2

# dataset_name = 'low_mag_cultured_neurons_2015-12-18'
# dataset_name = 'pooled_ipsc_2018-09-17'
# dataset_name = 'pooled_ipsc_2018-12-07'
# dataset_name = 'high_mag_adrenal_cortex'
# dataset_name = 'high_mag_beta_cell'
# dataset_name = 'low_mag_beta_cell'
plot = False
figsize = (20, 15)
random_file = True
no_detrend = False


def predict(x, model, filepath=None, return_detached=True, device=torch.device('cuda')):
    if filepath is not None and os.path.exists(filepath):
        model.load_state_dict(torch.load(filepath))
    with torch.no_grad():
        y = model(x)
    if return_detached:
        y = y.detach()
    for k in [k for k in locals().keys if k!='y']:
        del locals()[k]
    torch.cuda.empty_cache()
    return y
    
def denoise_trace(trace, model=None, filepath='checkpoints/denoise_trace.pt', return_detached=True, 
                  device=torch.device('cuda')):
    if model is None:
        model = UNet(in_channels=1, num_classes=1, out_channels=[8, 16, 32], num_conv=2, 
                     n_dim=1, kernel_size=3).to(device)
        model.load_state_dict(torch.load(filepath))
    with torch.no_grad():
        mean = trace.mean()
        std = trace.std()
        pred = model((trace-mean)/std)
        pred = model(pred)
        pred = pred * std + mean
    if return_detached:
        pred = pred.detach()
    for k in [k for k in locals().keys() if k!='pred']:
        del locals()[k]
    torch.cuda.empty_cache()
    return pred

def denoise_3d(mat, model=None, filepath='checkpoints/3d_denoise.pt', return_detached=True, 
               batch_size=5000, device=torch.device('cuda')):
    if model is None:
        model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, 
                     kernel_size=[3, 3, 3], same_shape=True).to(device)
        model.load_state_dict(torch.load(filepath))
    with torch.no_grad():
        num_batches = (mat.size(0) + batch_size - 1)//batch_size
        mat = torch.cat([model(mat[batch_size*i:batch_size*(i+1)]) for i in range(num_batches)], dim=0)
    if return_detached:
        mat = mat.detach()
    for k in [k for k in locals().keys() if k!='mat']:
        del locals()[k]
    torch.cuda.empty_cache()
    return mat

def attention_map(mat, model=None, filepath='checkpoints/segmentation_count_hardmask.pt', 
                  batch_size=5000, return_detached=True, device=torch.device('cuda')):
    if model is None:
        model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, 
                     kernel_size=[3, 3, 3], same_shape=True).to(device)
        model.load_state_dict(torch.load(filepath))
    nrow, ncol = mat.shape[1:]
    if batch_size*nrow*ncol > 1e7:
        batch_size = int(1e7 / (nrow*ncol))
    with torch.no_grad():
        num_batches = (mat.size(0) + batch_size - 1)//batch_size
        mat = torch.cat([model(mat[batch_size*i:batch_size*(i+1)]) for i in range(num_batches)], dim=0).mean(0)
    if return_detached:
        mat = mat.detach()
    for k in [k for k in locals().keys() if k!='mat']:
        del locals()[k]
    torch.cuda.empty_cache()
    return mat

def refine_one_label(submat, min_pixels=50, return_traces=False, percentile=50):
    soft_attention = attention_map(submat)
    label_image, regions = get_label_image(soft_attention, min_pixels=min_pixels)
    if return_traces:
        submats, traces = extract_traces(submat, softmask=soft_attention, label_image=label_image, regions=regions, 
                                         percentile=percentile)
        return submats, traces, soft_attention, label_image, regions
    else:
        return label_image

def refine_segmentation(submats, regions, label_image, min_pixels=50, connectivity=None):
    for label_idx in range(1, len(submats)+1):
        submat = submats[label_idx-1]
        minr, minc, maxr, maxc = regions[label_idx-1].bbox
        img = refine_one_label(submat, min_pixels=min_pixels)
        label_image[minr:maxr, minc:maxc] = img
    from skimage.measure import label, regionprops
    label_image = label(label_image>0, connectivity=connectivity)
    regions = regionprops(label_image)
    return label_image, regions

def zoom_in(seq, batch_size=700, seq2=None, figsize=(15, 10), label1='seq1', label2='seq2'):
    if isinstance(seq, torch.Tensor):
        seq = seq.detach().cpu().numpy()
        if seq2 is not None and isinstance(seq2, torch.Tensor):
            seq2 = seq2.detach().cpu().numpy()
    plt.figure(figsize=figsize)
    plt.plot(seq, 'b-', alpha=0.5, label=label1)
    if seq2 is not None:
        plt.plot(seq2, 'k--', alpha=0.5, label=label2)
        plt.legend()
    plt.show()
    num_batches = (len(seq)+batch_size-1) // batch_size
    for i in range(num_batches):
        plt.figure(figsize=figsize)
        plt.plot(seq[i*batch_size:(i+1)*batch_size], 'b-', alpha=1, label='seq1')
        if seq2 is not None:
            plt.plot(seq2[i*batch_size:(i+1)*batch_size], 'k--', alpha=0.5, label='seq2')
            plt.legend()
        plt.show()
        
def plot_pca_result(W, H, percentile=99.5, cor_map=None, title=None):
    from skimage.measure import regionprops
    num = W.shape[0]
    label_image = np.zeros((512, 180))
    for i in range(num):
        mask = np.abs(W[i])
        label_image[mask > np.percentile(mask, percentile)] = i+1
    label_image = label_image.T.astype('int')
    regions = regionprops(label_image)
    imshow(label_image)
    plot_image_label_overlay(np.zeros((180, 512)) if cor_map is None else cor_map, label_image=label_image, regions=regions)
    plt.figure(figsize=(20, 15))
    for i in range(num):
        trace = H[-700:, i] + i*0.5
        plt.plot(trace, label=i+1)
    plt.legend()
    if title is not None:
        plt.title(title)
    plt.show()

In [None]:
folder = '/home/jupyter/disk/data/sami/Adrenal Cortex'
filepath = f'{folder}/floxopatch_glomerulous_20hz.tif'
im = Image.open(filepath)
array = np.array([np.array(page) for page in ImageSequence.Iterator(im)])
mat = torch.from_numpy(array.astype('float32')).to(device)

In [None]:
plot_pca_result(np.transpose(A, (2,0,1)), C)

In [None]:
plot_pca_result(f'{folder}/pca_ica_out.mat')
A = res['out'] 
C = res['vica']
for i in range(A.shape[-1]):
    imshow(A[:, :, i])
    plt.plot(C[:, i])
    plt.title(i)
    plt.show()

# KNN attention pooling: prior (location + correlation) + dynamic distance 

# Train a local transfer model

In [None]:
class ConvLayer(nn.Module):
    r"""Applies two convolutions over an input of signal composed of several input planes.
    
    Args:
        in_channels (int): Number of input channels
    
    Shape:
        Input: :math:`(N, in_channels, H, W)`
        Output: :math:`(N, out_channels, H_{out}, W_{out})`
    
    Attributes:
        weight (Tensor): 
        bias (Tensor): 
        
    Examples::
    
        >>> x = torch.randn(2, 3, 5, 7)
        >>> model = MultiConv(3, 11)
        >>> model(x).shape
    
    """
    def __init__(self, in_channels, out_channels, n_dim=2, kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=True,
                 padding_mode='replicate', normalization='layer_norm', activation=nn.LeakyReLU(negative_slope=0.01, inplace=True)):
        super(ConvLayer, self).__init__()
        self.padding = padding
        self.padding_mode = padding_mode
        self.n_dim = n_dim
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size] * n_dim
        if isinstance(dilation, int):
            dilation = [dilation] * n_dim
        if isinstance(stride, int):
            self.stride = [stride] * n_dim
        self.effective_kernel_size = [dilation[i] * (kernel_size[i]-1) + 1 for i in range(n_dim)]
        if n_dim == 3:
            Conv = nn.Conv3d
        elif n_dim == 2:
            Conv = nn.Conv2d
        elif n_dim == 1:
            Conv = nn.Conv1d
        self.conv = Conv(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=1, bias=bias, 
                         padding_mode='zeros')
        if normalization == 'layer_norm':
            num_groups = 1
        elif normalization == 'instance_norm':
            num_groups = out_channels
        elif isinstance(normalization, int):
            num_groups = normalization
        else:
            raise ValueError(f'normalization = {normalization} not defined!')
        self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
        self.activation = activation
        
    def forward(self, x):
        for _ in range(0, self.n_dim + 2 - x.ndim, 1):
            x = x.unsqueeze(0)
        if self.padding:
            padding = []
            for L, k, s in zip(x.shape[2:], self.effective_kernel_size, self.stride):
                if k >= s:
                    tmp = L - math.floor((L-k)/s) * s
                    if tmp == 0:
                        padding.append(0)
                    else:
                        padding.append(s+k-tmp)
                else:
                    padding.append(0)
            expanded_padding = []
            for p in padding:
                expanded_padding = [(p+1)//2, p//2] + expanded_padding
            if sum(expanded_padding) > 0:
                x = nn.functional.pad(x, expanded_padding, mode=self.padding_mode)
        x = self.activation(self.norm(self.conv(x)))
        return x

In [None]:
class ConvNet(nn.Module):
    def __init__(self, num_targets=1, in_channels=1, out_channels=[8, 16, 32, 64, 128], kernel_sizes=[3, 5, 5, 5, 5], 
                 strides=[1, 1, 2, 2, 2], dilations=[1, 1, 2, 3, 3], n_dim=1, padding=True):
        super(ConvNet, self).__init__()
        self.num_layers = len(out_channels)
        self.layers = nn.ModuleList()
        for i in range(self.num_layers):
            self.layers.append(
                ConvLayer(in_channels=in_channels if i==0 else out_channels[i-1], out_channels=out_channels[i], n_dim=n_dim, 
                          kernel_size=kernel_sizes[i], stride=strides[i], dilation=dilations[i], padding=padding)
            )
        self.linear = nn.Linear(out_channels[-1], num_targets)
    def forward(self, x):
        for i in range(self.num_layers):
            x = self.layers[i](x)
        x = self.linear(x.mean(-1))
        return x

In [None]:
model = ConvNet(padding=False)
cnt = 0
for n, p in model.named_parameters():
    cnt += p.numel()
print(cnt)
x = torch.randn(300, 1, 700)

In [None]:
filepath = 'checkpoints/segmentation_count_hardmask.pt'
model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, 
             kernel_size=[3, 3, 3], same_shape=True).to(device)
model.load_state_dict(torch.load(filepath))

In [None]:
folder = '/home/tma/projects/optical_profiling/adam'
exp_ids = [f for f in os.listdir(folder) if os.path.isdir(f'{folder}/{f}')]
exp_id = exp_ids[0]
f = open(f'{folder}/{exp_id}/experimental_parameters.txt', 'r')
param_str = f.read().split('\n')
f.close()
params = [s.split('\t') for s in param_str if re.search('Horizontal pixel|Vertical pixel', s)]
for p in params:
    if re.search('Horizontal pixel', p[0]):
        ncol = int(p[1])
    elif re.search('Vertical pixel', p[0]):
        nrow = int(p[1])
mat = load_file(f'{folder}/{exp_id}/Sq_camera.bin', size=(-1, nrow, ncol))
mat = detrend_linear(mat[1000:])
cor_map = torch.stack([neighbor_cor(mat=mat[4000*i:4000*(i+1)], neighbors=8, choice='mean', nonnegative=True) 
           for i in range(16)], dim=0).mean(0)

In [None]:
folder = '/home/jupyter/disk/data/sami/low_mag_cultured_neurons/'
with open(os.path.join(folder, 'meta_data.pkl'), 'rb') as f:
    meta_data = pickle.load(f)
    exp_id_dict = {s[:10]:s for s in sorted(meta_data)}
all_exp_ids = sorted(meta_data)
trace_idx = {n: i for i, n in enumerate(sorted(meta_data))}

if not (os.path.exists('train_data/label_images.npz') and os.path.exists('train_data/traces.pkl')):
    all_traces = []
    all_images = []
    cor_maps = []
    for exp_id in sorted(meta_data):
        print(exp_id)
        mat = load_mat(exp_id, meta_data, folder, device=device)
        # the first frame of 7500 frames is missing; skip the initial 50 frames
        mat_list = [mat[799+750*i:1499+750*i] for i in range(9)]
        train_idx = list(range(200)) + list(range(550, 700))
        mat_adj = [detrend_linear(mat=m, train_idx=train_idx) for m in mat_list]
        mat = torch.cat(mat_adj, dim=0)
        del mat_adj, mat_list
        torch.cuda.empty_cache()
        cor_map = neighbor_cor(mat=mat, neighbors=8, choice='mean', nonnegative=True)
        label_image, regions = get_label_image(cor_map, min_pixels=50)
        submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
        label_image, regions = refine_segmentation(submats, regions, label_image)
        submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
        all_traces.append(torch.stack(traces, dim=0).cpu().numpy())
        all_images.append(label_image)
        cor_maps.append(cor_map.detach().cpu().numpy())
    all_images = np.stack(all_images, axis=0)
    cor_maps = np.stack(cor_maps, axis=0)
    np.savez('train_data/label_images.npz', all_images=all_images, cor_maps=cor_maps)
    with open('train_data/traces.pkl', 'wb') as f:
        pickle.dump(all_traces, f)
npzfile = np.load('train_data/label_images.npz')
all_images = npzfile['all_images']
cor_maps = npzfile['cor_maps']
with open('train_data/traces.pkl', 'rb') as f:
    all_traces = pickle.load(f)

In [None]:
exp_id = 'D1_FOV3_W2_at135105'
mat = load_mat(exp_id, meta_data, folder, device=device)
# the first frame of 7500 frames is missing; skip the initial 50 frames
mat_list = [mat[799+750*i:1499+750*i] for i in range(9)]
train_idx = list(range(200)) + list(range(550, 700))
mat_adj = [detrend_linear(mat=m, train_idx=train_idx) for m in mat_list]
mat = torch.cat(mat_adj, dim=0)
del mat_adj, mat_list
torch.cuda.empty_cache()
cor_map = neighbor_cor(mat=mat, neighbors=8, choice='mean', nonnegative=True)
label_image, regions = get_label_image(cor_map, min_pixels=50)
submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
label_image, regions = refine_segmentation(submats, regions, label_image)
submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
plot_image_label_overlay(cor_map, label_image=label_image, regions=regions)

In [None]:
sel_label = 1
submat = submats[sel_label-1]
region = regions[sel_label-1]
minr, minc, maxr, maxc = region.bbox
trace = traces[sel_label-1]
imshow(submat.mean(0))
k = 8
plt.plot(trace[700*k:700*(k+1)].cpu(), '-.')
trace2 = denoise_trace(trace[700*k:700*(k+1)])
plt.plot(trace2.cpu(), '-')
plt.show()

In [None]:
R = submat.reshape(submat.size(0), -1)

In [None]:
U, V, n_iter = non_negative_factorization(R, init=None)

In [None]:
from sklearn.decomposition import NMF
model_nmf = NMF(n_components=5, init='random', random_state=0)

In [None]:
W = model_nmf.fit_transform((R-R.min()).detach().cpu())

In [None]:
H = model_nmf.components_

In [None]:
k = 8
for i in range(5):
    imshow(H[i].reshape(submat.shape[1:]))
    plt.plot(W[700*k:700*(k+1), i], '-.')
    trace2 = denoise_trace(torch.from_numpy(W[700*k:700*(k+1), i]).float().to(device))
    plt.plot(trace2.cpu(), '-')
    plt.show()

In [None]:
torch.norm(R - torch.mm(U,V)) / torch.norm(R)

In [None]:
cor = cor_map[minr:maxr, minc:maxc]
label_sub_img = label_image[minr:maxr, minc:maxc]
high_threshold = cor[(label_sub_img == sel_label).tolist()].median()
low_threshold = cor[(label_sub_img == 0).tolist()].median()
target = cor.new_zeros(cor.shape)
target[cor > high_threshold] = 1
mask = ((cor > high_threshold) | (cor < low_threshold)).float()

In [None]:
plot_hist(cor.reshape(-1).sort()[0].cpu(), n_bins=100)
plt.axvline(x=threshold_otsu(cor.cpu().numpy()))
plt.axvline(x=low_threshold.cpu().numpy())
plt.axvline(x=high_threshold.cpu().numpy())

In [None]:
plot_image_label_overlay(cor_map, label_image=label_image, regions=regions)
plt.figure(figsize=(20, 20))
for j, trace in enumerate(traces):
    k = 8
    trace = trace[700*k:700*(k+1)]
    trace2 = denoise_trace(trace)
    plt.plot(trace.cpu() - j*200, '-.', c=colors[j], label=j+1)
    plt.plot(trace2.cpu() - j*200, '-', c=colors[j], label=j+1)
plt.legend()
plt.title(exp_id)
plt.show()

In [None]:
cor_map = neighbor_cor(mat=mat, neighbors=8, choice='mean', nonnegative=True)
label_image, regions = get_label_image(cor_map, min_pixels=50)
submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
label_image, regions = refine_segmentation(submats, regions, label_image, min_pixels=50)
submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
plot_image_label_overlay(cor_map, label_image=label_image, regions=regions)
plt.figure(figsize=(20, 20))
for j, trace in enumerate(traces):
    k = 8
    trace = trace[700*k:700*(k+1)]
    trace2 = denoise_trace(trace)
    plt.plot(trace.cpu() + j*200, '-.', c=colors[j], label=j+1)
    plt.plot(trace2.cpu() + j*200, '-', c=colors[j], label=j+1)
plt.legend()
plt.title(exp_id)
plt.show()

# Current pipeline

In [None]:
model = UNet(in_channels=1, num_classes=1, out_channels=[8, 16, 32], num_conv=2, 
             n_dim=1, kernel_size=3).to(device)
filepath = 'checkpoints/denoise_trace.pt'
model.load_state_dict(torch.load(filepath))
for i in range(len(all_traces)):
    plot_image_label_overlay(cor_maps[i], label_image=all_images[i], title=all_exp_ids[i])
    plt.figure(figsize=(20, 20))
    for j, trace in enumerate(all_traces[i]):
        k = 8
        trace = trace[700*k:700*(k+1)]
        trace2 = denoise_trace(torch.from_numpy(trace).to(device))
        plt.plot(trace + j*200, '-.', c=colors[j], label=j+1)
        plt.plot(trace2.cpu() + j*200, '-', c=colors[j], label=j+1)
    plt.legend()
    plt.title(all_exp_ids[i])
    plt.show()

In [None]:
exp_id = sorted(meta_data)[np.random.choice(len(meta_data))]
print(exp_id)
mat = load_mat(exp_id, meta_data, folder, device=device)
# the first frame of 7500 frames is missing; skip the initial 50 frames
mat_list = [mat[799+750*i:1499+750*i] for i in range(9)]
train_idx = list(range(200)) + list(range(550, 700))
mat_adj = [detrend_linear(mat=m, train_idx=train_idx) for m in mat_list]
mat = torch.cat(mat_adj, dim=0)
del mat_adj, mat_list
torch.cuda.empty_cache()
cor_map = neighbor_cor(mat=mat, neighbors=8, choice='mean', nonnegative=True)
label_image, regions = get_label_image(cor_map, min_pixels=50)
submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)
label_image, regions = refine_segmentation(submats, regions, label_image)
submats, traces = extract_traces(mat, softmask=cor_map, label_image=label_image, regions=regions, percentile=50)

In [None]:
plot_image_label_overlay(cor_map, label_image=label_image, regions=regions)

In [None]:
label_idx = 6
submats, traces, soft_attention, label_image, regions = refine_one_label(submats[label_idx-1], min_pixels=10, return_traces=True)
imshow(soft_attention)

In [None]:
label_idx = 1
trace = traces[label_idx-1]
pred = denoise_trace(trace)
plot_image_label_overlay(soft_attention, label_image, regions=regions, sel_idx=label_idx-1)
zoom_in(pred, seq2=trace)

In [None]:
num_labels = len(traces)
for label_idx in range(1, num_labels+1):
    submat = submats[label_idx-1]
    trace = traces[label_idx-1]
    pred = denoise_trace(trace)
    plot_image_label_overlay(cor_map, label_image, regions=regions, sel_idx=label_idx-1)
    plt.figure(figsize=figsize)
    plt.plot(trace.cpu(), linestyle='--', linewidth=1, color='k', markersize=2, alpha=0.5, label='raw trace')
    plt.plot(pred.detach().cpu(), '-', color='b', alpha=1, label='denoised')
    plt.legend()
    plt.title(label_idx)
    plt.show()

In [None]:
##### Low magnification cultured neurons  
if dataset_name == 'low_mag_cultured_neurons_2015-12-18':
    folder = '/home/tma/projects/optical_profiling/data'
    with open(os.path.join(folder, 'meta_data.pkl'), 'rb') as f:
        meta_data = pickle.load(f)
        exp_id_dict = {s[:10]:s for s in sorted(meta_data)}
#     with open(os.path.join(folder, 'videos.pkl'), 'rb') as f:
#         videos = pickle.load(f)
    if random_file:
        exp_id = sorted(meta_data)[np.random.choice(len(meta_data))]
    # exp_id = 'D1_FOV3_W2_at135105'
    # exp_id = 'F1_FOV3_W1_at143440'
    # exp_id = 'D1_FOV3_W1_at135000'
    # exp_id = 'D1_FOV2_W2_at134631'
    # exp_id = 'E3_FOV2_W2_at155722'
    # exp_id = 'E3_FOV2_W1_at155622'
    period = 750
    signal_start = 250
    signal_end = 500
    shift = -1
    num_segments = 10
    skip_segments = 1
    trim_size_left = 75
    trim_size_right = 25
    train_left_skip = 0 # skip from right
    train_right_skip = 75 # skip from left
    
#     # PCA/ICA result from Sami
#     mat = loadmat(f'{folder}/D1_FOV3_W2_at135105_segmentation.mat')
#     W = mat['sourceImage_out'].reshape(-1, 512, 180)
#     H = mat['trace_out']
#     for i in range(42):
#         imshow(W[i].T)
#         plt.figure(figsize=figsize)
#         plt.plot(H[:, i])
#         plt.title(i)
#         plt.show()

if dataset_name == 'pooled_ipsc_2018-09-17':
    root = '/home/tma/tma-disk/sami/2018-09-17'
    use_worldstar_only = True
    folders = [s for s in os.listdir(root) if os.path.isdir(f'{root}/{s}') and s not in ['AnalysisCodeExternal', 
                                                                                         'Coke can movies']]
    if use_worldstar_only:
        folders = [f for f in folders if re.search('worldstar$', f)]
    if random_file:
        folder = folders[np.random.choice(len(folders))]
#     folder = '164231_FOV3_DIV14_MOI10_FOV7_200ms_2OD_v2_worldstar'
    folder = '165201_FOV3_DIV14_MOI5_FOV2_200ms_2OD_v3_worldstar'
    print(folder)
    nrow, ncol = 180, 300
    size = [-1, nrow, ncol]
    filepath = f'{root}/{folder}/movie.bin'
    mat = load_file(filepath, size)
    L = mat.size(0)
    if plot:
#         daq = np.loadtxt(f'{root}/{folder}/movie_DAQ.txt', skiprows=1)
#         x = loadmat(os.path.join(root, 'MOI10_img.mat'))['MOI10_imgs']
#         y = loadmat(os.path.join(root, 'MOI10_traces.mat'))['MOI10_traces']
#         for i in range(x.shape[2]):
#             imshow(x[:,:,i], cmap='binary')
#             plt.figure(figsize=(20, 10))
#             plt.plot(y[:, i], 'o--', linewidth=1, markersize=2)
#             plt.title(i)
#             plt.show()
        plt.figure(figsize=(20, 10))
        plt.plot(mat.mean(-1).mean(-1).cpu())
        plt.show()
        extract_super_pixels(mat_adj=None, test_left=None, test_right=None, mat_cat=mat, plot=True)
    period = 500
    signal_start = 0
    signal_end = 100
    shift = 0
    trim_size_left = 0
    trim_size_right = 0
    num_segments = L // period
    skip_segments = 1
    train_left_skip = 0 # skip from right
    train_right_skip = 50 # skip from left

if dataset_name in ['pooled_ipsc_2018-12-07', 'low_mag_beta_cell']:
    remove_outliers = True
    if dataset_name == 'pooled_ipsc_2018-12-07':
        root = '/home/tma/tma-disk/sami/pooled_ipsc_2018-12-07'
    if dataset_name == 'low_mag_beta_cell':
        root = '/home/tma/tma-disk/sami/2018-12-21 Beta cells'    
    filenames = [f[:-4] for f in os.listdir(root) if re.search('.bin$', f)]
    if plot and len(filenames) < 10:
        for filename in filenames:
            size = get_size_from_txt(f'{root}/{filename}.txt')
            mat = load_file(f'{root}/{filename}.bin', size=size)
            plt.figure(figsize=(20, 10))
            plt.plot(mat.mean(-1).mean(-1).cpu())
            plt.title(f'{filename}: {size}')
            plt.show()
            extract_super_pixels(mat_adj=None, test_left=None, test_right=None, mat_cat=mat, plot=True)
    filename = 'D2_NoBlue_highG_3p5V_max_FOV4_at154522' if dataset_name=='low_mag_beta_cell' else filenames[0]
    if random_file:
        filename = filenames[np.random.choice(len(filenames))]
        print(filename)
    size = get_size_from_txt(f'{root}/{filename}.txt')
    mat = load_file(f'{root}/{filename}.bin', size=size)
    if remove_outliers:
        frame_mean = mat.mean(-1).mean(-1).cpu()
        mask = detect_outliers(frame_mean)
        mat = mat[mask]
        if np.sum(mask) < len(mat):
            print(f'Removed {len(mat) - np.sum(mask)} outlier frames')  

if dataset_name == 'high_mag_adrenal_cortex':
    folder = '/home/tma/tma-disk/sami/Adrenal Cortex'
    filepath = f'{folder}/floxopatch_glomerulous_20hz.tif'
    im = Image.open(filepath)
    array = np.array([np.array(page) for page in ImageSequence.Iterator(im)])
    mat = torch.from_numpy(array.astype('float32')).to(device)
    if plot:
        plt.plot(mat.mean(-1).mean(-1).cpu())
        plt.show()
        extract_super_pixels(mat_adj=None, test_left=None, test_right=None, mat_cat=mat, plot=True)
        res = loadmat(f'{folder}/pca_ica_out.mat')
        A = res['out'] 
        C = res['vica']
        for i in range(A.shape[-1]):
            imshow(A[:, :, i])
            plt.plot(C[:, i])
            plt.title(i)
            plt.show()

if dataset_name == 'high_mag_beta_cell':
    root = '/home/tma/tma-disk/sami/2018-09-21 DRH347 test on screening rig'
    folders = [s for s in os.listdir(root) if os.path.isdir(f'{root}/{s}') and s != 'AnalysisCodeExternal']
    folder = folders[0]
    if random_file:
        folder = folders[np.random.choice(len(folders))]
    print(folder)
    nrow, ncol = 90, 150
    size = [-1, nrow, ncol]
    filepath = f'{root}/{folder}/movie.bin'
    mat = load_file(filepath, size)
    L = mat.size(0)
    frame_mean = mat.mean(-1).mean(-1)
    sel_idx = detect_outliers(frame_mean, whis=5, return_outliers=False)
    if np.sum(sel_idx) < L:
        print(f'Detect {L - np.sum(sel_idx)} outliers and remove these {L - np.sum(sel_idx)} '
              'frames from downstream analysis')
        mat = mat[sel_idx]
    if plot:
    #     daq = np.loadtxt(f'{root}/{folder}/movie_DAQ.txt', skiprows=1)
        plt.figure(figsize=figsize)
        plt.plot(mat.mean(-1).mean(-1).cpu())
        plt.show()
        extract_super_pixels(mat_adj=None, test_left=None, test_right=None, mat_cat=mat, plot=True)
        
if dataset_name in ['low_mag_cultured_neurons_2015-12-18', 'pooled_ipsc_2018-09-17']:
    signal_length = signal_end - signal_start
    left0 = trim_size_left + shift
    right0 = period - trim_size_right + shift
    left1 = signal_start + shift
    right1 = signal_end + shift
    train_size_left = left1 - left0 - train_left_skip
    train_size_right = right0 - right1 - train_right_skip
    seg_idx = range(skip_segments, num_segments)
    num_seg = num_segments - skip_segments
    start0 = [period*i + left0 for i in seg_idx]
    end0 = [period*i + right0 for i in seg_idx]
    start1 = [period*i + left1 for i in seg_idx]
    end1 = [period*i + right1 for i in seg_idx]
    length0 = right0 - left0
    length1 = right1 - left1
    test_left = left1 - left0
    test_right = right1 - left0

In [None]:
i = 0
pred = denoise_trace(torch.from_numpy(C[:, i]).float().to(device))
zoom_in(seq=pred, seq2=C[:, i], batch_size=500)

In [None]:
def run_pipeline(mat, denoise=0, num_neighbors=4, cor_choice='mean', connectivity=1, use_detrend=True, weight_percentile=50, 
                 low_magnification=True, use_mean_image=False):
    if low_magnification:
        mat_adj = detrend(mat, start0, end0, train_size_left, train_size_right, linear_order=3, use_mean_bg=False, plot=plot, 
                          test_left=test_left, test_right=test_right, device=device, exp_id=exp_id, meta_data=meta_data, 
                          folder=folder, show_singular_values=False)
        cor_global, label_image, regions = extract_super_pixels(mat_adj=mat_adj, test_left=test_left, test_right=test_right, 
                                                        num_neighbors=num_neighbors, cor_choice=cor_choice, connectivity=None, 
                                                        plot=plot, use_mean_image=use_mean_image)
        mat_adj = torch.cat(mat_adj)
    else:
        L = mat.size(0)
        num_segments = L // period
        mat_adj = detrend_high_magnification(mat, skip_segments=skip_segments, num_segments=num_segments, period=period, 
                                             train_size_left=train_size_left, train_size_right=train_size_right, 
                                             linear_order=3, plot=False, signal_start=signal_start, signal_end=signal_end, 
                                             filepath=None, size=None, device=torch.device('cuda'), start0=None, end0=None, 
                                             return_mat=False)
        cor_global, label_image, regions = extract_super_pixels(mat_cat=mat_adj, mat_adj=None, test_left=None, test_right=None, 
                                                                num_neighbors=num_neighbors, cor_choice=cor_choice, 
                                                                connectivity=None, 
                                                                plot=plot, use_mean_image=use_mean_image)
    nframe, nrow, ncol = mat_adj.shape
    if denoise == 1:
        size = (nrow, ncol)
        spatial_threshold = get_threshold(size, loss_fn=total_variation)
        temporal_threshold = get_threshold(nframe, loss_fn=second_order_difference)
        U, V = pmd_compress(mat_adj.reshape(nframe, -1).T, size, max_num_fails=5, max_num_components=10, tol=1e-1, 
                    spatial_threshold=spatial_threshold, temporal_threshold=temporal_threshold, 
                    verbose=False)
        mat_adj = torch.mm(U, V).T.reshape(nframe, nrow, ncol)
    elif denoise == 2:
        A, B, loss_history = step_decompose(mat_adj.reshape(nframe, -1), num_components=10, verbose=False)
        mat_adj = torch.mm(A, B).reshape(mat_adj.shape)
    traces, submats = get_submat_traces(seg_idx=0, regions=regions, label_image=label_image, 
                            mat_adj=[mat_adj if use_detrend else torch.cat([mat[s:e] for s, e in zip(start0, end0)])], 
                            weight_percentile=weight_percentile, sig_list=None, mat_list=None, cor=cor_global, 
                            weighted_denominator=True, return_name='mat_adj', compare=False, test_left=test_left, 
                            test_right=test_right)
    return cor_global, label_image, regions, traces, submats


def zoom_in_trace(trace, markers='-', trace2=None, figsize=(20, 10), alpha1=1, alpha2=0.5):
    if isinstance(trace, torch.Tensor):
        trace = trace.squeeze().detach().cpu()
#     if 'test_left' not in globals():
#         num_seg = 10
#         length0 = int(math.ceil(len(trace) / num_seg))
#         test_left = 0
#         test_right = length0
    plt.figure(figsize=figsize)
    plt.plot(trace, markers, linewidth=1, color='b', markersize=2, label='trace 1')
    if trace2 is not None:
        if isinstance(trace2, torch.Tensor):
            trace2 = trace2.squeeze().detach().cpu()
        plt.plot(trace2, markers, linewidth=1, color='r', markersize=2, label='trace 2', alpha=0.5)
        plt.legend()
    for i in range(num_seg):
        plt.axvline(test_left+length0*i, linestyle='--', linewidth=1, color='g')
        plt.axvline(test_right+length0*i, linestyle='--', linewidth=1, color='g')
#         plt.axvline(train_size_left+length0*i, linestyle='-.', linewidth=1, color='r')
#         plt.axvline(length0-train_size_right+length0*i, linestyle='-.', linewidth=1, color='r')
    plt.title(f'Label {label_idx}')
    plt.show()
    for i in range(num_seg):
        plt.figure(figsize=figsize)
        plt.plot(trace[length0*i:length0*(i+1)].cpu(), markers, linewidth=1, color='b', markersize=2, label='trace 1', alpha=alpha1)
        if trace2 is not None:
            plt.plot(trace2[length0*i:length0*(i+1)].cpu(), markers, linewidth=1, color='r', markersize=2, label='trace 2', 
                     alpha=alpha2)
            plt.legend()
        plt.axvline(test_left, linestyle='--', linewidth=1, color='g')
        plt.axvline(test_right, linestyle='--', linewidth=1, color='g')
#         plt.axvline(train_size_left, linestyle='-.', linewidth=1, color='r')
#         plt.axvline(length0-train_size_right, linestyle='-.', linewidth=1, color='r')   
        plt.title(f'Segment {i+1}')
        plt.show()

In [None]:
exp_id = exp_id_dict['D2_FOV2_W1']
# exp_id = exp_id_dict['D1_FOV3_W1']
# exp_id = exp_id_dict['F1_FOV3_W1']
# exp_id = exp_id_dict['D1_FOV1_W1']
# exp_id = sorted(meta_data)[10]
# exp_id = sorted(meta_data)[np.random.choice(len(meta_data))]
print(exp_id)
start_time = time.time()
mat = load_mat(exp_id, meta_data, folder)
print(time.time() - start_time)

In [None]:
mat_adj = detrend(mat, start0, end0, train_size_left, train_size_right, linear_order=3, use_mean_bg=False, plot=False, 
                  test_left=test_left, test_right=test_right, device=device, show_singular_values=False)

In [None]:
mat_adj = detrend_high_magnification(mat, skip_segments=skip_segments, num_segments=num_segments, period=period, 
                                     train_size_left=train_size_left, train_size_right=train_size_right, 
                                     linear_order=3, plot=False, signal_start=signal_start, signal_end=signal_end, 
                                     filepath=None, size=None, device=torch.device('cuda'), start0=None, end0=None, 
                                     return_mat=False)

In [None]:
mat_adj, trend = detrend_linear(mat, return_trend=True)

In [None]:
plt.plot(mat.mean(-1).mean(-1).cpu(), 'r-')
plt.plot(trend.mean(-1).mean(-1).cpu(), 'b--')
plt.show()
plt.plot(mat_adj.mean(-1).mean(-1).cpu(), 'r-')

In [None]:
start_time = time.time()
cor_global, label_image, regions, traces, submats = run_pipeline(mat, num_neighbors=8, cor_choice='mean', connectivity=None, 
                                                                 use_detrend=True, weight_percentile=50, use_mean_image=False)
print(time.time() - start_time)
plot_image_label_overlay(cor_global, label_image)

In [None]:
start_time = time.time()
cor_global, label_image, regions, traces, submats = run_pipeline(mat, num_neighbors=4, cor_choice='mean', connectivity=1, 
                                                                 use_detrend=True, weight_percentile=50, low_magnification=False, 
                                                                 use_mean_image=False)
print(f'Time spent: {time.time() - start_time}')
plot_image_label_overlay(cor_global, label_image)

In [None]:
num_labels = len(traces)
for label_idx in range(1, num_labels+1):
    submat, label_mask, weight = submats[label_idx-1]
    trace = traces[label_idx-1]
    plot_image_label_overlay(cor_global, label_image, sel_idx=label_idx-1)
    plt.figure(figsize=figsize)
    plt.plot(trace.cpu(), linestyle='-', linewidth=1, color='b', markersize=2, alpha=1, label='raw trace')
#     plt.plot(pred.detach().cpu(), '-', color='r', alpha=0.5, label='denoised')
#     plt.legend()
    plt.title(label_idx)
    plt.show()

In [None]:
label_idx = 2
submat, label_mask, weight = submats[label_idx-1]
trace = traces[label_idx-1]
trace_clean = denoise_trace(trace)
plot_image_label_overlay(cor_global, label_image, sel_idx=label_idx-1)
# zoom_in_trace(trace, markers='-', figsize=(20, 15))
plt.figure(figsize=(20,15))
plt.plot(trace.cpu(), 'k--', alpha=0.5, label='raw')
plt.plot(trace_clean.detach().cpu(), 'b-', label='denoised')
plt.legend()
plt.show()
for i in range(9):
    plt.figure(figsize=(20,15))
    plt.plot(trace[650*i:650*i+650].cpu(), 'k-', alpha=0.5, label='raw')
    plt.plot(trace_clean[650*i:650*i+650].detach().cpu(), 'b-', label='denoised')
    plt.legend()
    plt.show()

In [None]:
def topk_pos(mat, k=3):
    if isinstance(mat, np.ndarray):
        topk = np.argsort(-mat.reshape(-1))[:k]
    elif isinstance(mat, torch.Tensor):
        topk = mat.reshape(-1).topk(k)[1]
        topk = topk.detach().cpu().numpy()
    else:
        raise ValueError(f'Only handle np.ndarray and torch.Tensor, but mat is of type {type(mat)}')
    shape = mat.shape
    return np.array(np.unravel_index(topk, shape)).T

In [None]:
# i = 5
# sel_idx = slice(650*i, 650*(i+1))
sel_idx = slice(len(submat))
plt.figure(figsize=figsize)
sel_pos = topk_pos(submat.mean(0), k=20)
for pos in sel_pos[[0, 1]]:
    trace = submat[sel_idx, pos[0], pos[1]]
    plt.plot(trace.detach().cpu(), label=f'{pos}', alpha=0.5)
plt.legend()
plt.show()

In [None]:
# i = 5
# sel_idx = slice(650*i, 650*(i+1))
plt.figure(figsize=figsize)
for pos in sel_pos[[0, 1]]: #[[10, 13], [12, 16]]:
    trace = submat[sel_idx, pos[0], pos[1]]
    mean = trace.mean()
    std = trace.std()
    pred = model((trace-mean)/std)
    pred = model(pred)
    pred = pred * std + mean
    plt.plot(pred.detach().cpu(), label='denoised', alpha=0.5)
plt.legend()
plt.show()

# 3D Denoise

In [None]:
model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, 
             kernel_size=[3, 3, 3], same_shape=True).to(device)
model.load_state_dict(torch.load('checkpoints/3d_denoise.pt'))

with torch.no_grad():
    submat_clean = model(submat)

In [None]:
imshow(submat.mean(0))
imshow(submat_clean.mean(0))

In [None]:
plt.plot(submat[3000:4000, 15, 31].detach().cpu(), alpha=0.5)
plt.show()
plt.plot(submat_clean[3000:4000, 15, 31].detach().cpu(), alpha=0.5)
plt.show()

In [None]:
model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, kernel_size=[3, 3, 3], 
             same_shape=True).to(device)

filepath = 'checkpoints/segmentation_count_hardmask.pt'
model.load_state_dict(torch.load(filepath))

with torch.no_grad():
    pred = model(submat_clean).mean(0)

In [None]:
cor_global, label_image, regions = extract_super_pixels(mat_adj=None, test_left=None, test_right=None, mat_cat=submat_clean, 
                                                        num_neighbors=4, cor_choice='mean', connectivity=None, 
                                                        min_pixels=10, image=pred, plot=True, use_mean_image=False)

In [None]:
traces, submats = get_submat_traces(seg_idx=0, regions=regions, label_image=label_image, 
                        mat_adj=[submat_clean], 
                        weight_percentile=50, sig_list=None, mat_list=None, cor=cor_global, weighted_denominator=True, 
                        return_name='mat_adj', compare=False, test_left=None, test_right=None)

In [None]:
num_labels = len(traces)
for label_idx in range(1, num_labels+1):
    submat, label_mask, weight = submats[label_idx-1]
    trace = traces[label_idx-1]
    plot_image_label_overlay(cor_global, label_image, sel_idx=label_idx-1)
    plt.figure(figsize=figsize)
    plt.plot(trace.cpu(), linestyle='-', linewidth=1, color='b', markersize=2, alpha=1, label='raw trace')
#     plt.plot(pred.detach().cpu(), '-', color='r', alpha=0.5, label='denoised')
#     plt.legend()
    plt.title(label_idx)
    plt.show()

In [None]:
label_idx = 1
submat, label_mask, weight = submats[label_idx-1]
trace = traces[label_idx-1]
trace_clean = denoise_trace(trace)
plot_image_label_overlay(cor_global, label_image, sel_idx=label_idx-1)
for i in range(9):
    plt.figure(figsize=(20,15))
    plt.plot(trace[650*i:650*i+650].cpu(), 'k--', alpha=0.5, label='raw')
    plt.plot(trace_clean[650*i:650*i+650].detach().cpu(), 'b-', label='denoised')
    plt.legend()
    plt.show()
# zoom_in_trace(trace, markers='-', figsize=(20, 15), trace2=trace_clean)

# Data generated by Trinh

In [None]:
filepath = '/home/tma/tma-disk/sami/trinh/20190924_155117_FOV9.bin'
mat = load_file(filepath, size=(4000, 200, 2000))

mat = mat[100:, :, 600:1200]

cor_global = neighbor_cor(mat, neighbors=8, plot=True, choice='mean', title='correlation map')

cor_global = mat.mean(0) * cor_global
cor_global = cor_global / cor_global.max()

image = cor_global.detach().cpu().numpy()
label_image, regions = get_label_image(image, min_pixels=50, connectivity=1, plot=True)

traces, submats = get_submat_traces(seg_idx=0, regions=regions, label_image=label_image, 
                        mat_adj=[mat], 
                        weight_percentile=50, sig_list=None, mat_list=None, cor=cor_global, weighted_denominator=True, 
                        return_name='mat_adj', compare=False)

plot_image_label_overlay(image, label_image)

from utility import read_tiff_file
for f in os.listdir('/home/tma/tma-disk-tmp/trinh'):
    if re.search('.tif', f):
        tiff = read_tiff_file(f'/home/tma/tma-disk-tmp/trinh/{f}')
        imshow(tiff.T, title=f)