# DECODE - Batch Fit SMLM Data
The purpose of this notebook is to batch fit acquired sequence using a trained model. Sequence can be tif stacks or nd2 files. For the latter, it uses the [nd2reader](https://github.com/rbnvrw/nd2reader) that can be installed in a DECODE environment using:

```pip install nd2reader```

In [None]:
import sys

import decode
import decode.utils

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml

import glob #added for batch
from nd2reader import ND2Reader # added for reading nd2

print(f"DECODE version: {decode.utils.bookkeeping.decode_state()}")

## Set parameters
Set device for inference (i.e. CUDA vs. CPU, for our setup inference on the CPU is about 10 times slower). If you fit on CPU though, you may want to change the number of threads if you have a big machine (see below).

In [None]:
device = 'cuda:0'  # or 'cpu', or you change cuda device index
threads = 12  #  number of threads, useful for CPU heavy computation. Change if you know what you are doing.
worker = 0  # number of workers for data loading. for Windows it only works with 0 at the moment
batch = 100 # 32-40 works for 8GB VRAM and 256x256 px frames, 96-100 for the RTX 3090

torch.set_num_threads(threads)  # set num threads

Here we set if the fit will use tif sequences or nd2 files as source.

In [None]:
fit_source = 'nd2'
#fit_source = 'tif'

## Specify paths for the model, parameters and frames (folder containing source sequences)

**Important** If the camera parameters of the training differ from the data which should be fitted (e.g. different EM gain), you can try to use the model anyways, but you must specify them here since we convert to photon units before forwarding through the model.

In [None]:
# paths to parameters and model

#STORM 2D
#param_path = 'C:/Users/chris/christo/DECODE/out_NSTORM_STORM647_2D_210415/2021-04-24_10-04-38_NeuroCyto-Proc/param_run.yaml'
#model_path = 'C:/Users/chris/christo/DECODE/out_NSTORM_STORM647_2D_210415/2021-04-24_10-04-38_NeuroCyto-Proc/model_0.pt'

#STORM 3D
param_path = 'C:/Users/Christo/Work/Processing/DECODE10/out_NSTORM_STORM647_3D_210415/2021-04-27_15-54-13_NCIS-Analyse1/param_run.yaml'
model_path = 'C:/Users/Christo/Work/Processing/DECODE10/out_NSTORM_STORM647_3D_210415/2021-04-27_15-54-13_NCIS-Analyse1/model_2.pt'

# path to folder containing source sequences
framefolder_path = 'W:/NC_DATA_NSTORM1_#1/Christo/210430 MAP2#3 (div16) SR/' # don't forget the / at the end

# use nd2 files or tif files as source
if fit_source == 'nd2':
    filelist = glob.glob(framefolder_path + "*.nd2")
else:
    filelist = glob.glob(framefolder_path + "*.tif")
filelist = glob.glob(framefolder_path + "*.nd2")
filelist.sort()
for filepath in filelist :
    filepath = filepath.replace("\\", "/")
    print(filepath)

# output path
outfolder_path = 'D:/NeuroCyto/data SMLM/210430 MAP2#2 (div16)/Locs ah5/' # don't forget the / at the end
    
# specify camera parameters of source sequences (if different from model)
meta = {
    'Camera': { # N-STORM EMCCD parameters
        'baseline': 100,
        'e_per_adu': 12.48,
        'em_gain': 100,
        'read_sigma': 74.4,
        'spur_noise': 0.002  # if you don't know, you can set this to 0
    }
}

## Load Parameters and Model
Specify Post-Processing as by the parameter file you trained the model with

In [None]:
param = decode.utils.param_io.load_params(param_path)
model = decode.neuralfitter.models.SigmaMUNet.parse(param)
model = decode.utils.model_io.LoadSaveModel(model,
                                            input_file=model_path,
                                            output_file=None).load_init(device=device)

In [None]:
# overwrite camera
param = decode.utils.param_io.autofill_dict(meta['Camera'], param.to_dict(), mode_missing='include')
param = decode.utils.param_io.RecursiveNamespace(**param)

## Fit the Data and Export the Resulting Coordinates as .h5 file

In [None]:
for frame_path in filelist : 

    # sanitize path
    frame_path = frame_path.replace("\\", "/")
   
    # import nd2 file
    if fit_source == 'nd2':
        # load nd2 file    
        ndx = ND2Reader(frame_path)
        sizes = ndx.sizes

        ndx.bundle_axes = 'yx'
        ndx.iter_axes = 't'
        n = len(ndx)
        shape = (sizes['t'], sizes['y'], sizes['x'])
        image  = np.zeros(shape, dtype=np.float32)

        for i in range(n):
            image[i] = ndx.get_frame(i)
        image = np.squeeze(image)

        frames = torch.from_numpy(image)
    # import tif file
    else:
        frames = decode.utils.frames_io.load_tif(frame_path)
        
    print("processing file " + frame_path)
    print(frames.shape)
    
    camera = decode.simulation.camera.Photon2Camera.parse(param)
    camera.device = 'cpu'
    
    
    # setup frame processing as by the parameter with which the model was trained
    frame_proc = decode.neuralfitter.utils.processing.TransformSequence([
        decode.neuralfitter.utils.processing.wrap_callable(camera.backward),
        decode.neuralfitter.frame_processing.AutoCenterCrop(8),
        #decode.neuralfitter.frame_processing.Mirror2D(dims=-1),  # WARNING: You might need to comment this line out.
        decode.neuralfitter.scale_transform.AmplitudeRescale.parse(param)
    ])

    
    # determine extent of frame and its dimension after frame_processing
    size_procced = decode.neuralfitter.frame_processing.get_frame_extent(frames.unsqueeze(1).size(), frame_proc.forward)  # frame size after processing
    frame_extent = ((-0.5, size_procced[-2] - 0.5), (-0.5, size_procced[-1] - 0.5))

    
    # Setup post-processing
    # It's a sequence of backscaling, relative to abs. coord conversion and frame2emitter conversion
    post_proc = decode.neuralfitter.utils.processing.TransformSequence([

        decode.neuralfitter.scale_transform.InverseParamListRescale.parse(param),

        decode.neuralfitter.coord_transform.Offset2Coordinate(xextent=frame_extent[0],
                                                              yextent=frame_extent[1],
                                                              img_shape=size_procced[-2:]),

        decode.neuralfitter.post_processing.SpatialIntegration(raw_th=0.1,
                                                              xy_unit='px',
                                                              px_size=param.Camera.px_size)


    ])
    
    
    # fit the data
    infer = decode.neuralfitter.Infer(model=model, ch_in=param.HyperParameter.channels_in,
                                      frame_proc=frame_proc, post_proc=post_proc,
                                      device=device, num_workers=worker, batch_size = batch)

    emitter = infer.forward(frames[:])
    
    
    # check on the output
    print(emitter)
    
    
    # Check if the predictions look reasonable on a random frame
    random_ix = torch.randint(frames.size(0), size=(1, )).item()
    em_subset = emitter.get_subset_frame(random_ix, random_ix)

    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    decode.plot.PlotFrameCoord(frame=frame_proc.forward(frames[[random_ix]])).plot()
    plt.subplot(122)
    decode.plot.PlotFrameCoord(frame=frame_proc.forward(frames[[random_ix]]),
                               pos_out=em_subset.xyz_px, phot_out=em_subset.prob).plot()
    plt.show()
    
    
    # Compare the inferred distribution of the photon numbers and background values with the ranges used during training
    plt.figure(figsize=(14,4))

    plt.subplot(131)
    mu, sig = param.Simulation.intensity_mu_sig
    plt.axvspan(0, mu+sig*3, color='green', alpha=0.1)
    sns.distplot(emitter.phot.numpy())
    plt.xlabel('Inferred number of photons')
    plt.xlim(0)

    plt.subplot(132)
    plt.axvspan(*param.Simulation.bg_uniform, color='green', alpha=0.1)
    sns.distplot(emitter.bg.numpy())
    plt.xlabel('Inferred background values')

    plt.show()
    

    # plot coordinates histograms
    plt.figure(figsize=(18,4))
    plt.subplot(131)
    plt.hist(emitter.xyz_nm[:, 0].numpy(),100)
    plt.xlabel('X (nm)')

    plt.subplot(132)
    plt.hist(emitter.xyz_nm[:, 1].numpy(),100)
    plt.xlabel('Y (nm)')

    plt.subplot(133)
    plt.hist(emitter.xyz_nm[:, 2].numpy(),100)
    plt.xlabel('Z (nm)')
    
    plt.show()

    
    # plot estimates uncertainties histograms
    plt.figure(figsize=(18,4))
    plt.subplot(131)
    sns.distplot(emitter.xyz_sig_nm[:, 0].numpy())
    plt.xlabel('Sigma Estimate in X (nm)')

    plt.subplot(132)
    sns.distplot(emitter.xyz_sig_nm[:, 1].numpy())
    plt.xlabel('Sigma Estimate in Y (nm)')

    plt.subplot(133)
    sns.distplot(emitter.xyz_sig_nm[:, 2].numpy())
    plt.xlabel('Sigma Estimate in Z (nm)')

    plt.show()
    
    
    # plot raw emitter set
    fig, axs = plt.subplots(2,2,figsize=(24, 12), sharex='col', gridspec_kw={'height_ratios':[1,1200/20000]})

    decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=5, zextent=[-600,600], colextent=[-500,500], plot_axis=(0,1), contrast=1.25).render(emitter, emitter.xyz_nm[:,2], ax=axs[0,0])
    decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=50, zextent=[-600,600], plot_axis=(0,2)).render(emitter, ax=axs[1,0])

    decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=5, zextent=[-600,600], colextent=[0,75], plot_axis=(0,1), contrast=1.25).render(emitter, emitter.xyz_sig_weighted_tot_nm, ax=axs[0,1])
    decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=50, zextent=[-600,600], colextent=[0,75], plot_axis=(0,2)).render(emitter, emitter.xyz_sig_weighted_tot_nm, ax=axs[1,1])

    plt.show()
    
    # h5 save
    h5out_path = frame_path.replace(framefolder_path, outfolder_path)
    
    if fit_source == 'nd2': 
        h5out_path = h5out_path.replace(".nd2", ".h5")
    else:
        h5out_path = frame_path.replace(".tif", ".h5") 
        
    emitter.save(h5out_path)  # can be loaded via 'decode.EmitterSet.load('emitter.h5')'
    print("saved file " + h5out_path)