# DECODE - NeuroCyto Fit Batch 2

The purpose of this notebook is to fit all Nikon nd2 sequences within a folder (acquired using the N-STORM microscope). It will identify the channel used tand use the correct trained model for the channel. In the case of a multi-channel acquisition, it will process each channel within the nd2 file with the proper trained model.

Currently the two use cases are:
- single channel STORM acquisition with 647 nm excitation and 405 nm pump, resulting in a '405/647 R1' channel
- two-channel PAINT acquisition with 561 and 647 nm excitation, resulting in two '561 R1' and '647 R1' channels

In [None]:
import sys

import decode
import decode.utils

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import glob #added for batch

# Read Nikon nd2 files (nd2reader 3.3.0)
# installed using pip in environment, see here: https://rbnvrw.github.io/nd2reader/tutorial.html#installation
from nd2reader import ND2Reader # added for reading nd2

# Progress bar during file reading
# installed ipywidgets and activated them in environment, see "Installation", Point 2 here: https://towardsdatascience.com/ever-wanted-progress-bars-in-jupyter-bdb3988d9cfc
from tqdm.notebook import tqdm, trange # added for progress bar
import time  # added for progress bar

# Currently using DECODE 0.10.0 in the decode10_env environment 
print(f"DECODE version: {decode.utils.bookkeeping.decode_state()}")

## Set parameters: Acquisition, Inference, Camera parameters

**Acquisition parameters:** Set the type of acquisition (STORM or PAINT, 2D or 3D) and define the channels identifiers (possible channels = wavelength of each excitation laser)

In [None]:
# Acquisition type
acqu_type = 'STORM3D' #or STORM2D
#acqu_type = 'PAINT3D' #or PAINT2D

# Channel identifiers
channel_ids = ['488', '561', '647']

**Inference parameters:** Set device for inference (i.e. CUDA vs. CPU, for our setup inference on the CPU is about 10 times slower). If you fit on CPU though, you may want to change the number of threads if you have a big machine (see below).

In [None]:
device = 'cuda:0'  # or 'cpu', or you change cuda device index
threads = 8  #  number of threads, useful for CPU heavy computation. Change if you know what you are doing.
worker = 0  # number of workers for data loading. Change only if you know what you are doing.
batch = 100 # 32-40 works for 8GB VRAM and 256x256 px, 96 for the RTX 3090

torch.set_num_threads(threads)  # set num threads

**Camera parameters:** If the camera parameters of the training differ from the data which should be fitted (e.g. different EM gain), you can try to use the model anyways, but you must specify them here since we convert to photon units before forwarding through the model.

In [None]:
# Specify camera parameters of source sequences (if different from model)
over_cam = 0
meta = {
    'Camera': {
        'baseline': 100, # N-STORM EMCCD
        'e_per_adu': 12.48,
        'em_gain': 100,
        'read_sigma': 74.4,
        'spur_noise': 0.002  # if you don't know, you can set this to 0
    }
}

## Specify paths: models, source sequences, and output files paths

Here we set up the paths:
- root path is where your library of trained models are found
- framefolder path is the folder where the source sequences are (folder should contain all nd2 files to be processed and no other nd2 file, as script will try to process all nd2 files it finds within this folder)
- outfolder path is the folder where the localization ah5 files will be saved after fitting of each channel of each nd2 file

To find the right model, the loops will use two things:
- the acquisition type defined at step 2 (STORM3D, PAINT3D)
- the wavelength of the excitation laser that is contained in the channel name in the file metadata (488, 561, 647...)

So you should store two files: model_1.pt and param_run.yaml from the output folder of your corresponding model training in a 'Current Models' folder (defined as the root path) like this:
- Current Models/STORM3D_647/
- Current Models/STORM3D_561/
- Current Models/STORM3D_647/
- etc.

In [None]:
# Model root path
root_path = 'C:/path_to_your_decode_folder/Current models/'

# Path to folder containing source sequences
framefolder_path = 'W:/path_to_your_nd2_files/' # can be on a server, don't forget the / at the end

# Output path for h5 localizations files
outfolder_path = 'G:/path_where_you_save_the_ah5_loc_files/'# don't forget the / at the end


# Generate the list of nd2 files inside the framefolder
filelist = glob.glob(framefolder_path + "*.nd2")
filelist.sort()
for filepath in filelist :
    filepath = filepath.replace("\\", "/")
    print(filepath)
    

## Main part
### Loop on nd2 files and channels within, fit the data (with the right model) and save the resulting localisations (h5 file)

In [None]:
# Loop on nd2 files
for frame_path in filelist : 

    # Sanitize path
    frame_path = frame_path.replace("\\", "/")
    print("")
    print("")
    print("processing file " + frame_path)
   
    # Load nd2 file    
    ndx = ND2Reader(frame_path)
    sizes = ndx.sizes
    channels = ndx.metadata['channels']
    nchan = len(channels)

    print("")
    print("   loading nd2 sequence with " + str(sizes) + " dimensions with " + str(nchan) + " channel(s): " + str(channels))

    ndx.bundle_axes = 'yx'
    ndx.iter_axes = 't'
    n = len(ndx)   

    # Loop on channels
    curr_ch = -1;

    for channel in channels :

        # Define current channel
        curr_ch = curr_ch + 1
        if nchan > 1:
            ndx.default_coords['c'] = curr_ch

        # Make image stack to host single-channel frames
        shape = (sizes['t'], sizes['y'], sizes['x'])
        image  = np.zeros(shape, dtype=np.float32)

        # Assign frames from nd2 file to image stack
        for i in tqdm(range(n)):
            image[i] = ndx.get_frame(i)

        # Convert image stack to torch stack
        image = np.squeeze(image)
        frames = torch.from_numpy(image)

        # Log the import
        print("")
        print("")
        print("      imported torch stack with " + str(frames.shape) + " frames using channel \'" + str(channels[curr_ch]) + "\'")

        # Find the channel identifier of the current channel
        curr_id = '647'
        for channel_id in channel_ids:
            if channels[curr_ch].find(channel_id) > -1:
                curr_id = channel_id

        # Build the path to the trained model for the current channel
        model_folder = acqu_type + "_" + curr_id + "/"

        # Define path to the model used for fitting
        print("      using trained model stored in " + root_path + model_folder)
        print("")
        param_path = root_path + model_folder + "param_run.yaml"
        model_path = root_path + model_folder + "model_1.pt"

        # Load the trained model
        param = decode.utils.param_io.load_params(param_path)
        model = decode.neuralfitter.models.SigmaMUNet.parse(param)
        model = decode.utils.model_io.LoadSaveModel(model,
                                        input_file=model_path,
                                        output_file=None).load_init(device=device)

        # Overwrite camera if necessary
        if over_cam > 0:
            param = decode.utils.param_io.autofill_dict(meta['Camera'], param.to_dict(), mode_missing='include')
            param = decode.utils.param_io.RecursiveNamespace(**param)

        # Set camera
        camera = decode.simulation.camera.Photon2Camera.parse(param)
        camera.device = 'cpu'


        # Setup frame processing as by the parameter with which the model was trained
        frame_proc = decode.neuralfitter.utils.processing.TransformSequence([
            decode.neuralfitter.utils.processing.wrap_callable(camera.backward),
            decode.neuralfitter.frame_processing.AutoCenterCrop(8),
            #decode.neuralfitter.frame_processing.Mirror2D(dims=-1),  # WARNING: You might need to comment this line out.
            decode.neuralfitter.scale_transform.AmplitudeRescale.parse(param)
        ])


        # Determine extent of frame and its dimension after frame_processing
        size_procced = decode.neuralfitter.frame_processing.get_frame_extent(frames.unsqueeze(1).size(), frame_proc.forward)  # frame size after processing
        frame_extent = ((-0.5, size_procced[-2] - 0.5), (-0.5, size_procced[-1] - 0.5))


        # Setup post-processing
        # It's a sequence of backscaling, relative to abs. coord conversion and frame2emitter conversion
        post_proc = decode.neuralfitter.utils.processing.TransformSequence([

            decode.neuralfitter.scale_transform.InverseParamListRescale.parse(param),

            decode.neuralfitter.coord_transform.Offset2Coordinate(xextent=frame_extent[0],
                                                                  yextent=frame_extent[1],
                                                                  img_shape=size_procced[-2:]),

            decode.neuralfitter.post_processing.SpatialIntegration(raw_th=0.1,
                                                                  xy_unit='px',
                                                                  px_size=param.Camera.px_size)


        ])


        # Fit the data
        infer = decode.neuralfitter.Infer(model=model, ch_in=param.HyperParameter.channels_in,
                                          frame_proc=frame_proc, post_proc=post_proc,
                                          device=device, num_workers=worker, batch_size = batch)

        emitter = infer.forward(frames[:])


        # Check on the output
        print(emitter)


        # Check if the predictions look reasonable on a random frame
        random_ix = torch.randint(frames.size(0), size=(1, )).item()
        em_subset = emitter.get_subset_frame(random_ix, random_ix)

        plt.figure(figsize=(12, 6))
        plt.subplot(121)
        decode.plot.PlotFrameCoord(frame=frame_proc.forward(frames[[random_ix]])).plot()
        plt.subplot(122)
        decode.plot.PlotFrameCoord(frame=frame_proc.forward(frames[[random_ix]]),
                                   pos_out=em_subset.xyz_px, phot_out=em_subset.prob).plot()
        plt.show()


        # Compare the inferred distribution of the photon numbers and background values with the ranges used during training
        plt.figure(figsize=(14,4))

        plt.subplot(131)
        mu, sig = param.Simulation.intensity_mu_sig
        plt.axvspan(0, mu+sig*3, color='green', alpha=0.1)
        sns.distplot(emitter.phot.numpy())
        plt.xlabel('Inferred number of photons')
        plt.xlim(0, 10000)

        plt.subplot(132)
        plt.axvspan(*param.Simulation.bg_uniform, color='green', alpha=0.1)
        sns.distplot(emitter.bg.numpy())
        plt.xlabel('Inferred background values')
        plt.xlim(0, 400)

        plt.show()


        # plot coordinates histograms
        plt.figure(figsize=(18,4))
        plt.subplot(131)
        plt.hist(emitter.xyz_nm[:, 0].numpy(),100)
        plt.xlabel('X (nm)')

        plt.subplot(132)
        plt.hist(emitter.xyz_nm[:, 1].numpy(),100)
        plt.xlabel('Y (nm)')

        plt.subplot(133)
        plt.hist(emitter.xyz_nm[:, 2].numpy(),100)
        plt.xlabel('Z (nm)')

        plt.show()


        # plot uncertainties histograms
        plt.figure(figsize=(18,4))
        plt.subplot(131)
        sns.distplot(emitter.xyz_sig_nm[:, 0].numpy())
        plt.xlabel('Sigma Estimate in X (nm)')
        plt.xlim(0, 100)

        plt.subplot(132)
        sns.distplot(emitter.xyz_sig_nm[:, 1].numpy())
        plt.xlabel('Sigma Estimate in Y (nm)')
        plt.xlim(0, 100)

        plt.subplot(133)
        sns.distplot(emitter.xyz_sig_nm[:, 2].numpy())
        plt.xlabel('Sigma Estimate in Z (nm)')
        plt.xlim(0, 300)

        plt.show()


        # plot raw emitter set
        fig, axs = plt.subplots(2,2,figsize=(24, 12), sharex='col', gridspec_kw={'height_ratios':[1,1200/20000]})

        decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=5, zextent=[-600,600], colextent=[-500,500], plot_axis=(0,1), contrast=1.25).render(emitter, emitter.xyz_nm[:,2], ax=axs[0,0])
        decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=50, zextent=[-600,600], plot_axis=(0,2)).render(emitter, ax=axs[1,0])

        decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=5, zextent=[-600,600], colextent=[0,75], plot_axis=(0,1), contrast=1.25).render(emitter, emitter.xyz_sig_weighted_tot_nm, ax=axs[0,1])
        decode.renderer.Renderer2D(px_size=10., sigma_blur=5., rel_clip=None, abs_clip=50, zextent=[-600,600], colextent=[0,75], plot_axis=(0,2)).render(emitter, emitter.xyz_sig_weighted_tot_nm, ax=axs[1,1])

        plt.show()

        # h5 save
        h5out_path = frame_path.replace(framefolder_path, outfolder_path)
        h5out_path = h5out_path.replace(".nd2", "_" + curr_id + ".h5")

        emitter.save(h5out_path)  # can be loaded via 'decode.EmitterSet.load('emitter.h5')'
        print("saved file " + h5out_path)

