In [1]:
from localmd.decomposition import localmd_decomposition, display, factored_svd
import localmd.tiff_loader as tiff_loader
import scipy
import scipy.sparse
import jax
import jax.scipy
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap

import os

#This line is a temporary flag used to avoid memory allocation errors
import functools
from functools import partial
import time
import torch
import tifffile
import matplotlib.pyplot as plt

%load_ext autoreload
%load_ext line_profiler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Default parameters (can be modified..)

pmd_params_dict = {
    'block_height':32,
    'block_width':32,
    'overlaps_height':16,
    'overlaps_width':16,
    'window_range':10000,
    'background_rank':15,
    'deconvolve':True,
    'max_consec_failures':1,
    'rank_prune_factor':0.25,
    'max_components':100,
    'window_length': 1000
}


#NOTE: this data folder will also contain the location of the TestData
block_height = pmd_params_dict['block_height']
block_width = pmd_params_dict['block_width'] 
block_sizes = [block_height, block_width]

overlaps_height = pmd_params_dict['overlaps_height'] 
overlaps_width = pmd_params_dict['overlaps_width'] 

if overlaps_height > block_height: 
    print("Overlaps height was set to be greater than block height, which is not valid")
    print("Setting overlaps to be 5")
    overlaps_height = 5

if overlaps_width > block_width:
    print("Overlaps width was set to be greater than width height, which is not valid \
    Setting overlaps to be 5")
    overlaps_width = 5

overlap = [overlaps_height, overlaps_width]

max_consec_failures = pmd_params_dict['max_consec_failures']
# explained_variance_threshold = pmd_params_dict['explained_variance_threshold']
rank_prune_factor = pmd_params_dict['rank_prune_factor']


window_range = pmd_params_dict['window_range'] 
if window_range <= 0:
    print("Window length cannot be negative! Resetting to 6000")
    window_range = 6000
start = 0
end = window_range

window_length = pmd_params_dict['window_length']

# num_frames_decompose = window_length

background_rank = pmd_params_dict['background_rank'] 
deconvolve=True
deconv_batch=1000

###THESE PARAMS ARE NEVER MODIFIED
sim_conf = 5

#@markdown Keep run_deconv true unless you do not want to run maskNMF demixing
run_deconv = True
max_components = pmd_params_dict['max_components']

corrector = None

tiff_batch_size = 100
pixel_batch_size = 10000
dtype="float32"

order = "F"

In [3]:
input_file = "../datasets/demoMovie.tif"

In [1]:
%lprun -f factored_svd U, R, s, V, std_img, mean_img, data_shape, data_order, = localmd_decomposition(input_file, block_sizes, overlap, [start, end], \
                                max_components=max_components, background_rank = background_rank, sim_conf=sim_conf,\
                                 tiff_batch_size=tiff_batch_size,pixel_batch_size=pixel_batch_size, dtype=dtype, order=order, \
                                 num_workers=0, frame_corrector_obj = corrector, max_consec_failures=max_consec_failures, rank_prune_factor=rank_prune_factor)

U = U.tocsr()
U = U.astype("float32")
R = R.astype("float32")
s = s.astype("float32")
V = V.astype("float32")

In [6]:
np.savez("INSERT_SAVE_NAME_HERE.npz", fov_shape = data_shape[:2], \
                fov_order=data_order, U_data = U.data, \
                U_indices = U.indices,\
                U_indptr=U.indptr, \
                U_shape = U.shape, \
                U_format = type(U), \
                R = R, \
                s = s, \
                Vt = V, \
                 mean_img = mean_img, \
                 noise_var_img = std_img)

# Generate a comparison triptych to show how well PMD retains signal from the original movie

In [5]:
#### MODIFY THIS. time_range specifies which temporal subset (frames) of the dataset you'd like to view in the triptych
time_range = [1500, 3000]
####


### DO NOT MODIFY
order = data_order
V_crop = V[:, time_range[0]:time_range[1]]
sV = s[:, None] * V_crop
RsV = R.dot(sV)
mov = U.tocsr().dot(RsV)
mov = mov.reshape((data_shape[0], data_shape[1], -1), order = data_order)
keys = [i for i in range(time_range[0], time_range[1])]
original_dataset = tifffile.imread(input_file, key=keys)

!
!
!


In [7]:
#### MODIFY BELOW TO SPATIALLY CROP FIELD OF VIEW (OTHERWISE BELOW CODE SHOWS FULL FOV)
x_range_0 = 0
x_range_1 = 60
y_range_0 = 0
y_range_1 = 80


### DO NOT MODIFY
original_dataset_crop = original_dataset[:, x_range_0:x_range_1, y_range_0:y_range_1].transpose(1,2,0)
noise_var_img = std_img[x_range_0:x_range_1, y_range_0:y_range_1, None]
mean_img = mean_img[x_range_0:x_range_1, y_range_0:y_range_1, None]
mov_crop = mov[x_range_0:x_range_1, y_range_0:y_range_1, :]
mov_crop_rescaled = mov_crop * noise_var_img + mean_img
overall_result = np.zeros((mov_crop.shape[0], mov_crop.shape[1]*3, mov_crop.shape[2]))
overall_result[:, :mov_crop.shape[1], :] = original_dataset_crop
overall_result[:, mov_crop.shape[1]:mov_crop.shape[1]*2, :] = mov_crop_rescaled
overall_result[:, mov_crop.shape[1]*2:, :] = original_dataset_crop - mov_crop_rescaled

In [9]:
### BELOW LINE SAVES TRIPTYCH, FOR LOCAL VIEWING (INMAGEJ, ETC.) 
'''
Note the triptych has three "panels". The leftmost panels hows the raw movie, the center panel
shows the PMD output and the right most panel shows the "residual" (the difference between PMD and the raw data)
'''
tifffile.imwrite("Demo_SidebySide_Triptych.tiff", overall_result.transpose(2, 0, 1).astype("float32"))