# Evaluate efficacy of gain-correction

In [None]:
# Load a trained model
import torch
import math
import gpytorch

from torch.utils.data import TensorDataset, DataLoader


import preprocUtils
import preprocRandomVariables
import preprocLikelihoods
import preprocModels
import preprocKernels

from collections import OrderedDict

# Plotly 
import plotly
from plotly.offline import iplot as plt
from plotly import graph_objs as plt_type
from plotly import graph_objs as go
plotly.offline.init_notebook_mode(connected=True)
import colorcet # For custom colormaps

import nbimporter
from preprocVisualisationTesting import *
from evaluateResults_new import *

import time

# Get a bunch of useful utility functions for loading data and results
from thesis_final_func_defs import *

from IPython.display import clear_output

In [None]:
data_dir='/nfs/data/gergo/Neurofinder_update/'


stamp_git = '_gitsha_' + '2bd0d720de0995be6b0f1795304839f9877cb6c3'
stamp_training_type = '_rPC_1_origPMgain_useNans'
stamp_trainingCoverage = '_targetCoverage_10'
stamp_modelGridType = '_grid_30_7' 

remove_PCs = 1

final_stamp = stamp_git + stamp_training_type + stamp_trainingCoverage + stamp_modelGridType

prior = 'expertPrior'
lik = 'unampLik'

exportNow = True
instant_clear_outputs = exportNow
time_stamp = '_20190525T114324'#getTimestamp()

In [None]:
data_id = '3'
subdataset = '00'
dataset_name = 'neurofinder.0' + data_id + '.' + subdataset

In [None]:
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,
#                 data_id = data_id,
#                 subdataset = subdataset,
#                 prior=prior, 
#                 lik=lik,
#                 stamp = final_stamp)

# imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# imagesc(mean_im, pixels_per_micron=1.15)
# imagesc(corr_mean_im, pixels_per_micron=1.15)

In [None]:
device = 'cuda:2'

# Load the appropriate fitted model
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
loadFittedModel(
    dataset_name = dataset_name,
    data_dir=data_dir,
    prior=prior, 
    lik=lik, 
    stamp = final_stamp,
    device = device
)

# Load also the corresponding linear likelihood model!
lin_model_list = \
loadFittedModel(
    dataset_name = dataset_name,
    data_dir=data_dir,
    prior=prior, 
    lik='linLik', 
    stamp = final_stamp,
    device = device
)

#
lin_likelihood = lin_model_list[0].likelihood
lin_likParams = OrderedDict(lin_likelihood.named_parameters())

# Store the expected additive "electronic" noise (based on this linear model estimate)
gauss_noise_std = lin_likParams['log_noise'].data[0].exp().sqrt()


In [None]:
dataStats

In [None]:

# Load the appropriate data (with potentially correcting for photomultiple gain included in the mll model object
imgsImputed = loadImputedData(
    dataset_name = dataset_name,
    data_dir=data_dir,
    device = device,
    # We can supply a model that corrects for the photomultipler gain
    mll = mll
)

In [None]:
likParams = OrderedDict(mll.likelihood.named_parameters())
# for key in ['log_gain', 'log_underamplified_amplitude']:
#     print(key, np.exp(likParams[key].data[0]))
    
# for key in ['log_noise', 'log_noise_pedestal']:
#     print(key, np.sqrt(np.exp(likParams[key].data[0])))
print('offset', likParams['offset'].data[0])
print('log_noise_pedestal', np.sqrt(np.exp(likParams['log_noise_pedestal'].data[0])))
print('logit_underamplified_probability', logistic(likParams['logit_underamplified_probability'].data[0]))
print('log_underamplified_amplitude', np.exp(likParams['log_underamplified_amplitude'].data[0]))
print('log_gain', np.exp(likParams['log_gain'].data[0]))
print('log_noise', np.sqrt(np.exp(likParams['log_noise'].data[0])))
print('max pixel value', imgsImputed.max())
print('max photon count', (imgsImputed.max().cpu()-likParams['offset'].data[0].cpu())/np.exp(likParams['log_gain'].data[0].cpu()))

In [None]:
if remove_PCs is not None:
    # # Get the PCs from full dataset
    U, S, V = torch.svd((imgsImputed-imgsImputed.mean(2).unsqueeze(-1)).view(-1, imgsImputed.size(2)).cpu())

    # Get the PCs from just training data (this seems to be a bad option for locating pixels based on crossCorr)
    # U, S, V = torch.svd((train_y-train_y.mean(1).unsqueeze(-1)))

    to_remove = U[:, :remove_PCs].matmul(S[:remove_PCs].diag()).matmul(V[:,:remove_PCs].t())

    imgsImputed -= to_remove.view(*imgsImputed.size()).to(device)

In [None]:
torch.cuda.empty_cache()

In [None]:
# Get the MAP transformation 
gray_levels, inverse_poiss_MAP = getInverseMapEstimate(
    likelihood,
    max_gray_level = preprocUtils.nanmax(imgsImputed.cpu()),
    max_photon = float(200)
)
gray_levels = gray_levels[2:] # Ignore negative values
inverse_poiss_MAP = inverse_poiss_MAP[2:] # Ignore negative values



# Get the MAP transformation for the linear model
lin_gray_levels, lin_inverse_poiss_MAP = getInverseMapEstimate(
    lin_likelihood,
    max_gray_level = preprocUtils.nanmax(imgsImputed.cpu()),
    max_photon = float(70)
)
lin_gray_levels = lin_gray_levels[2:] # Ignore negative values
lin_inverse_poiss_MAP = lin_inverse_poiss_MAP[2:] # Ignore negative values

In [None]:
imgsImputedLin = imgsImputed.cpu().detach() - lin_likelihood.offset.data[0].cpu().detach()
imgsImputedLin[imgsImputedLin < 0] = 0.

# For the linear case we have to know that the spatial and likelihood gain function have a non-determinancy, 
# so assume that the spatial mean gain was 1, and normalise the dataset accordingly with this now effective total gain
lin_pred_gain_func = lin_model_list[-2]
lin_mean_spatial_gain_train = lin_pred_gain_func[train_x.detach().long().unbind(1)].mean().detach()
lin_pred_gain_func = lin_pred_gain_func.div(lin_mean_spatial_gain_train)

imgsImputedLin = imgsImputedLin.div(
    lin_likelihood.log_gain.data[0].exp().cpu() * lin_mean_spatial_gain_train.cpu()
).detach()

In [None]:
imgsImputedPhoton = torch.stack([
    progress_bar(
        func = lambda image: im2photon(image, inverse_poiss_MAP, gray_levels, keep_zeros=True).to('cpu').detach(),
        inp = image,
        index = index,
        report = True,
        report_freq = 400
    )
    for index, image in enumerate(imgsImputed.permute(2,0,1).detach().to(device))], 
    dim=2).detach()

In [None]:
# Correct the individual images with the gain
# gainRange = [1./100, 100.]
# imgsImputedCorr = ((imgsImputedPhoton)
#                    .div(torch.clamp(pred_gain_func.to('cpu'), min=gainRange[0], max=gainRange[1]).unsqueeze(2))
#                   )

imgsImputedCorr = imgsImputedPhoton.div(pred_gain_func.to('cpu').unsqueeze(2)).detach()

In [None]:
# Load the Nan dataset, to make sure detected Nan pixels do not form part of further analyses
imgsNan = loadNanData(
    dataset_name = dataset_name,
    data_dir=data_dir,
    device = 'cpu',
    # We can supply a model that corrects for the photomultipler gain
    mll = mll
)


In [None]:
torch.cuda.empty_cache()

In [None]:
imgsImputed = imgsImputed.cpu()
train_x = train_x.cpu()

In [None]:
nanmask = torch.isnan(imgsNan[:,:,:imgsImputed.shape[2]])
imgsImputed[nanmask] = float('nan')
imgsImputedLin[nanmask] = float('nan')
imgsImputedPhoton[nanmask] = float('nan')
imgsImputedCorr[nanmask] = float('nan')

#imagesc(preprocUtils.nanmean(imgsImputedCorr.detach(), dim=2))

In [None]:
imgsImputedLin = imgsImputedLin.detach()
imgsImputedPhoton = imgsImputedPhoton.detach()
imgsImputedCorr = imgsImputedCorr.detach()

## Evaluate likelihood transform (gray -> photon)


Evaluates the affine fit to mean-variance description of the pixels


In [None]:
#raise("Stop execute all")

In [None]:
vis_skip = 5

#### Affine transform

In [None]:
# Also, plot the resulting non-linear maximum likelihood transformation curve
#data = plot(X=gray_levels, Y=inverse_poiss_MAP, now = False)

# Using the Gaussian noise estimate from the linear model

data = [
    plt_type.Scatter( x = lin_gray_levels.view(-1)-2*gauss_noise_std, y=lin_inverse_poiss_MAP.view(-1),
                     line=dict(color='rgba(0,0,0,0.2)')
                    ),
    plt_type.Scatter( x = lin_gray_levels.view(-1), y=lin_inverse_poiss_MAP.view(-1),
                fill='tonextx',
                fillcolor='rgba(0,0,0,0.2)'
                    ),
    plt_type.Scatter( x = lin_gray_levels.view(-1)+2*gauss_noise_std, y=lin_inverse_poiss_MAP.view(-1),
                fill='tonextx',
                line=dict(color='rgba(0,0,0,0.2)'),
                fillcolor='rgba(0,0,0,0.2)'
                    )
]

layoutArgs = defaultLayout(scale=1.2)
dict_merge(layoutArgs, dict(
    xaxis = dict(
        title = 'Grey level in data (a.u.)',
        range = [0, gray_levels.max()]
    ),
    yaxis = dict(
        title = 'Estimated photon flux (photon)',
        showgrid=False
    ),
    colorway = ['black'],
    showlegend = False
))

fig = plt_type.Figure(data=data, layout=plt_type.Layout(**layoutArgs))

if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_InverseML_Lin_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='plot', image='svg')
else:
    plt(fig)

if instant_clear_outputs:
    clear_output()

#### Non-linear photomultiplier likelihood transform

In [None]:
# import importlib
# import thesis_final_func_defs
# importlib.reload(thesis_final_func_defs)
# from thesis_final_func_defs import *

In [None]:
# Get the marginal probability distribution over incoming photon numbers at each observed grey level

# Generate marginal photon count distributions at various input levels ()
max_input = 20.
photon_counts, photon_log_probs = (
    likelihood.getPhotonLogProbs(torch.arange(0.,max_input).view(1,-1).to(gray_levels.device)*torch.ones_like(gray_levels).view(-1,1), 
                                 max_photon=float(200), reNormalise = False))

# Get the response distribution at each potential photon count
p_PM = likelihood.createResponseDistributions(photon_counts) 

# Get the marginal log probabilities of each observation at each photon count
cur_target_slice = gray_levels.view(-1)
allLogProbs = torch.cat(
                [p_PM[0].log_prob(cur_target_slice.view(-1,1)).view(-1,1),
                p_PM[1].log_prob(cur_target_slice.view(-1,1)).view(-1,1),
                p_PM[2].log_prob(cur_target_slice.view(-1,1))],
                dim = 1)
            
# Correct for the less than 1 observations with log CDF instead of log_prob
if (cur_target_slice<=0.).sum()>0:       
    allLogProbs[cur_target_slice<=0., :] = torch.cat(
        [p_PM[0].cdf(0.).log().view(-1),
        p_PM[1].cdf(0.).log().view(-1),
        p_PM[2].cdf(0.).log().view(-1)],
        dim = 0)
    

# Create the figure
allProbs = allLogProbs.exp().div(allLogProbs.exp().sum(1).view(-1,1)).detach().data
n_max = 11

layout_extra = defaultLayout(scale=1.2)
dict_merge(layout_extra, dict(xaxis=dict(title='Grey level in data (a.u.)')))

fig = plotStacked(
    torch.cat([allProbs[:,:n_max], (1.-allProbs[:,:n_max].sum(1)).view(-1,1)],dim=1), 
    gray_levels, 
    now=False,
    layout=layout_extra
)


if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_InverseMarginals_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='plot', image='svg')
else:
    plt(fig)
 

if instant_clear_outputs:
    clear_output()
    time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots
    

In [None]:
gray_level_stds = im2photon(
    torch.stack([gray_levels-gauss_noise_std, gray_levels+gauss_noise_std], dim=1), 
                inverse_poiss_MAP, gray_levels, keep_zeros=False)

In [None]:
# Also, plot the resulting non-linear maximum likelihood transformation curve
#data = plot(X=gray_levels, Y=inverse_poiss_MAP, now = False)

# Get the error bars in photon flux estimates, by squashing the 
# Gaussian distribution mean-std, mean+std curves (the assumed "electronic noise")
# through the estimated non-linear transform induced by the photomultiplier
gray_level_stds = im2photon(
    torch.stack([gray_levels-gauss_noise_std, gray_levels+gauss_noise_std], dim=1), 
                inverse_poiss_MAP, gray_levels, keep_zeros=False)

data = [
    plt_type.Scatter( x = gray_levels.view(-1), y=gray_level_stds[:,0],
                     line=dict(color='rgba(0,0,0,0.2)')
                    ),
    plt_type.Scatter( x = gray_levels.view(-1), y=inverse_poiss_MAP.view(-1),
                fill='tonextx',
                fillcolor='rgba(0,0,0,0.2)'
                    ),
    plt_type.Scatter( x = gray_levels.view(-1), y=gray_level_stds[:,1],
                fill='tonextx',
                line=dict(color='rgba(0,0,0,0.2)'),
                fillcolor='rgba(0,0,0,0.2)'
                    )
]

layoutArgs = defaultLayout(scale=1.2)
dict_merge(layoutArgs, dict(
    xaxis = dict(
        title = 'Grey level in data (a.u.)',
        range = [0, gray_levels.max()]
    ),
    yaxis = dict(
        title = 'Estimated photon flux (photon)',
        showgrid=False
    ),
    colorway = ['black'],
    showlegend = False
))

fig = plt_type.Figure(data=data, layout=plt_type.Layout(**layoutArgs))

if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_InverseML_NonLin_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='plot', image='svg')
else:
    plt(fig)

if instant_clear_outputs:
    clear_output()
    time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots
    
    

In [None]:
# # Get the forward model estimates
# inp = gpytorch.random_variables.GaussianRandomVariable(inverse_poiss_MAP.view(-1), 
#                              gpytorch.lazy.DiagLazyVariable(1e-10*torch.ones_like(inverse_poiss_MAP.view(-1)))
#                             )
# pred_at_inverseMAP = likelihood(latent_func = inp, approx = False, max_photon=int(570))
# pred_at_inverseMAP_mean = pred_at_inverseMAP.mean()
# pred_at_inverseMAP_mean

# plot(
#     Y = torch.stack(
#         [pred_at_inverseMAP.mean()-pred_at_inverseMAP.std(),
#         pred_at_inverseMAP.mean(), 
#          pred_at_inverseMAP.mean()+pred_at_inverseMAP.std()],
#         dim = 1
#     ),
#     X =  inverse_poiss_MAP
#     )
         

In [None]:
# plot(
#     X = torch.stack(
#         [gray_levels-2*gauss_noise_std,
#         gray_levels, 
#          gray_levels+2*gauss_noise_std],
#         dim = 1
#     ), 
#     Y=inverse_poiss_MAP.view(-1,1).repeat((1,3)))

In [None]:
# plot(
#     Y = torch.stack(
#         [pred_at_inverseMAP_mean-pred_at_inverseMAP_std/2,
#         pred_at_inverseMAP_mean, 
#          pred_at_inverseMAP_mean+pred_at_inverseMAP_std/2],
#         dim = 1
#     ),
#     X =  inverse_poiss_MAP.view(-1)
#     )

In [None]:
# # # Get the forward model estimates
# inp = gpytorch.random_variables.GaussianRandomVariable(inverse_poiss_MAP.view(-1).log(), 
#                              gpytorch.lazy.DiagLazyVariable(1e-10*torch.ones_like(inverse_poiss_MAP.view(-1)))
#                             )
# pred_at_inverseMAP = likelihood(latent_func = inp, approx = False, max_photon=int(70))
# pred_at_inverseMAP_mean = pred_at_inverseMAP.mean()
# pred_at_inverseMAP_std = pred_at_inverseMAP.std()

# # plot(
# #     X = torch.stack(
# #         [gray_levels-,
# #         gray_levels, 
# #          gray_levels+pred_at_inverseMAP.std()],
# #         dim = 1
# #     ),
# #     Y =  inverse_poiss_MAP.view(-1,1).repeat(1,3)
# #     )

# plot(
#     X = torch.stack(
#         [pred_at_inverseMAP_mean-pred_at_inverseMAP_std/2,
#         gray_levels,
#          pred_at_inverseMAP_mean,
#          pred_at_inverseMAP_mean+pred_at_inverseMAP_std/2],
#         dim = 1
#     ),
#     Y =  inverse_poiss_MAP.view(-1,1).repeat(1,4)
#     )

In [None]:
# # Also, plot the resulting non-linear maximum likelihood transformation curve
# data = plot(X=gray_levels, Y=inverse_poiss_MAP, now = False)
# layoutArgs = defaultLayout(scale=1.2)
# dict_merge(layoutArgs, dict(
#     xaxis = dict(
#         title = 'Grey level in data (a.u.)'
#     ),
#     yaxis = dict(
#         title = 'Estimated photon flux (photon)',
#         showgrid=False
#     ),
#     colorway = ['black']
# ))

# fig = plt_type.Figure(data=data, layout=plt_type.Layout(**layoutArgs))

# if exportNow:
#     exportFigure(fig, filename= 'ch1_figResults_InverseML_NonLin_nf0'+ data_id + subdataset, 
#                  time_stamp = time_stamp, type='plot', image='svg')
# else:
#     plt(fig)

# if instant_clear_outputs:
#     clear_output()

#### Numeric effects of the transforms on mean-variance plots and relationship

In [None]:
# Plot variance over mean curves to test the likelihood estimating the correct gain

mean_orig = preprocUtils.nanmean(imgsImputed[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
var_orig = preprocUtils.nanvar(imgsImputed[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
mean_orig = mean_orig[torch.isnan(mean_orig)==0]
var_orig = var_orig[torch.isnan(var_orig)==0]

mean_lin = preprocUtils.nanmean(imgsImputedLin[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
var_lin = preprocUtils.nanvar(imgsImputedLin[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
mean_lin = mean_lin[torch.isnan(mean_lin)==0]
var_lin = var_lin[torch.isnan(var_lin)==0]

mean_photon = preprocUtils.nanmean(imgsImputedPhoton[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
var_photon = preprocUtils.nanvar(imgsImputedPhoton[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
mean_photon = mean_photon[torch.isnan(mean_photon)==0]
var_photon = var_photon[torch.isnan(var_photon)==0]

mean_corr = preprocUtils.nanmean(imgsImputedCorr[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
var_corr = preprocUtils.nanvar(imgsImputedCorr[train_x.long().unbind(1)][:,1000:].contiguous(), -1)
mean_corr = mean_corr[torch.isnan(mean_corr)==0]
var_corr = var_corr[torch.isnan(var_corr)==0]

affine_orig_fit = preprocUtils.torchLinReg(mean_orig.view(-1,1).cpu(), var_orig.view(-1,1).cpu(), exact=True)
affine_lin_fit = preprocUtils.torchLinReg(mean_lin.view(-1,1), var_lin.view(-1,1), exact=True)
affine_photon_fit = preprocUtils.torchLinReg(mean_photon.view(-1,1), var_photon.view(-1,1), exact=True)
affine_corr_fit = preprocUtils.torchLinReg(mean_corr.view(-1,1), var_corr.view(-1,1), exact=True)
print(affine_orig_fit[2], affine_orig_fit[0])
print(affine_lin_fit[2], affine_lin_fit[0])
print(affine_photon_fit[2], affine_photon_fit[0])
print(affine_corr_fit[2], affine_corr_fit[0])

In [None]:
# Collect all plots' information to iterate over
all_likelihood_gain_plots = {
    'orig': {
        'mean':mean_orig[::vis_skip],
        'var':var_orig[::vis_skip],
        'fit':affine_orig_fit,
        'fname':'ch1_figResults_GainEstmation_Orig_nf0',
        'title':'Original observations'
    },
    'lin': {
        'mean':mean_lin[::vis_skip],
        'var':var_lin[::vis_skip],
        'fit':affine_lin_fit,
        'fname':'ch1_figResults_GainEstmation_Lin_nf0',
        'title':'Linear likelihood gain correction'
    },
    'photon': {
        'mean':mean_photon[::vis_skip],
        'var':var_photon[::vis_skip],
        'fit':affine_photon_fit,
        'fname':'ch1_figResults_GainEstmation_Photon_nf0',
        'title':'Photomultiplier likelihood gain correction'
    },
    'corr': {
        'mean':mean_corr[::vis_skip],
        'var':var_corr[::vis_skip],
        'fit':affine_corr_fit,
        'fname':'ch1_figResults_GainEstmation_Corr_nf0',
        'title':'Photomultiplier likelihood and spatial gain correction'
    }
}

In [None]:
# Do all plots in a loop
for key in all_likelihood_gain_plots:#['photon']:#

    predictor = all_likelihood_gain_plots[key]['mean']
    target = all_likelihood_gain_plots[key]['var']

    cur_affine_fit = all_likelihood_gain_plots[key]['fit']
    y_gain_linear = cur_affine_fit[0]
    y_gain_intercept = cur_affine_fit[1]
    y_gain_offset = cur_affine_fit[2]

    pred_at = torch.linspace(0, predictor.max(),100)
    pred_result = (y_gain_linear*pred_at+y_gain_intercept).squeeze()
    n_outliers = (target>pred_result.max()).sum()


    layoutArgs = defaultLayout(scale=1.2)
    dict_merge(layoutArgs, dict(
        xaxis = dict(
            title = 'Mean intensity over time (a.u.)'
        ),
        yaxis = dict(
            title = 'Variance over time (a.u.<sup>2</sup>)',
            showgrid=True,
            range = [min([pred_result.min(), predictor.min()]), pred_result.max()]
        ),
        title = all_likelihood_gain_plots[key]['title'],

        annotations=[        
            dict(
                x=0.25,
                y=0.9,
                xanchor = 'center',
                yanchor='top',
                xref='paper',
                yref='paper',
                text='{} outliers / {} data'.format(int(n_outliers), target.numel()),
                showarrow=False,
                arrowhead=0,
                ax=0,
                ay=0
            ),
            dict(
                x=0.65,
                y=0.3,
                xanchor = 'center',
                yanchor='top',
                xref='paper',
                yref='paper',
                text='Affine fit: y = {:0.2f} * (x-{:0.2f})'.format(float(y_gain_linear.data[0]), float(y_gain_offset.data[0])),
                font = dict(color='#CC8C00'),
                showarrow=False,
                arrowhead=0,
                ax=0,
                ay=0
            )   
        ]
    ))

    fig= plt_type.Figure(
        data = [plt_type.Scattergl(x=predictor, y=target, 
                      mode='markers', 
                      marker=dict(
                          opacity=0.3,
                          color='black',
                          size=4
                      ),
                      name='Background pixels'
                     ),
        plt_type.Scattergl(x=pred_at, y=pred_result, 
                     marker=dict(color='orange'),
                     name='Affine fit'
                    )
        ],
        layout = plt_type.Layout(**layoutArgs)
    )

    if exportNow:
        exportFigure(fig, filename= all_likelihood_gain_plots[key]['fname']+ data_id + subdataset, 
                     time_stamp = time_stamp, type='plot', image='svg')
    else:
        plt(fig)

    if instant_clear_outputs:
        clear_output()
        time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots

# Spatial gain correction

Show images of the spatial gain correction procedure

Get for each training pixel the optimal Lambda (photon flux)

Compare the histograms on the original versus the corrected values as numerical results

In [None]:
all_pixel_per_micron = {'0':1.15, '1':0.8, '2':1.15, '3':1.7, '4':0.8}
pixels_per_micron = all_pixel_per_micron[data_id]

In [None]:
#imagesc(pred_gain_func, heatmap=dict(colorscale='div'), pixels_per_micron=pixels_per_micron, image='svg')

In [None]:
fig = imagesc(pred_gain_func, heatmap=dict(colorscale='div'), pixels_per_micron=pixels_per_micron, now=False)

if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_SpatialGainMap_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='image', image='svg')
else:
    plt(fig)

if instant_clear_outputs:
    clear_output()
    time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots

In [None]:
#imagesc(preprocUtils.nanmean(imgsImputed[:,:,1000:].detach(), dim=2), pixels_per_micron=pixels_per_micron, image='svg')

In [None]:
fig = imagesc(preprocUtils.nanmean(imgsImputed[:,:,1000:].detach(), dim=2), pixels_per_micron=pixels_per_micron, now=False)

if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_SpatialOrigMean_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='image', image='svg')
else:
    plt(fig)
    
if instant_clear_outputs:
    clear_output()
    time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots

In [None]:
#imagesc(preprocUtils.nanmean(imgsImputedPhoton[:,:,1000:].detach(), dim=2), pixels_per_micron=pixels_per_micron, image='svg')

In [None]:
fig = imagesc(preprocUtils.nanmean(imgsImputedPhoton[:,:,1000:].detach(), dim=2), pixels_per_micron=pixels_per_micron, now=False)

if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_SpatialPhotonMean_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='image', image='svg')
else:
    plt(fig)
    
if instant_clear_outputs:
    clear_output()
    time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots

In [None]:
#imagesc(preprocUtils.nanmean(imgsImputedCorr[:,:,1000:].detach(), dim=2), pixels_per_micron=pixels_per_micron, image='svg')

In [None]:
fig = imagesc(preprocUtils.nanmean(imgsImputedCorr[:,:,1000:].detach(), dim=2), pixels_per_micron=pixels_per_micron, now=False)

if exportNow:
    exportFigure(fig, filename= 'ch1_figResults_SpatialCorrMean_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp, type='image', image='svg')
else:
    plt(fig)
    
if instant_clear_outputs:
    clear_output()
    time.sleep(5) # Put sleep statements so browser can clear buffer and doesn't download empty plots

In [None]:
#raise('Stopping Run All')

#### Get the single Lambda rate per pixel (over all time samples) and show histograms for correcting that

In [None]:
from torch.utils.data import TensorDataset, DataLoader

max_input = 20.
photon_counts, photon_log_probs = (
    likelihood.getPhotonLogProbs(torch.arange(0.,max_input).view(1,-1).to(gray_levels.device)*torch.ones_like(gray_levels).view(-1,1), 
                                 max_photon=float(200), reNormalise = False))

# Get the response distribution at each potential photon count
p_PM = likelihood.createResponseDistributions(photon_counts) 
photon_log_probs_gray = likelihood.getLogProbSumOverTargetSamples(p_PM, gray_levels.view(-1))

#orig_dataset = TensorDataset(imgsImputed.cuda())
test_dataset = TensorDataset(imgsImputed[train_x.long().unbind(1)][:,1000:].contiguous().to(device))
data_loader = DataLoader(test_dataset, batch_size=50, shuffle=False, drop_last=False)

out = []
for mini_batch in data_loader:
    log_w = im2logPhotonProb(mini_batch[0], photon_log_probs_gray.to(device), gray_levels.to(device))
    out.append(
        getOptLambda(
            log_w,
            lambda_guess = log_w.max(-1)[1].float().mean(-1)
        ).cpu()
    )
    
testLambda = torch.cat(out, dim=0).squeeze().detach()
testLambdaCorr = testLambda.div(pred_gain_func[train_x.long().unbind(1)].view(-1).cpu()).detach()

In [None]:
torch.cuda.empty_cache()

In [None]:
data_max_vals = [float(preprocUtils.nanmax(testLambda)), float(preprocUtils.nanmax(testLambdaCorr))]
data_max_vals

In [None]:
# Get histograms of the gain-corrected versus non-corrected photon counts to test gain estimation

r0 = 0.
#r1 = np.floor(float(imgsImputed.max()/2))
r1 = float(max(data_max_vals)) # Use the original range
nbins = 100

Ahists, bin_centers = fast_histograms(testLambda.view(1,-1).to(device), r0, r1, nbins, output_device='cpu')
Ahists_corr, bin_centers = fast_histograms(testLambdaCorr.view(1,-1).to(device), r0, r1, nbins, output_device='cpu')

hists = torch.stack([Ahists.detach().view(-1), Ahists_corr.detach().view(-1)], dim=1).detach()

#plot(hists.log(), bin_centers.view(-1).detach())

In [None]:
layoutArgs = defaultLayout()
dict_merge(layoutArgs, dict(
    xaxis=dict(
        title = 'Mean Inferred Background Photon flux (per pixel, across frames)'
    ),
    yaxis=dict(
        title = 'Counts per bin',
        #type='log'
    ),
    barmode='overlay'
    
))

first_bin_cent = torch.tensor([bin_centers[0]-bin_centers[1]])
right_bin_width = 0.2

fig = plt_type.Figure(
        data = [
            plt_type.Bar(
                x=torch.cat([first_bin_cent, bin_centers[1:-1], torch.tensor([bin_centers[-1]+right_bin_width/2.])]),
                y=hists[:,0], 
                width = torch.cat([bin_centers[1]-first_bin_cent, bin_centers[2:]-bin_centers[1:-1],torch.tensor([right_bin_width])]),
                marker=dict(
                          opacity=0.7,
                          color='black',
                          #size=4
                      ),
                name = 'Original'
            ),
            plt_type.Bar(
                x=torch.cat([first_bin_cent, bin_centers[1:-1], torch.tensor([bin_centers[-1]+right_bin_width/2.])]), 
                y=hists[:,1], 
                width = torch.cat([bin_centers[1]-first_bin_cent, bin_centers[2:]-bin_centers[1:-1],torch.tensor([right_bin_width])]),
                marker=dict(
                          opacity=0.3,
                          color='orange',
                          #size=4
                      ),
                name = 'Gain Corrected'
            ),
        ],
        layout = plt_type.Layout(**layoutArgs)
    )

if exportNow:
    exportFigure(fig, filename= 'ch1_figResultsCorrHistogram_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp,
                 type='plot', image='svg')
else:
    plt(fig)
    
if instant_clear_outputs:
    clear_output()

In [None]:
testLambdaCorr.device

In [None]:
# ALSO DO IT BY MATCHING MEANS OF THE DISTRIBUTIONS
# Get histograms of the gain-corrected versus non-corrected photon counts to test gain estimation

div_corr = preprocUtils.nanmean(testLambdaCorr)/preprocUtils.nanmean(testLambda)
data_max_vals[1] = data_max_vals[1]/div_corr

r0 = 0.
#r1 = np.floor(float(imgsImputed.max()/2))
r1 = float(max(data_max_vals)) # Use the original range
nbins = 100


Ahists, bin_centers = fast_histograms(testLambda.view(1,-1).to(device), r0, r1, nbins, output_device='cpu')
Ahists_corr, bin_centers = fast_histograms(testLambdaCorr.view(1,-1).div(div_corr).to(device), r0, r1, nbins, output_device='cpu')

hists = torch.stack([Ahists.detach().view(-1), Ahists_corr.detach().view(-1)], dim=1).detach()

#plot(hists.log(), bin_centers.view(-1).detach())

In [None]:
layoutArgs = defaultLayout()
dict_merge(layoutArgs, dict(
    xaxis=dict(
        title = 'Mean Inferred Background Photon flux (per pixel, across frames)'
    ),
    yaxis=dict(
        title = 'Counts per bin',
        #type='log'
    ),
    barmode='overlay'
    
))

first_bin_cent = torch.tensor([bin_centers[0]-bin_centers[1]])
right_bin_width = 0.2

fig = plt_type.Figure(
        data = [
            plt_type.Bar(
                x=torch.cat([first_bin_cent, bin_centers[1:-1], torch.tensor([bin_centers[-1]+right_bin_width/2.])]),
                y=hists[:,0], 
                width = torch.cat([bin_centers[1]-first_bin_cent, bin_centers[2:]-bin_centers[1:-1],torch.tensor([right_bin_width])]),
                marker=dict(
                          opacity=0.7,
                          color='black',
                          #size=4
                      ),
                name = 'Original'
            ),
            plt_type.Bar(
                x=torch.cat([first_bin_cent, bin_centers[1:-1], torch.tensor([bin_centers[-1]+right_bin_width/2.])]), 
                y=hists[:,1], 
                width = torch.cat([bin_centers[1]-first_bin_cent, bin_centers[2:]-bin_centers[1:-1],torch.tensor([right_bin_width])]),
                marker=dict(
                          opacity=0.3,
                          color='orange',
                          #size=4
                      ),
                name = 'Gain Corrected'
            ),
        ],
        layout = plt_type.Layout(**layoutArgs)
    )

if exportNow:
    exportFigure(fig, filename= 'ch1_figResultsCorrHistogram_MatchedMean_nf0'+ data_id + subdataset, 
                 time_stamp = time_stamp,
                 type='plot', image='svg')
else:
    plt(fig)
    
if instant_clear_outputs:
    clear_output()