### Clone and install the SIMCODE git repository

In [None]:
# Clone the SIMCODE source code
%cd /content
!git clone https://github.com/jcnossen/simcode.git
%cd /content/simcode/
!pip install -e .

### Load all packages and configure the training object 

In [None]:
import smlmtorch.util.progbar
# set False if your jupyter notebook does not support javascript plugins
smlmtorch.util.progbar.USE_AUTO_TQDM = False

import numpy as np
import matplotlib.pyplot as plt
from smlmtorch import config_dict
from smlmtorch.nn.sf_model import SFLocalizationModel
from smlmtorch.nn.model_trainer import LocalizationModelTrainer
from smlmtorch.nn.benchmark.compare_crlb import CRLBPlotGenerator

bg_mean = 5 # for crlb plot
device = 'cuda:0'
moving_wnd_size = 6 # num intensities
L = 64
center_frame_ix = 2 # within moving window of 6
gauss_sigma = 1.3

# these are the configuration parameters for the model as used in the paper, most likely overkill though
config = config_dict(
    model = dict(
        enable_readnoise = False,
        enable3D = False,
        unet_shared_features=[L, L*2], 
        unet_combiner_features=[L, L*2, L*4],
        ie_input_features=32, # number of features going from combiner to IE
        unet_batch_norm = True,
        input_scale = 0.01, # get pixel values into useful range
        input_offset = 3,
        unet_combiner_output_features = 256,
        output_head_features = 48,
        num_intensities=moving_wnd_size,
        max_bg=100,
        max_intensity=20000,
        xyz_scale=[1.1,1.1,1],
        unet_intensity_features=[L, L*2, L*4],
        output_intensity_features=32,
        use_on_prob=False
    ),
    loss = dict(
        gmm_components=0,
        count_loss_weight=0.01,
        track_intensities_offset = center_frame_ix
    ),
    optimizer_type = 'Lion',
    optimizer = dict(
        lr = 2e-5,
        weight_decay = 0.01
    ),
    #clip_grad_norm_ = dict( max_norm=0.03, norm_type=2 ),
    lr_scheduler = dict(step_size=30, gamma=0.5),
    train_size = 8*1024,
    test_size = 1024,
    batch_size = 6,

    simulation = dict(
        num_frames = 32,
        img_shape=(32,32),
        z_range = (-0.5,0.5),
        density_um2 = 1.5,
        pixelsize_nm = 100,
        mean_on_time = 6,
        
        track_intensities = moving_wnd_size,
        track_intensities_offset = center_frame_ix,
        intensity_distr = 'log-normal',
        intensity_mode = 400,
        intensity_mean = 600,
        intensity_mean_min = 50,
        intensity_mean_max = 10000,
        intensity_fluctuation_std = 0.5, # I_frame = I_spot * (1 + fluctuation * randn())
        bg_max = 20,
        bg_min = 0.5,
        render_args = dict(
            read_noise_mean = 0.5,
            read_noise_std = 1
        ),
        psf = dict(
            type = 'Gaussian2D',
            sigma = [gauss_sigma, gauss_sigma],
        )
    ),
    benchmark = dict(
        prob_threshold=0.5, 
        match_distance_px=3,
        kdeplot_params =['N0', 'N5'],
        render_frame_offset=center_frame_ix
    )
)

# x,y,z,I,start,end,bg
savedir = f'/content/models/sf_conv_g{gauss_sigma}_L{L}'

trainer = LocalizationModelTrainer(config, SFLocalizationModel, device,
    save_dir=savedir,
    load_previous_model=True # find the latest checkpoint and continue from there
)



### Show the distribution of spot intensities in the training data

In [None]:
ds = trainer.data_generator.generate(100)
active = ds.spots_ds.spots[:,:,:,0]>0
intensities = ds.spots_ds.spots[:,:,:,1][active]

plt.figure(figsize=(4,2))
plt.hist(intensities, bins=100, range=[0,3000],density=True)
plt.title('Histogram of spot intensities in training data')
plt.ylabel('Intensity prob.density')
plt.xlabel('Intensity [photons]')
plt.savefig(trainer.save_dir+"/intensities_histogram.svg")

### Run the training loop

With current settings, the model should be ok after 100 epochs already. With current settings, the model should be ok after 80 epochs already. Note that you'll need at least the A100 runtime to do this in a reasonable time. Our model was trained on a standalone Ubuntu machine with an 3090 RTX

In [None]:
photon_range = np.logspace(2, 3.5, 10)

crlb_plotter = CRLBPlotGenerator(trainer.model, 100,
    trainer.data_generator.psf,
    param_list=['N0', 'x', 'y'],
    psf_param_mapping = ['N', 'x', 'y'],
    sim_config=config.simulation, device=device)

def plot_callback(epoch, batch, test_output):
    # This allows us to see some intermediate results during training, in addition to the tensorboard logs
    data, camera_calib, labels = batch
    output = trainer.eval_batch(data[[0]], camera_calib[[0]])
    y = trainer.model.to_global_coords(output, revert=True)

    crlb_plotter.plot_photon_range(photon_range, background=bg_mean,
        n_frames=moving_wnd_size, log_writer=trainer.writer, log_step=epoch)

    mf = config.model.num_intensities
    y = y[0,0]
    fig, ax=plt.subplots(1,4)
    ax[0].imshow(data[0,0].cpu().numpy()); ax[0].set_title('Input')
    ax[1].imshow(y[0].cpu().numpy()); ax[1].set_title('p')
    ax[2].imshow(y[1].cpu().numpy()); ax[2].set_title('x')
    ax[3].imshow(y[2].cpu().numpy()); ax[3].set_title('y')
    plt.show()
    plt.close()

trainer.train(num_epochs=10, log_interval=1, batch_size = config['batch_size'],
                data_refresh_interval=1, test_callback=plot_callback,
                report_interval=1)
