# Evaluate efficacy of gain-correction

In [1]:
# Load a trained model
import torch
import math
import gpytorch

from torch.utils.data import TensorDataset, DataLoader


import preprocUtils
import preprocRandomVariables
import preprocLikelihoods
import preprocModels
import preprocKernels

from collections import OrderedDict

# Plotly 
import plotly
from plotly.offline import iplot as plt
from plotly import graph_objs as plt_type
plotly.offline.init_notebook_mode(connected=True)
import colorcet # For custom colormaps

import nbimporter
from preprocVisualisationTesting import *

Importing Jupyter notebook from preprocVisualisationTesting.ipynb


In [2]:
#torch.cuda.set_device(0)

In [3]:
def display_results(data_id = '0', prior='noPrior', lik='linLik', stamp = '_00_firstRun_noPCremoved',
                    data_dir='/nfs/data/gergo/Neurofinder_update/', device = 'cpu', 
                    retVars=False, retResults=False, subdataset = '00'):

    dataset_name = 'neurofinder.0' + data_id +'.' + subdataset
    mll = torch.load(data_dir + dataset_name +'/preproc2P/savedModels/mll_' +prior+'_' + lik + stamp, map_location=device)

    model = mll.model
    likelihood = mll.likelihood
    mean_im = mll.mean_im

    train_x = mll.train_x
    train_y = mll.train_y

    dataStats = preprocUtils.getDataStatistics(train_x, train_y)
    
    print(dataStats)
    print(OrderedDict(likelihood.named_parameters()))
    
    
    # Set the model and likelihood in evaluation mode
    model.eval()
    likelihood.eval()
    
    # Create test grids over which we predict for easy visualisations
    n_test_grid = torch.tensor(mean_im.shape)
    n_test_grid_small = 32
    test_x = preprocUtils.create_test_grid(n_test_grid, ndims=2, device=device, a=dataStats['x_minmax'][0,:], b=dataStats['x_minmax'][1,:])
    test_x_small = preprocUtils.create_test_grid(n_test_grid_small, ndims=2, device=device, a=dataStats['x_span'][0][0], b=dataStats['x_span'][0][1])
    
    # Get log_photon counts:
    pred_log_photon = model(test_x)
    if isinstance(model.mean_module, gpytorch.means.ConstantMean):
        pred_gain_func = (pred_log_photon.mean()-model.mean_module.constant.data).exp()
    else:
        pred_gain_func = pred_log_photon.mean().exp()
    
    divBy = pred_gain_func.reshape(*mean_im.shape)
    gainRange = [0.1, 10.]

    corr_mean_im = (mean_im).div(torch.clamp(divBy, min=1./gainRange[1], max=1./gainRange[0]))
    #corr_mean_im = (mean_im-likelihood.offset).div(torch.clamp(divBy, min=1./gainRange[1], max=1./gainRange[0]))+likelihood.offset
    
    if not retVars:
        imagesc(pred_gain_func.reshape(*n_test_grid), 
        title = 'Expected gain function',
           heatmap=dict(colorscale = 'div'))
    
        # Show the original and the gain_corrected images
        imagesc(mean_im, title = 'Mean image')
        imagesc(corr_mean_im, title = 'Corrected mean image')

    if retVars:
        return mll, model, likelihood, train_x, train_y, dataStats, mean_im, pred_gain_func.reshape(*n_test_grid), corr_mean_im
    
    if retResults:
        return mean_im, pred_gain_func.reshape(*n_test_grid), corr_mean_im

In [4]:
stamp_git = '_gitsha_' + '2bd0d720de0995be6b0f1795304839f9877cb6c3'
stamp_training_type = '_rPC_1_origPMgain_useNans'

In [5]:
stamp_trainingCoverage = '_targetCoverage_05'
#stamp_modelGridType = '_grid_30_7'
stamp_modelGridType = '_grid_30_7_largeBatch'
data_id = '0'
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = data_id,  
                prior='noPrior', 
                lik='unampLik', 
                stamp = stamp_git + stamp_training_type + stamp_trainingCoverage + stamp_modelGridType
               )

imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im, pixels_per_micron=1.15)
#imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

SyntaxError: invalid syntax (<ipython-input-5-241a67341eca>, line 10)

In [None]:
data_id = '0'
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = data_id,  prior='noPrior', lik='linLik', 
                stamp = stamp_git + stamp_training_type)

imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im, pixels_per_micron=1.15)
#imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# data_id = '0'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id,  prior='noPrior', lik='unampLik', stamp = '_00_firstRun_noPCremoved')

# imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# imagesc(mean_im, pixels_per_micron=1.15)
# imagesc(corr_mean_im, pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# data_id = '0'
# subdataset = '00'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id, prior='noPrior', lik='poissLik', stamp = '_05_origPMGain_test_03',
#                subdataset = subdataset)

# imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# imagesc(mean_im, pixels_per_micron=1.15)
# imagesc(corr_mean_im, pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# data_id = '4'
# subdataset = '00'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id, prior='expertPrior', lik='poissLik', stamp = '_05_origPMGain_test_05_finegrid_50_3interp',
#                subdataset = subdataset)

# imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# imagesc(mean_im, pixels_per_micron=1.15)
# imagesc(corr_mean_im, pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# data_id = '1'
# subdataset = '00'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id, prior='expertPrior', lik='linLik', stamp = '_05_origPMGain_test_04',
#                subdataset = subdataset)

# imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# imagesc(mean_im, pixels_per_micron=1.15)
# imagesc(corr_mean_im, pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# (train_y[:,0] <= torch.max(likelihood.offset, torch.tensor(0.)).data).sum()

In [None]:
data_id = '0'
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = data_id, prior='expertPrior', lik='unampLik', stamp = '_02_origPMGain_test_05')

imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im.clamp(max=1000), pixels_per_micron=1.15)
#imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# mll.pmGain_y

In [None]:
# imagesc(mll.pmGain_y.reshape(mean_im.shape))

In [None]:
# imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# imagesc(mean_im, pixels_per_micron=1.15)
# imagesc(corr_mean_im, pixels_per_micron=1.15)

In [None]:
# data_id = '0'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id, prior='noPrior', lik='poissLik', stamp = '_02_origPMGain_test_05')

# # imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# # imagesc(mean_im, pixels_per_micron=1.15)
# # imagesc(corr_mean_im.clamp(max=20), pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
# data_id = '0'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id, prior='noPrior', lik='poissLik', stamp = '_02_normPMGain_test_05')

# # imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# # imagesc(mean_im, pixels_per_micron=1.15)
# # imagesc(corr_mean_im.clamp(max=20), pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
data_id = '0'
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = data_id, prior='expertPrior', lik='linLik', stamp = '_05_origPMGain_test_05_finegrid_50_3interp')

imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im, pixels_per_micron=1.15)
#imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
data_id = '0'
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = data_id, prior='noPrior', lik='linLik', stamp = '_05_origPMGain_test_05_finegrid_50_3interp')

imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im, pixels_per_micron=1.15)
#imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
imagesc(corr_mean_im, pixels_per_micron=1.15)

In [None]:
# def logistic(x,  x0=0., k=1., L=1.):
#     return x.add(-x0).mul(-k).exp().add(1).reciprocal().mul(L)

In [None]:
# logistic(likelihood.logit_underamplified_probability)

In [None]:
# data_id = '1'
# mll, model, likelihood, train_x, train_y, \
# dataStats, mean_im, pred_gain_func, corr_mean_im = \
# display_results(retVars=True,data_id = data_id, prior='noPrior', lik='poissLik', stamp = '_02_normPMGain_test_05')

# # imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
# # imagesc(mean_im, pixels_per_micron=1.15)
# # imagesc(corr_mean_im.clamp(max=20), pixels_per_micron=1.15)
# #imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_lik')

In [None]:
def print_priors(module):
    for name, param, prior in module.named_parameter_priors():
        try:
            print(name, float(param), [float(prior.a), float(prior.b)])#, prior.log_prob(param)
        except:
            print(name, prior.log_prob(param))

In [None]:
print_priors(likelihood)

In [None]:
OrderedDict(mll.likelihood.named_parameters())

In [None]:
import torch
import json
import numpy as np
from numpy import array, zeros
from scipy.misc import imread
from glob import glob

# File system management
import os
import errno

from IPython.core.debugger import set_trace
import warnings

from preprocUtils import toTorchParam
import copy

# Evaluation metrics
data_dir='/nfs/data/gergo/Neurofinder_update/'
device = 'cpu'
max_T = 1000

dataset_name = 'neurofinder.0' + data_id +'.00'
#dataset_name = 'neurofinder.0' + data_id +'.00.test'
#dataset_name = 'neurofinder.0' + data_id +'.01.test'
use_validation_data_only = True

files = sorted(glob(data_dir+dataset_name+'/images/*.tiff'))
imgs = np.array([imread(f) for f in files[:min(max_T, len(files))]])

imgs = torch.tensor(imgs.astype(np.float32)).permute(1,2,0).to(device)

savedImgsImputed = sorted(glob(data_dir+dataset_name+'/preproc2P/imgsImputed*.npy'))
        
if savedImgsImputed:
    imgsImputed = torch.tensor(np.load(savedImgsImputed[-1]))
    imgsImputedLoaded = True

    if imgsImputed.size(2) < max_T:
        warnings.warn("""In the saved imgsImputed data there is only {} frames,
                      less than the requested {}, using only available number"""
                      .format(imgsImputed.size(2), max_T))
    else:
        imgsImputed = imgsImputed[:,:,:max_T]
else:
    imgsImputed = imgs

        
if use_validation_data_only:
    imgs = imgs[:,:,500:]
    imgsImputed = imgsImputed[:,:,500:]
        
# Correct for pmGain, if mll has it

if hasattr(mll, 'pmGain_y'):
    imgsImputed.div_(mll.pmGain_y.reshape(*imgsImputed.shape[:2]).unsqueeze(-1))
    
imgsImputed.shape

In [None]:
def getLambdaLogProb(log_w, log_lam, dim=-1, requires_grad=True):
    if dim == -1:
        dim=log_w.ndimension()-1
    # Get the log probabilities (numerically stable), then logsumexp()
    x = torch.arange(log_w.size(dim), device=log_w.device).float().view(*([1]*max(dim,0))+[-1]+[1]*max(log_w.ndimension()-dim-1,0))

    lprobs = (log_w
              - (x+1.).lgamma()
              + x*log_lam
              -log_lam.exp())

    del x
    torch.cuda.empty_cache()

    if requires_grad:
        return lprobs.logsumexp(dim=dim)
    else:
        return lprobs.logsumexp(dim=dim).data
    
# Define an optimiser that starts from MAP estimate and uses lambda log prob as objective
class LambdaOptimiser(torch.nn.Module):
    def __init__(self, log_w, lambda_guess=torch.tensor(1.)):
        super(LambdaOptimiser, self).__init__()
        self.register_buffer("log_w", log_w)
        
        # Put log lambda into the appropriate shape
        log_lam = lambda_guess.float().clamp(min=1e-1).log()
        log_lam = log_lam.view(list(log_lam.size())+[1]*max(0, log_w.ndimension()-log_lam.ndimension()))
        self.register_parameter("log_lam", preprocUtils.toTorchParam(log_lam, 
                                                                     paramShape=log_lam.size(), 
                                                                     device=log_lam.device))
        
    def forward(self):
        return -getLambdaLogProb(self.log_w, self.log_lam).sum()
        

def getOptLambda(log_w, lambda_guess=None): # Opt lambda given discrete estimate of photon distribution
#     if log_w[0]==0: # Special case of certainty
#         return torch.tensor([0.])
    lambda_guess = lambda_guess if lambda_guess is not None else torch.tensor(1., device=log_w.device)
    lamModule = LambdaOptimiser(log_w.detach(), lambda_guess=lambda_guess)
    optim = torch.optim.LBFGS(lamModule.parameters()) 
    # LBFGS does optimisation internally via "closure", no need to iterate outside
    def closure():
        optim.zero_grad()
        loss = lamModule()
        #print(loss, lamModule.log_lam.exp().data)
        loss.backward()
        return loss  

    optim.step(closure)
    
    out = lamModule.log_lam.exp().data
    
    del lamModule
    
    return out
    

    
# class LambdaLikOptimiser(torch.nn.Module):
#     def __init__(self, likelihood, lambda_guess=torch.tensor([1.]), max_photon = 50.):
#         super(LambdaLikOptimiser, self).__init__()
#         self.likelihood = likelihood
#         self.register_parameter("log_lam", preprocUtils.toTorchParam(lambda_guess.float().clamp(min=1e-1).log(), ndims=0))
        
#     def forward(self, cur_target):
#         return -self.likelihood.single_log_prob( 
#                                     self.log_lam.expand(cur_target.size()).view(1,-1), 
#                                     cur_target, batchsize = int(450), max_photon= max_photon).sum()
    
# def getOptLambdaFromLik(data, likelihood, lambda_guess=torch.tensor([1.]), max_photon = 50.):
#     lamModule = LambdaLikOptimiser(likelihood, lambda_guess=lambda_guess, max_photon = max_photon)
#     optim = torch.optim.LBFGS([lamModule.log_lam])
#     def closure():
#         optim.zero_grad()
#         loss = lamModule(data)
#         #print(loss, lamModule.log_lam.exp().data)
#         loss.backward()
#         return loss  
#     optim.step(closure)
        
#     return lamModule.log_lam.exp().data

In [None]:
def im2logPhotonProb(im, photon_log_probs, gray_levels, interpolate=False):
    if not interpolate:
        # Nearest neighbor version
        return photon_log_probs[
            (im.unsqueeze(-1) - gray_levels.view(*([1]*im.ndimension()+[-1]))).abs().min(-1)[1],:]
    else:

        # Interpolation version
        dists = (im.unsqueeze(-1) - gray_levels.view(*([1]*im.ndimension()+[-1]))) # Distances along last dimension

        # Find the last element in gray_levels that im is larger than (so dists>=0), then interpolate or extrapolate appropriately
        last_pos_ind = (((dists>=0).sum(-1))-1).clamp(0, dists.size(-1)-2).unsqueeze(-1)
        interp_inds = torch.cat([last_pos_ind, last_pos_ind+1], dim=-1)

        static_indices = np.indices(interp_inds.shape)
        static_indices[-1] = interp_inds
        dists_sorted = dists[static_indices]


        # Get the two closest distances, and values to linearly inter-/extra-polate
        dist0 = dists_sorted[...,0]
        dist1 = dists_sorted[...,1]

        val0 = photon_log_probs[interp_inds[...,0],:]
        val1 = photon_log_probs[interp_inds[...,1],:]

        dist_sign_same = ((dist0*dist1)>=0.).float()

        denom = (
            # If different sign (interpolate)
            (1.-dist_sign_same)*(dist1.abs()+dist0.abs()) 
            # If same sign (extrapolate)
            + dist_sign_same *((dist0-dist1).abs())
        )


        rx = dist0.unsqueeze(-1)

        m = (val1-val0)/denom.unsqueeze(-1)


        fx = m*rx + val0


        return fx

# imgsImputedCorr = preprocUtils.apply(lambda x: correctImage(x, inverse_poiss_mean, gray_levels), 
#                                      copy.deepcopy(imgsImputed), dim=2)

def im2photon(im, inverse_poiss_MAP, gray_levels, keep_zeros=True):
    # Linear inter/extra-polation version
    
    dists = (im.unsqueeze(-1) - gray_levels.view(*([1]*im.ndimension()+[-1]))) # Distances along last dimension

    # Find the last element in gray_levels that im is larger than (so dists>=0), then interpolate or extrapolate appropriately
    last_pos_ind = (((dists>=0).sum(-1))-1).clamp(0, dists.size(-1)-2).unsqueeze(-1)
    interp_inds = torch.cat([last_pos_ind, last_pos_ind+1], dim=-1)

    static_indices = np.indices(interp_inds.shape)
    static_indices[-1] = interp_inds
    dists_sorted = dists[static_indices]


    # Get the two closest distances, and values to linearly inter-/extra-polate
    dist0 = dists_sorted[...,0]
    dist1 = dists_sorted[...,1]

    val0 = inverse_poiss_MAP[interp_inds[...,0]]
    val1 = inverse_poiss_MAP[interp_inds[...,1]]

    dist_sign_same = ((dist0*dist1)>=0.).float()

    denom = (
        # If different sign (interpolate)
        (1.-dist_sign_same)*(dist1.abs()+dist0.abs()) 
        # If same sign (extrapolate)
        + dist_sign_same *((dist0-dist1).abs())
    )


    rx = dist0

    m = (val1-val0)/denom


    fx = m*rx + val0

    if keep_zeros:
        fx = fx * (im!=0).type(fx.type())
    
    return fx
    
    
    # Nearest neighbor version
    #return inverse_poiss_MAP[(im.unsqueeze(-1) - gray_levels.view(*([1]*im.ndimension()+[-1]))).abs().min(-1)[1]]

# imgsImputedCorr = preprocUtils.apply(lambda x: correctImage(x, inverse_poiss_mean, gray_levels), 
#                                      copy.deepcopy(imgsImputed), dim=2)

In [None]:
# Estimating photon probability from grey level

light_levels = torch.arange(1e-4, 1.,1e-2) # In photon
model_out = gpytorch.random_variables.GaussianRandomVariable(light_levels.log(), gpytorch.lazy.DiagLazyVariable(1e-7*torch.ones_like(light_levels)))


max_photon = float(70)
photon_counts, photon_log_probs = (
    likelihood.getPhotonLogProbs(model_out.mean().view(-1,1).exp(), max_photon=max_photon, reNormalise = False))
p_PM = likelihood.createResponseDistributions(photon_counts)

gray_levels = torch.cat([torch.tensor([-0.02*np.nanmax(train_y), 0., 1e-10]), # Important to add a non-zero value close to zero
                         torch.logspace(-3, imgsImputed.max().log10(),100)])
#                           torch.linspace(0., 3., 100),
#                          torch.linspace(3.1, imgsImputed.max(), 5)])


photon_log_probs = likelihood.getLogProbSumOverTargetSamples(p_PM, gray_levels.view(-1))

# Fix the 0 gray_level issue (by making it certainly 0 photon)
# photon_log_probs[:2, 0] = 0.
# photon_log_probs[:2, 1:] = -float('inf')

#log_photon_prob_marginals = photon_log_probs.exp().div(photon_log_probs.exp().sum(1).view(-1,1)).log()
log_photon_prob_marginals = photon_log_probs - photon_log_probs.logsumexp(1).unsqueeze(1)

plot(log_photon_prob_marginals[:,:10].exp(), gray_levels)

#print(gray_levels[:50])
#plot(photon_prob_marginals[:50,:].t())



tmp = log_photon_prob_marginals.detach().exp().data
n_max = 7
fig = plotStacked(torch.cat([tmp[:,:n_max], (1.-tmp[:,:n_max].sum(1)).view(-1,1)],dim=1), gray_levels, now=False)
plt(fig)

# Plot gamma posterior mean and variance (single sample, weigthed sum of p(num_photon)*num_photon)


#inverse_poiss_mean = (tmp_marginals*torch.arange(max_photon).view(1,-1)).sum(1)

# inverse_poiss_alpha = (tmp_marginals*torch.arange(max_photon).view(1,-1)).sum(1)
# inverse_poiss_beta = torch.arange(max_photon).view(1,-1).sum(1)/max_photon

#inverse_poiss_MAP = tmp_marginals.max(1)[1]
inverse_poiss_MAP = torch.cat([getOptLambda(log_photon_prob_marginals[i,:], 
                                            lambda_guess=log_photon_prob_marginals[i,:].max(0)[1]) 
                               for i in range(log_photon_prob_marginals.size(0))])

inverse_poiss_MAP2 = torch.cat([getOptLambda(photon_log_probs[i,:], 
                                            lambda_guess=photon_log_probs[i,:].max(0)[1]) 
                               for i in range(photon_log_probs.size(0))])

#inverse_poiss_MAP = torch.cat([getOptLambdaFromLik(gray_levels[i].view(-1), likelihood) for i in range(log_photon_prob_marginals.size(0))])

#plot(inverse_poiss_MAP.view(-1,1), gray_levels)

plt(plot(inverse_poiss_MAP.view(-1,1), gray_levels, now=False)
    +plot(inverse_poiss_MAP2.view(-1,1), gray_levels, now=False)
   )


# plot(torch.cat([inverse_poiss_MAP.view(-1,1), 
#                 im2photon(torch.arange(0,gray_levels.max()*1.1,1e-1), inverse_poiss_MAP, gray_levels).view(-1,1)], dim=1), 
#                gray_levels)

#imgsImputedCorr = copy.deepcopy(imgsImputedCorr)



# exportFigure(fig, image='svg', filename='fig1-photon_cum_prob')

# plt(plt_type.Figure(data=data, 
#                 layout=plt_type.Layout(
#                     xaxis=dict(
#                         title='Grey level in data'
#                     ),
#                     yaxis=dict(
#                         title='Cumulative probability of photon count'
#                     )
#                 ))
#    )

# likelihood.__class__

# tmp = likelihood.single_log_prob(torch.arange(0.,20.).view(1,-1)*torch.ones_like(gray_levels).view(-1,1), gray_levels.view(-1))


In [None]:
imagesc(imgsImputed.permute(2,0,1)[0])

In [None]:
imagesc(im2photon(imgsImputed.permute(2,0,1)[0], inverse_poiss_MAP, gray_levels, keep_zeros=True))

In [None]:
imgsImputedPhoton = torch.stack([im2photon(image, inverse_poiss_MAP, gray_levels, keep_zeros=True) for image in imgsImputed.permute(2,0,1)], dim=0)

In [None]:
imagesc(imgsImputedPhoton.mean(0))

In [None]:
gainRange = [1e-20, float('inf')]
imgsImputedCorr = ((imgsImputedPhoton)
                   .div(torch.clamp(pred_gain_func, min=1./gainRange[1], max=1./gainRange[0]).unsqueeze(0))
                  )

#Installing packages with pip or conda within running jupyter kernel
#import sys
#!conda install --yes --prefix {sys.prefix} tifffile
#!{sys.executable} -m pip install tifffile

In [None]:
imagesc(imgsImputedCorr.mean(0))

In [None]:
from tifffile import imsave

for index, image in enumerate(imgsImputedCorr):
    imsave(
        data_dir+dataset_name+'/preproc2P/images/image' + str(index).zfill(5) + '.tif',
        image.detach().numpy().astype('uint16')
    )

In [None]:
# WHAT does this do???
#imagesc(im2logPhotonProb(imgsImputed.permute(2,0,1)[0], photon_log_probs, gray_levels, interpolate=True))

In [None]:
# imgsImputedPhoton = torch.cat([im2photon(im_batch[0].permute(1,2,0), inverse_poiss_MAP, gray_levels) for im_batch in data_loader], dim=-1)

In [None]:
from torch.utils.data import TensorDataset, DataLoader

#orig_dataset = TensorDataset(imgsImputed.cuda())
orig_dataset = TensorDataset(imgsImputed)
data_loader = DataLoader(orig_dataset, batch_size=50, shuffle=False, drop_last=False)

out = []
for im_batch in data_loader:
    #log_w = im2logPhotonProb(im_batch[0], photon_log_probs.cuda(), gray_levels.cuda())
    log_w = im2logPhotonProb(im_batch[0], photon_log_probs, gray_levels)
    out.append(getOptLambda(
        log_w,
        lambda_guess = log_w.max(-1)[1].float().mean(-1)
    ))
    
imgsImputedLambda = torch.cat(out, dim=0).squeeze()


# # Permute things so it works for DataLoader (that splits along first dimension)
# orig_dataset = TensorDataset(imgsImputed.permute(2,0,1))
# data_loader = DataLoader(orig_dataset, batch_size=50, shuffle=False, drop_last=False)

# imgsImputedPhoton = torch.cat([im2photon(im_batch[0].permute(1,2,0), inverse_poiss_MAP, gray_levels) for im_batch in data_loader], dim=-1)

In [None]:
# save the imputed images for CHOMP to use!

In [None]:
del data_loader, orig_dataset, log_w, out

torch.cuda.empty_cache()

imgsImputedLambda = imgsImputedLambda.squeeze()

In [None]:
torch.cuda.empty_cache()

In [None]:
imagesc(imgsImputedLambda)

In [None]:
imagesc(imgsImputedLambda.cpu()/pred_gain_func.cpu())

In [None]:
imagesc(pred_gain_func, heatmap={'colorscale':'div'})

In [None]:
tmp = copy.deepcopy(mean_im)
tmp[train_x.long().unbind(1)] = float('nan')
imagesc(tmp)

# Get goodness of fit for the model

Assuming a pixel has a non-variable lambda over time, and so the noise is only due to Poissonity and photomultiplier properties

In [None]:
def fast_digitise(A, r0, r1, nbins=None, binsize=None  ):
    """Inspired by https://stackoverflow.com/questions/26783719/efficiently-get-indices-of-histogram-bins-in-python
    
    Treat <r0+jitter and r1<= as two seperate bins (so add 1., then clamp at 0)
    
    # This approach works because pytorch seem to remember sign very well even with little jitter 
    (so if A==r0, A-(r0+jitter) is negative), but (r1-(r0+jitter) / (r1-r0)) = 1, which is weird but useful
    """
    
    if binsize is None:
        binsize=1.
    
    if nbins is None:
        nbins = int(torch.tensor((r1-r0)/binsize).floor().add(2.))
    
    jitter = 1e-12
    bin_center_correction = (r1-r0)/(2.*float(nbins-2.)) # So that r0 and r1 are bin centers rather than edges
#     r0 -= bin_edge_correction
#     r1 += bin_edge_correction
    
    
    out_hists = ((A-(r0+jitter)) * (float(nbins-2)/(r1-r0))).floor().long().add(1.).clamp(0, nbins-1)
    bin_centers = torch.cat([torch.tensor(r0).view(-1), 
                                torch.linspace(r0+bin_center_correction, r1-bin_center_correction, nbins-2) , 
                                torch.tensor(r1).view(-1)]).to(A.device)

    return out_hists, bin_centers
    
def fast_histograms(A, r0, r1, nbins, dim = -1, output_device = None):
    dim = dim if dim>0 else (A.ndimension()+dim)
    output_device = output_device if output_device is not None else A.device

    Adigitised, bin_centers = fast_digitise(A, r0, r1, nbins)
    Adims = list(A.size())
    Adims[dim]=int(bin_centers.numel())
    Ahists = torch.zeros(*Adims, device=A.device)
    
    
    for i in range(Ahists.size(dim)):
        ind_range = [slice(None)]*max(dim,0)+[i]+[slice(None)]*max(A.ndimension()-dim-1,0)
        Ahists[ind_range].add_((Adigitised==i).sum(dim).float())
        
    return Ahists.to(output_device), bin_centers.to(output_device)

def fast_predictive_probs(Alambda, bin_centers, photon_log_probs, gray_levels, dim = -1, batch_size=5,
                          interpolate=False,
                          output_device = None):
    """ Computes predictive probabilities of the likelihood with rates Alambda for histogram bins given"""
    dim = dim if dim>0 else (Alambda.ndimension()+1+dim)
    #bin_centers = torch.linspace(r0, r1, nbins, device=Alambda.device)
    
    # The two edges of bin_centers are assumed to be cutoff/saturation points and CDF should be computed instead of pdf
    # At 0 this is done correctly in log_photon_prob_marginals, but not at saturation    
    
    log_w = im2logPhotonProb(bin_centers, photon_log_probs, gray_levels, interpolate=interpolate)
    
    cur_dataset = TensorDataset(Alambda.contiguous().view(-1))
    data_loader = DataLoader(cur_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    
    output_device = output_device if output_device is not None else Alambda.device
    
    A_logpred_probs = []
    for Alambda_batch in data_loader:
        tmp = getLambdaLogProb(
                log_w.unsqueeze(1),
                Alambda_batch[0].log().view(1,-1,1), 
                requires_grad = False
                ).to(output_device)
        A_logpred_probs.append(tmp)
        del tmp
        torch.cuda.empty_cache()
        
    A_logpred_probs = torch.cat(A_logpred_probs, dim=1)
    
    # Get normalised probabilities per bin 
    #(effectively turning log probs into a discrete distribution, 
     # as a stepwise constant approximation to the true pdf)
    A_logpred_probs -= A_logpred_probs.logsumexp(0).unsqueeze(0)
    
    return A_logpred_probs.reshape(-1, *Alambda.size()).permute(1,2,0).exp()


def getWeightedSquaredHistogramError(hist_counts, hist_pred_probs, dim=-1):
    dim = dim if dim>0 else (hist_counts.ndimension()+dim)
    N = hist_counts.sum(dim)
    res_counts = hist_counts - hist_pred_probs*N.unsqueeze(dim)
    expected_variance = N.unsqueeze(dim)*hist_pred_probs*(1.-hist_pred_probs)
    
    return (res_counts.pow(2)/expected_variance).sum(dim)


import scipy.stats
from preprocUtils import nansum
def getChiSquaredHistogramError(hist_counts, hist_pred_probs, dim=-1, n_params = 1):
    """http://maxwell.ucsc.edu/~drip/133/ch4.pdf , eq.(3)
    
    Degrees of freedom (num_hist_bins - num_constraints of pred_probs)
    
    """
    dim = dim if dim>0 else (hist_counts.ndimension()+dim)
    N = nansum(hist_counts, dim)
    Nbins = (torch.isnan(hist_counts)==False).sum(dim)
    res_counts = hist_counts - hist_pred_probs*N.unsqueeze(dim)
    expected_value = N.unsqueeze(dim)*hist_pred_probs
    
    chi_squared_test_statistic = res_counts.pow(2)/expected_value
    
    chi_square_at_point05 = torch.tensor(scipy.stats.chi2.isf(0.05, Nbins-1- n_params))
    
    return nansum(chi_squared_test_statistic, dim), chi_square_at_point05


def getLikelihoodRatioHistogramError(hist_counts, hist_pred_probs, dim=-1, n_params = 1):
    """So-called multinomial test
    
    Degrees of freedom (num_hist_bins - num_constraints of pred_probs)
    
    """
    dim = dim if dim>0 else (hist_counts.ndimension()+dim)
    N = nansum(hist_counts, dim)
    Nbins = (torch.isnan(hist_counts)==False).sum(dim)
    hist_obs_probs = hist_counts/N.unsqueeze(dim)

    lr_test_statistic = 2*hist_counts*(hist_obs_probs/hist_pred_probs).clamp(1e-10, 1e20).log()
    
    
    df = int(hist_counts.size(dim))-1- n_params
    
    #correction_factor = 1. + hist_pred_probs.log().mul(-1.).logsumexp(dim) #Williams (1976)
    
    chi_square_at_point05 = torch.tensor(scipy.stats.chi2.isf(0.05, Nbins-1- n_params))
    
    return nansum(lr_test_statistic, dim), chi_square_at_point05
    
    
def getKolmogorovSmirnovHistogramError(hist_counts, hist_pred_probs, dim=-1):
    """Kolmogorov-Smirnov test for goodness of fit
    
    two-sided test, implemented for histogram data, but test statistics and critical value computed as in scipy
    https://github.com/scipy/scipy/blob/master/scipy/stats/stats.py#L4416
    """
    dim = dim if dim>0 else (hist_counts.ndimension()+dim)
    N = hist_counts.sum(dim)
    hist_obs_cdf = hist_counts.cumsum(dim)/N.unsqueeze(dim)
    hist_pred_cdf = hist_pred_probs.cumsum(dim)

    ks_test_statistic = (hist_obs_cdf-hist_pred_cdf).abs().max(dim)[0]
    
    # Use approx mode for small sample
  
    return ks_test_statistic, (scipy.stats.distributions.kstwobign.sf(ks_test_statistic * N.sqrt())) #(2 * scipy.stats.distributions.ksone.sf(ks_test_statistic, N))

In [None]:
def merge_bins(hist_counts, hist_pred_probs, 
               target_expected_count = 5, 
               remove_first_bin = True,
               adaptive_cats_based_on_lambda = 1,
               hist_lambdas = torch.empty(0)
              ):
    """
    Ensure that on average the histogram bins have enough observations in them in expectation
    Various rules of thumb:
     - https://www.statsdirect.com/help/nonparametric_methods/chisq_goodness_fit.htm
     
     - D. S. Moore, G. P. McCabe, Introduction to the Practice of Statistics, W. H. Freeman Publishing Company, New York, 2007.
    
    (Expected frequency > 1 everywhere and expected frequency > 5 in 80% of the bins)
    
    Compute the corresponding probabilities as well in the merged bins
    """
    
    if adaptive_cats_based_on_lambda > 1:
        lambda_cat_inds = torch.tensor(np.digitize(
            hist_lambdas, 
            np.quantile(hist_lambdas.view(-1), np.linspace(0.0, 0.95, adaptive_cats_based_on_lambda)))-1, 
            #np.linspace(hist_lambdas.min(), hist_lambdas.max()+1, adaptive_cats_based_on_lambda+1))-1,
                                      dtype=torch.uint8)
    else:
        lambda_cat_inds=torch.zeros_like(hist_counts[:,:,0])
        
    if remove_first_bin:
        hist_counts = hist_counts[:,:,1:]
        hist_pred_probs = hist_pred_probs[:,:,1:]
        # Renormalise predictive probabilities
        hist_pred_probs.div_(preprocUtils.nansum(hist_pred_probs, dim=2).unsqueeze(-1)) 
    
    n = hist_counts.sum(2).view(-1)
    hist_expected_freqs = n.unsqueeze(-1) * hist_pred_probs.view(-1, hist_pred_probs.size(-1))
    
    out_counts = torch.zeros_like(hist_counts)
    out_pred_probs = torch.zeros_like(hist_pred_probs)

    for cur_cat in range(0, adaptive_cats_based_on_lambda):
        #print("Current category: {} \n --------------".format(cur_cat))
        cur_pixels_mask = (lambda_cat_inds == cur_cat)

        low_freqs = torch.tensor(
            np.quantile(
                hist_expected_freqs[cur_pixels_mask.view(-1)],
                    q = 0.2, axis = 0)
        )

        # Can't use A[mask1][:,mask2] = ... as assignment (but also doesnt give error.)
        tmp_counts = torch.zeros_like(out_counts[cur_pixels_mask])
        tmp_pred_probs = torch.zeros_like(out_pred_probs[cur_pixels_mask])

        cur_bin = 0
        cur_merged_bin = 0
        cur_expected_count = 0
        while cur_bin < low_freqs.size(0):
            #print(cur_bin)
            cur_bin_low = cur_bin

            cur_expected_count = low_freqs[cur_bin]
            while (cur_expected_count < target_expected_count) and (cur_bin < (low_freqs.size(0)-1)):
                cur_bin += 1
                cur_expected_count += low_freqs[cur_bin]

            tmp_counts[:,cur_merged_bin] += hist_counts[cur_pixels_mask][:, cur_bin_low:(cur_bin+1)].sum(1)
            tmp_pred_probs[:,cur_merged_bin] += hist_pred_probs[cur_pixels_mask][:, cur_bin_low:(cur_bin+1)].sum(1)

            cur_bin += 1
            cur_merged_bin += 1


        # Merge the last bin with the previous one if needed
        if cur_expected_count < target_expected_count:
            cur_merged_bin -= 1
            tmp_counts[:,cur_merged_bin-1] += tmp_counts[:,cur_merged_bin]
            tmp_pred_probs[:,cur_merged_bin-1] += tmp_pred_probs[:,cur_merged_bin]


        tmp_counts[:,cur_merged_bin:] = float('nan')
        tmp_pred_probs[:,cur_merged_bin:] = float('nan')

        out_counts[cur_pixels_mask] = tmp_counts
        out_pred_probs[cur_pixels_mask] = tmp_pred_probs
        
#     out_counts = out_counts[:,:,:cur_merged_bin]
#     out_pred_probs = out_pred_probs[:,:, :cur_merged_bin]
    
    # Renormalise predictive probabilities
    out_pred_probs.div_(preprocUtils.nansum(out_pred_probs, dim=2).unsqueeze(-1)) 
    
    return out_counts, out_pred_probs
    

In [None]:
# hist_counts = Ahists
# hist_pred_probs = Apreds
# hist_lambdas = imgsImputedLambda

# adaptive_cats_based_on_lambda = 5
# target_expected_count = 5.

# if adaptive_cats_based_on_lambda > 1:
#     lambda_cat_inds = torch.tensor(np.digitize(
#         hist_lambdas, np.linspace(hist_lambdas.min(), hist_lambdas.max()+1, adaptive_cats_based_on_lambda+1))-1,
#                                   dtype=torch.uint8)
# else:
#     lambda_cat_inds=torch.zeros_like(hist_counts[:,:,0])

# if remove_first_bin:
#     hist_counts = hist_counts[:,:,1:]
#     hist_pred_probs = hist_pred_probs[:,:,1:]
#     # Renormalise predictive probabilities
#     hist_pred_probs.div_(preprocUtils.nansum(hist_pred_probs, dim=2).unsqueeze(-1)) 

# n = hist_counts.sum(2).view(-1)
# hist_expected_freqs = n.unsqueeze(-1) * hist_pred_probs.view(-1, hist_pred_probs.size(-1))



In [None]:
# out_counts = torch.zeros_like(hist_counts)
# out_pred_probs = torch.zeros_like(hist_pred_probs)

# for cur_cat in range(1, adaptive_cats_based_on_lambda):
#     print("Current category: {} \n --------------".format(cur_cat))
#     cur_pixels_mask = (lambda_cat_inds == cur_cat)

#     low_freqs = torch.tensor(
#         np.quantile(
#             hist_expected_freqs[cur_pixels_mask.view(-1)],
#                 q = 0.2, axis = 0)
#     )

#     # Can't use A[mask1][:,mask2] = ... as assignment (but also doesnt give error.)
#     tmp_counts = torch.zeros_like(out_counts[cur_pixels_mask])
#     tmp_pred_probs = torch.zeros_like(out_pred_probs[cur_pixels_mask])
    
#     cur_bin = 0
#     cur_merged_bin = 0
#     cur_expected_count = 0
#     while cur_bin < low_freqs.size(0):
#         print(cur_bin)
#         cur_bin_low = cur_bin

#         cur_expected_count = low_freqs[cur_bin]
#         while (cur_expected_count < target_expected_count) and (cur_bin < (low_freqs.size(0)-1)):
#             cur_bin += 1
#             cur_expected_count += low_freqs[cur_bin]
            
#         tmp_counts[:,cur_merged_bin] += hist_counts[cur_pixels_mask][:, cur_bin_low:(cur_bin+1)].sum(1)
#         tmp_pred_probs[:,cur_merged_bin] += hist_pred_probs[cur_pixels_mask][:, cur_bin_low:(cur_bin+1)].sum(1)

#         cur_bin += 1
#         cur_merged_bin += 1


#     # Merge the last bin with the previous one if needed
#     if cur_expected_count < target_expected_count:
#         cur_merged_bin -= 1
#         tmp_counts[:,cur_merged_bin-1] += tmp_counts[:,cur_merged_bin]
#         tmp_pred_probs[:,cur_merged_bin-1] += tmp_pred_probs[:,cur_merged_bin]
        

#     tmp_counts[:,cur_merged_bin:] = float('nan')
#     tmp_pred_probs[:,cur_merged_bin:] = float('nan')
    
#     out_counts[cur_pixels_mask] = tmp_counts
#     out_pred_probs[cur_pixels_mask] = tmp_pred_probs
    
# # out_counts = out_counts[:,:,:cur_merged_bin]
# # out_pred_probs = out_pred_probs[:,:, :cur_merged_bin]

# # # Renormalise predictive probabilities
# # out_pred_probs.div_(preprocUtils.nansum(out_pred_probs, dim=2).unsqueeze(-1)) 

In [None]:
train_indeces = torch.zeros_like(imgsImputedLambda.cpu(), dtype=torch.uint8)
train_indeces[train_x.long().unbind(1)] = 1
test_indeces = torch.ones_like(imgsImputedLambda.cpu(), dtype=torch.uint8)
test_indeces[train_x.long().unbind(1)] = 0

In [None]:
r0 = 0.
#r1 = np.floor(float(imgsImputed.max()/2))
r1 = float(imgsImputed.max())
nbins = 1000

Ahists, bin_centers = fast_histograms(imgsImputed.cuda(), r0, r1, nbins, output_device='cpu')
#Ahists_orig, bin_centers = fast_histograms(imgs.cuda(), r0, r1, nbins, output_device='cpu')
#Ahists_dense, bin_centers_dense = fast_histograms(imgsImputed.cuda(), r0, r1, nbins*5, output_device='cpu')


In [None]:
torch.cuda.empty_cache()

In [None]:
Apreds = fast_predictive_probs(imgsImputedLambda.cuda(), bin_centers.cuda(),
                               photon_log_probs.cuda(), gray_levels.cuda(),
                               batch_size = 50,
                              output_device = 'cpu')

# Apreds_dense = fast_predictive_probs(imgsImputedLambda.cuda(), bin_centers_dense.cuda(),
#                                photon_log_probs.cuda(), gray_levels.cuda(),
#                                batch_size = 50,
#                               output_device = 'cpu')

In [None]:
torch.cuda.empty_cache()

In [None]:
A_chiErrors, chi_thres = getChiSquaredHistogramError(Ahists, Apreds, n_params = n_params)

In [None]:
A_ksErrors, ks_thres = getKolmogorovSmirnovHistogramError(Ahists, Apreds, dim=-1)

In [None]:
imagesc(A_ksErrors)

In [None]:
imagesc(torch.tensor(ks_thres).float())

In [None]:
n_params = 0 if use_validation_data_only else sum([p.numel() for p in likelihood.parameters()])

A_chiErrors, chi_thres = getChiSquaredHistogramError(Ahists, Apreds, n_params = n_params)
                                                     #n_params=sum([p.numel() for p in likelihood.parameters()]))

# A_chiErrors_dense, chi_thres_dense = getChiSquaredHistogramError(Ahists_dense, Apreds_dense, 
#                                                      n_params=sum([p.numel() for p in likelihood.parameters()]))

A_lrErrors, chi_thres = getLikelihoodRatioHistogramError(Ahists, Apreds, n_params = n_params)
                                                     #n_params=sum([p.numel() for p in likelihood.parameters()]))

# A_lrErrors_dense, chi_thres_dense = getLikelihoodRatioHistogramError(Ahists_dense, Apreds_dense, 
#                                                      n_params=sum([p.numel() for p in likelihood.parameters()]))

In [None]:
imagesc(A_chiErrors.clamp(min=chi_thres, max=chi_thres*2))

In [None]:
imagesc(A_lrErrors.clamp(min=chi_thres, max=chi_thres*4))

In [None]:
print("""{:.2f}% of pixels are well explained by a 
single underlying poisson process with a fixed rate over time, 
meaning there is no discernible time variation""".format(float((A_chiErrors<chi_thres).sum())/A_chiErrors.numel()*100)
     )

In [None]:
#ind = np.ravel_multi_index((382, 103), imgsImputedLambda.size()) # order is y-x (as displayed in image)

#ind = Apreds[:,:,9].argmax()
#ind = Ahists[:,:,0].argmax()

#ind = A_lrErrors.argmin()
ind = A_chiErrors.argmin()
#ind = A_chiErrors.argmax()

#ind = imgsImputedLambda.view(-1).argmax()

# tmp = fast_predictive_probs(imgsImputedLambda.view(-1)[ind].view(-1,1).cpu(), 
#                             bin_centers, photon_log_probs, gray_levels, dim = -1, batch_size=5, 
#                           output_device = None)

# tmp1 = fast_predictive_probs(imgsImputedLambda.view(-1)[ind].view(-1,1).cpu(), 
#                             bin_centers, log_photon_prob_marginals, gray_levels, dim = -1, batch_size=5, 
#                           output_device = None)

# tmp2 = likelihood.single_log_prob(torch.tensor([-18.]).expand(bin_centers.size()).view(1,-1), bin_centers,batchsize=200)
# tmp2 -= tmp2.logsumexp(1)

# tmp3 = likelihood.single_log_prob(imgsImputedLambda.view(-1)[ind].log().view(-1).cpu().expand(bin_centers.size()).view(1,-1), bin_centers,batchsize=200)
# tmp3 -= tmp3.logsumexp(1)

print(np.unravel_index(ind, imgsImputedLambda.size()), 
      imgsImputedLambda.view(-1)[ind],
      "\n Chi test metric: {} < {} Threshold ".format(A_chiErrors.view(-1)[ind], chi_thres))
plt(plot(Ahists.view(-1,Ahists.size(2))[ind,:].view(-1,1)/imgsImputed.size(2), bin_centers, now=False)
    +plot(Apreds.view(-1,Apreds.size(2))[ind,:].view(-1,1), bin_centers, now=False)
#     +plot(tmp.view(-1,1), bin_centers, now=False)
#     +plot(tmp1.view(-1,1), bin_centers, now=False)
#     +plot(tmp2.view(-1,1).exp(), bin_centers, now=False)
#     +plot(tmp3.view(-1,1).exp(), bin_centers, now=False)
)

In [None]:
out_counts, out_pred_probs = merge_bins(Ahists, Apreds, 
                                        target_expected_count=5., 
                                        remove_first_bin = True,
                                       adaptive_cats_based_on_lambda = 200,
                                       hist_lambdas = imgsImputedLambda)

In [None]:
n_params = 0 if use_validation_data_only else sum([p.numel() for p in likelihood.parameters()])

A_chiErrors_new, chi_thres_new = getChiSquaredHistogramError(out_counts, out_pred_probs, n_params = n_params)

In [None]:
chi_thres_new.min()

In [None]:
imagesc(A_chiErrors_new.clamp(max=400))

In [None]:
imagesc(A_chiErrors_new>(chi_thres_new.float()*1.4))

In [None]:
n_params = 0 if use_validation_data_only else sum([p.numel() for p in likelihood.parameters()])

A_chiErrors_new, chi_thres_new = getChiSquaredHistogramError(out_counts, out_pred_probs, n_params = n_params)
                                                     #n_params=sum([p.numel() for p in likelihood.parameters()]))

# A_chiErrors_dense, chi_thres_dense = getChiSquaredHistogramError(Ahists_dense, Apreds_dense, 
#                                                      n_params=sum([p.numel() for p in likelihood.parameters()]))

A_lrErrors_new, chi_thres_new = getLikelihoodRatioHistogramError(out_counts, out_pred_probs, n_params = n_params)
                                                     #n_params=sum([p.numel() for p in likelihood.parameters()]))

# A_lrErrors_dense, chi_thres_dense = getLikelihoodRatioHistogramError(Ahists_dense, Apreds_dense, 
#                                                      n_params=sum([p.numel() for p in likelihood.parameters()]))

In [None]:
chi_thres_new

In [None]:
A_chiErrors_new

In [None]:
imagesc(A_chiErrors_new)

In [None]:
imagesc(A_chiErrors_new.clamp(min=30, max=200))

In [None]:
imagesc(A_chiErrors_new.clamp(min=chi_thres_new, max=chi_thres_new*3))
imagesc(A_lrErrors_new.clamp(min=chi_thres_new, max=chi_thres_new*3))

In [None]:
print("""{:.2f}% of pixels are well explained by a 
single underlying poisson process with a fixed rate over time, 
meaning there is no discernible time variation""".format(float((A_chiErrors_new<chi_thres_new).sum())/A_chiErrors_new.numel()*100)
     )

In [None]:
#ind = np.ravel_multi_index((382, 103), imgsImputedLambda.size()) # order is y-x (as displayed in image)

#ind = Apreds[:,:,9].argmax()
ind = Ahists[:,:,9].argmax()

#ind = A_lrErrors.argmin()
#ind = A_chiErrors.argmin()
#ind = A_lrErrors_new.argmin()
ind = A_chiErrors_new.argmax()

#ind = imgsImputedLambda.view(-1).argmax()

# tmp = fast_predictive_probs(imgsImputedLambda.view(-1)[ind].view(-1,1).cpu(), 
#                             bin_centers, photon_log_probs, gray_levels, dim = -1, batch_size=5, 
#                           output_device = None)

# tmp1 = fast_predictive_probs(imgsImputedLambda.view(-1)[ind].view(-1,1).cpu(), 
#                             bin_centers, log_photon_prob_marginals, gray_levels, dim = -1, batch_size=5, 
#                           output_device = None)

# tmp2 = likelihood.single_log_prob(torch.tensor([-18.]).expand(bin_centers.size()).view(1,-1), bin_centers,batchsize=200)
# tmp2 -= tmp2.logsumexp(1)

# tmp3 = likelihood.single_log_prob(imgsImputedLambda.view(-1)[ind].log().view(-1).cpu().expand(bin_centers.size()).view(1,-1), bin_centers,batchsize=200)
# tmp3 -= tmp3.logsumexp(1)

# print(np.unravel_index(ind, imgsImputedLambda.size()), 
#       imgsImputedLambda.view(-1)[ind],
#       "\n Chi test metric: {} \n  LR test metric: {} \n       Threshold: {}".format(
#           A_chiErrors_new.view(-1)[ind], A_lrErrors_new.view(-1)[ind], chi_thres_new))
plt(plot(out_counts.view(-1,out_counts.size(2))[ind,:].view(-1,1)/imgsImputed.size(2)*500, now=False)
    +plot(out_pred_probs.view(-1,out_pred_probs.size(2))[ind,:].view(-1,1)*500, now=False)
#     +plot(tmp.view(-1,1), bin_centers, now=False)
#     +plot(tmp1.view(-1,1), bin_centers, now=False)
#     +plot(tmp2.view(-1,1).exp(), bin_centers, now=False)
#     +plot(tmp3.view(-1,1).exp(), bin_centers, now=False)
)

print(np.unravel_index(ind, imgsImputedLambda.size()), 
      imgsImputedLambda.view(-1)[ind],
      "\n Chi test metric: {} \n  LR test metric: {} \n    Threshold: {}".format(
          A_chiErrors.view(-1)[ind], A_lrErrors.view(-1)[ind], chi_thres))
plt(plot(Ahists.view(-1,Ahists.size(2))[ind,:].view(-1,1)/imgsImputed.size(2), bin_centers, now=False)
    +plot(Apreds.view(-1,Apreds.size(2))[ind,:].view(-1,1), bin_centers, now=False)
   )

In [None]:
# #ind = Apreds[:,:,9].argmax()
# ind = Ahists[:,:,1].argmax()

# #ind = A_lrErrors.argmax()
# #ind = A_chiErrors.argmin()

# #ind = imgsImputedLambda.view(-1).argmin()

# tmp2 = likelihood.single_log_prob(torch.tensor([-18.]).expand(bin_centers.size()).view(1,-1), bin_centers,batchsize=200)
# tmp2 -= tmp2.logsumexp(1)

# tmp3 = likelihood.single_log_prob(imgsImputedLambda.view(-1)[ind].log().view(-1).cpu().expand(bin_centers.size()).view(1,-1), bin_centers,batchsize=200)
# tmp3 -= tmp3.logsumexp(1)

# print(imgsImputedLambda.view(-1)[ind])
# plt(plot(Ahists.view(-1,Ahists.size(2))[ind,:].view(-1,1)/imgsImputed.size(2), bin_centers, now=False)
#     +plot(Apreds.view(-1,Apreds.size(2))[ind,:].view(-1,1), bin_centers, now=False)
#     +plot(tmp2.view(-1,1).exp(), bin_centers, now=False)
#     +plot(tmp3.view(-1,1).exp(), bin_centers, now=False)
# )

In [None]:
# ind = Apreds_dense[:,:,1].argmax()
# #ind = A_lrErrors.argmax()
# #ind = A_chiErrors.argmin()

# #ind = imgsImputedLambda.view(-1).argmin()

# tmp = fast_predictive_probs(imgsImputedLambda.view(-1)[ind].view(-1,1).cpu(), bin_centers_dense, log_photon_prob_marginals, gray_levels, dim = -1, batch_size=5, 
#                           output_device = None)
# #tmp -= tmp.logsumexp(1)

# tmp2 = likelihood.single_log_prob(torch.tensor([-18.]).expand(bin_centers_dense.size()).view(1,-1), bin_centers_dense,batchsize=200)
# tmp2 -= tmp2.logsumexp(1)

# tmp3 = likelihood.single_log_prob(imgsImputedLambda.view(-1)[ind].log().view(-1).cpu().expand(bin_centers_dense.size()).view(1,-1), bin_centers_dense,batchsize=200)
# tmp3 -= tmp3.logsumexp(1)

# print(imgsImputedLambda.view(-1)[ind])
# plt(plot(Ahists_dense.view(-1,Ahists_dense.size(2))[ind,:].view(-1,1)/imgsImputed.size(2), bin_centers_dense, now=False)
#     +plot(tmp.view(-1,1), bin_centers_dense, now=False)
#     +plot(Apreds_dense.view(-1,Apreds_dense.size(2))[ind,:].view(-1,1), bin_centers_dense, now=False)
#     +plot(tmp2.view(-1,1).exp(), bin_centers_dense, now=False)
#     +plot(tmp3.view(-1,1).exp(), bin_centers_dense, now=False)
# )

In [None]:
# bin_edges_dense = torch.cat([bin_centers_dense[0].view(-1)-1e10,
#                        bin_centers_dense[1:-1] - (bin_centers_dense[2]-bin_centers_dense[1]).div(2)+1e-7,
#                        bin_centers_dense[-1].view(-1),
#                        bin_centers_dense[-1].view(-1)+1e10
#                       ])

In [None]:
# r0=0.
# r1=400.
# bin_centers_dense = torch.arange(r0,r1, 1.)
# bin_edges_dense = torch.cat([torch.tensor([r0])-1e10,
#                        bin_centers_dense[1:-1],
#                        torch.tensor([r1])+1e10
#                       ])

In [None]:
#float((imgs!=imgsImputed).sum())/imgs.numel()

In [None]:
# hist, bins = np.histogram(imgsImputed.detach().cpu(), bin_edges_dense)
# hist_orig, bins = np.histogram(imgs.detach().cpu(), bin_edges_dense)

# plt(([plt_type.Scatter(y=hist/sum(hist), x=bin_centers_dense)]
#      +[plt_type.Scatter(y=hist_orig/sum(hist_orig), x=bin_centers_dense)]
#      +plot(Ahists_dense.sum(1).sum(0).view(-1,1).div(Ahists_dense.sum()), bin_centers_dense, now=False))
#    )

In [None]:
# plt(plot(Ahists.sum(1).sum(0).view(-1,1).div(Ahists.sum()), bin_centers, now=False)
#          +plot(tmp3.view(-1,1).exp(), bin_centers, now=False))


In [None]:
# plt(plot(Ahists_dense.sum(1).sum(0).view(-1,1).div(Ahists_dense.sum()), bin_centers_dense, now=False)
#          +plot(tmp3.view(-1,1).exp(), bin_centers_dense, now=False))


In [None]:
# tmp.shape

In [None]:
Ahists.view(-1,Ahists.size(2)).shape

In [None]:
imagesc(A_chiErrors)

# Get histrograms of lambda

In [None]:
import itertools
def createLocalHists(imgs, n_hist_grid = 3, nbins = None):
    """
    Creates n_hist_grid**2 histograms tiling the spacial axis of the image
    
    ..note:
        Do the whole thing on cpu, as histc is not supported on GPU and numpy hist is slow
    """
    input_device = imgs.device
    ranges = list()
    for d in range(2):
        ranges.append(list())
        tmp = torch.linspace(0, imgs.shape[d],n_hist_grid+1).round().cpu()
        for n in range(n_hist_grid):
            ranges[d].append(slice(int(tmp[n]), int(tmp[n+1])))
        
    imgs_max = float(imgs.max())
    imgs_min = float(imgs.min())
    n_bins = int(nbins) if nbins is not None else int(imgs_max-imgs_min+1.)
    
    # Modify min and max such that the bin centers actually start on imgs_max and imgs_min 
    bin_edge_correction = (imgs_max-imgs_min)/(2.*float(n_bins-1.))
    
    out = torch.stack([
            torch.histc(imgs[ind0, ind1, :].cpu(), # histc does not support GPU
                bins = n_bins,
                min = imgs_min-bin_edge_correction,
                max = imgs_max+bin_edge_correction)
                     for ind0, ind1 in itertools.product(*ranges)
        ],
        dim = 1)
    
    hist_bin_centers = torch.linspace(imgs_min, imgs_max, n_bins)
    
    return out.to(input_device), hist_bin_centers.to(input_device)

In [None]:
from collections import OrderedDict
def summariseHists(histOut, numQuantXs = 101, retCdf = False, invNormalise=True):
    """
    Summarises the historgrams by their quantile function
    and (normalised) inverse quantile functions with 
        numQuantXs unique inverse quantile X points linspaced between 0 and 1 (inclusive)
        We later assume a piecewise-constant approximation of the true quantile function with these points being centers
        Set numQuantXs = None to avoid the approximation (the full cdf resolution is used then)
    """
    input_device = histOut[0].device
    
    # Get quantile function as OrderedDict ('y' - matrix, 'x' - vector)
    cdfs = OrderedDict( [ # init input is list of tuples
        ('y', torch.cumsum(histOut[0], dim=0)/histOut[0].sum(0)),
        ('x', histOut[1])
        ])
    
    if retCdf:
        return cdfs
    
    # Get quantile function (inverse cdf) as OrderedDict ('y' - matrix, 'x' - vector)
    
    numQuantXs = numQuantXs if numQuantXs is not None else len(histOut[1])
    
    inv_bins = torch.linspace(0., 1., numQuantXs).to(input_device)
    div_by = len(histOut[1]) if invNormalise else 1.
    quantFuncs = OrderedDict( [
            ('y', (cdfs['y'].unsqueeze(-1) - inv_bins.view(1,1,-1)).abs().argmin(0).float().permute(1,0).div(float(div_by))),
            ('x', inv_bins)
            ])
    
    return quantFuncs
    
    # Get non-normalised inverse quantile function

In [None]:
def distWasserstein(quantFuncs, p=2., retPwDistsWeighted = False):
    # Get pairwise distances at every point
    pwDists = (quantFuncs['y'].unsqueeze(2) - quantFuncs['y'].unsqueeze(1)).abs()
    
    xbinSize = quantFuncs['x'][1] - quantFuncs['x'][0]
    pwDistsWeighted = xbinSize * (pwDists.pow(p))
    
    if retPwDistsWeighted:
        return pwDistsWeighted
    
    return pwDistsWeighted.sum(0).pow(1/p)

In [None]:
# Histograms of chi_squared test statistics
histOut = createLocalHists(A_chiErrors.cpu()[train_indeces].view(-1,1,1).clamp(max=chi_thres*4),1, nbins=1000)
histOut2 = createLocalHists(A_chiErrors.cpu()[test_indeces].view(-1,1,1).clamp(max=chi_thres*4),1, nbins=1000)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Histograms of likelihood ratio test statistics
histOut = createLocalHists(A_lrErrors.cpu()[train_indeces].view(-1,1,1),1, nbins=1000)
histOut2 = createLocalHists(A_lrErrors.cpu()[test_indeces].view(-1,1,1),1, nbins=1000)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

### Histograms of corrected vs uncorrected lambda

In [None]:
imgsImputedLambdaCorrected = imgsImputedLambda/(pred_gain_func.cuda())

In [None]:
# Train vs test uncorrected lambda
histOut = createLocalHists(imgsImputedLambda[train_indeces].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambda[test_indeces].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Uncorrected vs corrected lambda
histOut = createLocalHists(imgsImputedLambda.view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambdaCorrected.view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Uncorrected vs corrected training lambda
histOut = createLocalHists(imgsImputedLambda[train_indeces].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambdaCorrected[train_indeces].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Corrected train vs test lambda
histOut = createLocalHists(imgsImputedLambdaCorrected[train_indeces].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambdaCorrected[test_indeces].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Uncorrected vs corrected test lambda
histOut = createLocalHists(imgsImputedLambda[test_indeces].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambdaCorrected[test_indeces].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Histogram of corrected lambdas with low and high LR test statistics
inds1 = A_lrErrors<chi_thres
inds2 = A_lrErrors>=chi_thres

# Uncorrected vs corrected test lambda
histOut = createLocalHists(imgsImputedLambdaCorrected[inds1].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambdaCorrected[inds2].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Histogram of corrected lambdas with low and high Chi test statistics
inds1 = A_chiErrors<chi_thres
inds2 = A_chiErrors>=chi_thres

# Uncorrected vs corrected test lambda
histOut = createLocalHists(imgsImputedLambdaCorrected[inds1].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambdaCorrected[inds2].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# Histogram of corrected lambdas with low and high Chi test statistics
inds1 = A_chiErrors<chi_thres
inds2 = A_chiErrors>=chi_thres

# Uncorrected vs corrected test lambda
histOut = createLocalHists(imgsImputedLambda[inds1].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(imgsImputedLambda[inds2].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
histOut = createLocalHists(imgsImputedLambda.view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(tmp.view(-1,1,1),1, nbins=100)
plt(plot(histOut[0], histOut[1], now=False)+plot(histOut2[0], histOut2[1], now=False))

In [None]:
plot(sumHistOut['y'], sumHistOut['x'])

In [None]:
histOut = createLocalHists(Aerrors[train_indeces].view(-1,1,1),1, nbins=100)
histOut2 = createLocalHists(Aerrors[test_indeces].view(-1,1,1),1, nbins=100)
plt((plot(histOut[0]/histOut[0].sum()/(histOut[1][1]-histOut[1][0]), histOut[1], now=False)
    +plot(histOut2[0]/histOut2[0].sum()/(histOut2[1][1]-histOut2[1][0]), histOut2[1], now=False))
                                              )

In [None]:
# def getLambdaPoly(w):
#     #x = torch.arange(w.size(dim)).float().view(*([1]*max(dim,0))+[-1]+[1]*max(w.ndimension()-dim-1,0))
#     x = torch.arange(w.numel()).float()
    
#     # Get coeffs for coeff[0]*lambda^0, coeff[1]*lambda, ... coeff[N]*lambda^N
#     coeffs = (w[:-1]-w[1:]).div((x[:-1]+1).lgamma().exp())
#     coeffs = torch.cat([coeffs.view(-1), -w[-1].div((x[-1]+1).lgamma().exp()).view(1)])
    
#     return coeffs
    
#     # Get rid of unnecessary too small coeffs
# #     max_coeff = coeffs.abs().max()
# #     coeffs[coeffs.abs()<(1e-3*max_coeff)] = 0

#     # Get rid of zero coeffs (numerical issues)
#     coeffs = coeffs[:int((coeffs>0).max(0)[1]+1)]
    
#     res = np.roots(np.flip(coeffs.detach()))
    
#     return res
    
#     res = res[np.isreal(res)]
#     res = np.real(res)
#     res = res[res>1e-9]
    
#     res = torch.tensor(res)
#     if res.numel()==0:
#         res = torch.tensor([0.])
    
#     return res

In [None]:
imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im, pixels_per_micron=1.15)

In [None]:
imagesc(corr_mean_im.clamp(0,5), pixels_per_micron=1.15)

In [None]:
imagesc(mll.pmGa)

In [None]:
likelihood.offset

In [None]:
# Correct the whole imgsImputes
# gainRange = [1e-20, float('inf')]
# imgsImputedCorr = ((imgsImputed-likelihood.offset)
#                    .div(torch.clamp(pred_gain_func, min=1./gainRange[1], max=1./gainRange[0]).unsqueeze(-1))
#                    +likelihood.offset
#                   ).clamp(min=0.)

gainRange = [1e-20, float('inf')]
imgsImputedCorr = ((imgsImputed)
                   .div(torch.clamp(pred_gain_func, min=1./gainRange[1], max=1./gainRange[0]).unsqueeze(-1))
                  )

In [None]:
imagesc(imgsImputedCorr.mean(2))

# Additive kernel decomposition
Vincent Adam's paper MLSP 2016

In [None]:
pred_gain_func_orig = copy.deepcopy(pred_gain_func.detach())

In [None]:
dict(model.covar_module.named_parameters())

In [None]:
def getMarginals(model):
    """
    Given a model with 0 mean a scaled additive kernel as its covar_module, 
    returns the marginal contributions of kernels.
    
    Also assumes kernels[0] is symmetrized and thus non-invertible. 
    """
    
    inp = model.inducing_points
    kernels = model.covar_module.base_kernel.kernels
    
    Ks = [k(inp).evaluate()*model.covar_module.log_outputscale.exp() 
         for k in kernels]
    Ksum = sum(Ks)
    
    
    muF = model.variational_output().mean()
    if isinstance(model.mean_module, gpytorch.means.ConstantMean):
        muF = muF-model.mean_module.constant.data.squeeze()
    
    
    sigF = model.variational_output().covar().evaluate()
    muF_Scaled = sigF.inverse().matmul(muF)
    
    Ktilde = (sigF.inverse() - Ksum.inverse()).inverse() + Ksum
    Ktilde_inv = Ktilde.inverse()
    
    Vs = [[ Kd1 - Kd1.matmul(Ktilde_inv).matmul(Kd2)
        
    for Kd1 in Ks]
    for Kd2 in Ks]
    
    Ksum_Scaled = Ktilde_inv.matmul(Ksum)
    
    nus = [(Kd - Kd.matmul(Ksum_Scaled)).matmul(muF_Scaled)
           for Kd in Ks]
    
    # Add back the prior mean to component 1
    nus[1] = nus[1]
           
    margRandVars = [
        gpytorch.random_variables.GaussianRandomVariable(
        nus[i].view(-1),
        Vs[i][i])
        #gpytorch.utils.pivoted_cholesky.pivoted_cholesky(Vs[i][i], 10))
        for i in range(len(nus))]
    
    return nus, Vs, margRandVars


In [None]:
def interpolateFromMarginal(model, randVar, inputs):
    interp_indices, interp_values = model._compute_grid(inputs)
    
    # Compute test mean
    # Left multiply samples by interpolation matrix
    test_mean = gpytorch.utils.left_interp(interp_indices, interp_values, randVar.mean().unsqueeze(-1))
    test_mean = test_mean.squeeze(-1)

    # Compute test covar
    test_covar = gpytorch.lazy.InterpolatedLazyVariable(
        randVar.covar(), interp_indices, interp_values, interp_indices, interp_values
    )
        
    return gpytorch.random_variables.GaussianRandomVariable(test_mean, test_covar)

In [None]:
nus, Vs, margRandVars = getMarginals(model)

In [None]:
device='cpu'
n_test_grid = torch.tensor(mean_im.shape)
test_x = preprocUtils.create_test_grid(n_test_grid, ndims=2, device=device, a=dataStats['x_minmax'][0,:], b=dataStats['x_minmax'][1,:])


In [None]:
logGainFunc = interpolateFromMarginal(model, margRandVars[0], test_x)

pred_gain_func = logGainFunc.mean().reshape(*n_test_grid).exp()

In [None]:
logZFunc = interpolateFromMarginal(model, margRandVars[1], test_x)
pred_Z_func = logZFunc.mean().reshape(*n_test_grid).exp()

In [None]:
imagesc(pred_gain_func_orig)
imagesc(pred_gain_func * pred_Z_func)

In [None]:
imagesc(pred_gain_func)

In [None]:
imagesc(pred_Z_func)


# Showing results

In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='noPrior', lik='linLik', stamp = '_00_firstRun_noPCremoved')

imagesc(pred_gain_func, pixels_per_micron=1.15, heatmap=dict(colorscale='div'), image='svg', filename='fig1-res1-gain_noprior_linlik')
imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_noprior_linlik')


mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='mexRadPrior', lik='linLik', stamp = '_00_firstRun_noPCremoved')

imagesc(pred_gain_func, pixels_per_micron=1.15, heatmap=dict(colorscale='div'), image='svg', filename='fig1-res1-gain_mexRadprior_lik')
imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_mexRadprior_lik')

In [None]:
imagesc(corr_mean_im, pixels_per_micron=1.15, image='svg', filename='fig1-res1-corr_mexRadprior_lik')

In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='noPrior', lik='unampLik', stamp = '_00_firstRun_noPCremoved')

In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='mexRadPrior', lik='unampLik', stamp = '_00_firstRun_noPCremoved')

In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='mexRadPrior', lik='linLik', stamp = '_00_firstRun_noPCremoved')

In [None]:
imagesc(pred_gain_func, heatmap=dict(colorscale='div'))


In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='noPrior', lik='unampLik', stamp = '_00_firstRun_noPCremoved')


In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='noPrior', lik='linLik', stamp = '_00_firstRun_noPCremoved')


In [None]:
mll, model, likelihood, train_x, train_y, \
dataStats, mean_im, pred_gain_func, corr_mean_im = \
display_results(retVars=True,data_id = '0', prior='noPrior', lik='linLik', stamp = '_00_firstRun_noPCremoved_smallLinNoiseInit')


In [None]:
imagesc(pred_gain_func, heatmap=dict(colorscale='div'))
imagesc(mean_im, pixels_per_micron=1.15)
imagesc(corr_mean_im, pixels_per_micron=1.15)

In [None]:
# imagesc(pred_gain_func, colorscale='felfire')
# imagesc(mean_im, image='svg', filename='fig1-mean_im', pixels_per_micron=1.15)
# imagesc(corr_mean_im, image='svg', filename='fig1-mean_im_corr', pixels_per_micron=1.15)

In [None]:
imagesc(corr_mean_im, image='svg', filename='fig1-mean_im_corr', pixels_per_micron=1.15)

In [None]:
plotly.__version__


# Estimating photon probability from grey level

In [None]:
light_levels = torch.arange(1e-4, 1.,1e-4) # In photon
model_out = gpytorch.random_variables.GaussianRandomVariable(light_levels.log(), gpytorch.lazy.DiagLazyVariable(1e-7*torch.ones_like(light_levels)))
lik_out = likelihood(model_out)


In [None]:
photon_counts, photon_log_probs = (
    likelihood.getPhotonLogProbs(model_out.mean().view(-1,1).exp(), max_photon=float(20), reNormalise = False))
p_PM = likelihood.createResponseDistributions(photon_counts)

In [None]:
gray_levels = torch.arange(-50., 500.,20.)
# tmp = likelihood.getLogProbSumOverTargetSamples(p_PM, gray_levels.view(-1))
# plot(tmp.exp().div(tmp.exp().sum(1).view(-1,1)), gray_levels)

In [None]:
likelihood.__class__

In [None]:
tmp = likelihood.single_log_prob(torch.arange(0.,20.).view(1,-1)*torch.ones_like(gray_levels).view(-1,1), gray_levels.view(-1,1))

In [None]:
a = tmp.exp().div(tmp.exp().sum(1).view(-1,1)).detach().data
n_max = 5
fig = plotStacked(torch.cat([a[:,:n_max], (1.-a[:,:n_max].sum(1)).view(-1,1)],dim=1), gray_levels, now=False)
plt(fig)



In [None]:
exportFigure(fig, image='svg', filename='fig1-photon_cum_prob')

In [None]:
plt(plt_type.Figure(data=data, 
                layout=plt_type.Layout(
                    xaxis=dict(
                        title='Grey level in data'
                    ),
                    yaxis=dict(
                        title='Cumulative probability of photon count'
                    )
                ))
   )

In [None]:
tmp.shape

In [None]:
tmp.exp().div(tmp.exp().sum(1).view(-1,1)).sum(1)

In [None]:
plot(lik_out.var().sqrt().view(-1,1), light_levels)

In [None]:
plot(lik_out.mean().view(-1,1), light_levels)

In [None]:
# Diagnose non-approximate likelihood and posterior peakiness
min_photon = 0
max_photon = None
import numpy as np
for n_train_ind in [44323, 13450, 6792]:# 38150, 38155, 38159]:#44323, 17500, 13450, 6792]: # 6792 is a good training index, looks bumpy

    hist, bins, bins_extended = croppedHist(train_y[n_train_ind,:], bins = 100)
    model_out = model(train_x[n_train_ind,:].unsqueeze(0))
    print(model_out.mean().exp())
    print(likelihood.log_probability(model_out, train_y[n_train_ind, :]))
    if likelihood.__class__==preprocLikelihoods.LinearGainLikelihood:
        lik_out = likelihood(model_out)
    else:
        lik_out = likelihood(model_out, approx=False)
    
    # Sample from the output (we don't have access to log_prob for Mixture)
    n_samples = 1e5*1.
    if max_photon is None:
        lik_sample = lik_out.sample(int(n_samples))
    else:
        curMixWeights = lik_out.rand_vars[0].weights[min_photon:max_photon+1] / lik_out.rand_vars[0].weights[min_photon:max_photon+1].sum()
        lik_sample = preprocRandomVariables.MixtureRandomVariableWithSampler(
            *lik_out.rand_vars[0].rand_vars[:max_photon+1], weights = curMixWeights.view(-1)).sample(int(n_samples))
        
    lik_hist = np.histogram(lik_sample.detach().cpu(), bins_extended)[0]/n_samples*train_y.size(1)
    
    #cur_out_moved = likelihood(model(train_x[n_train_ind,:].unsqueeze(0)+0.2)).cpu()
    plt([plt_type.Bar(x=bins[:-1], y=hist),
         plt_type.Scatter(x=bins[:-1], 
                          y=lik_hist, mode='lines+markers'),
         plt_type.Scatter(x=torch.cat([model_out.mean().exp()*likelihood.log_gain.exp()]*2), 
                          y=np.array([0,30.]), mode='lines')
        ])

In [None]:
# Good data - model pairs
"""
dataset_name = 'neurofinder.04.00'
tmp = scipy.io.loadmat(data_dir+dataset_name+'/imputed_first_700_frames.mat')
imgsImputed = torch.tensor(tmp['imgsImputed'])
mll = torch.load("savedModels/mll_20180814T170124").cpu()
"""

"""
"""


# Get numerical results

In [None]:
model_out = model(test_x)
bg_mean = likelihood.forward(model_out).mean().reshape(*n_test_grid)

imgsImputedMeanCorr = ((imgsImputed)
                   .sub(bg_mean.unsqueeze(-1))
                  )

gainRange = [1e-20, float('inf')]
imgsImputedCorr = ((imgsImputedMeanCorr)
                   .div(torch.clamp(pred_gain_func, min=1./gainRange[1], max=1./gainRange[0]).unsqueeze(-1))
                  )

imgsImputedOrigCorr = ((imgsImputed)
                   .div(torch.clamp(pred_gain_func_orig, min=1./gainRange[1], max=1./gainRange[0]).unsqueeze(-1))
                  )

In [None]:
def toMat(A):
    A = A.sub(A.min())
    A = A.div(A.max())
    return A

In [None]:
from nbimporter import NotebookLoader
scatterHex = NotebookLoader().load_module("preprocVisualisationTesting").scatterHex

In [None]:
imagesc((imgsImputed.std(2).pow(2)/imgsImputed.mean(2)))

In [None]:
imagesc((imgsImputed.std(2).pow(2)/imgsImputed.mean(2)).clamp(70,170))

In [None]:
imagesc(imgsImputedOrigCorr.std(2).pow(2)/imgsImputedOrigCorr.mean(2))

In [None]:
imagesc(imgsImputedCorr.std(2).pow(2)/imgsImputedCorr.mean(2))

In [None]:
imagesc(imgsImputedCorr.std(2).pow(2).log())

In [None]:
model_out = model(test_x)
bg1 = likelihood.forward(logZFunc)
bg2 = likelihood.forward(model_out)

In [None]:
imagesc(bg2.mean().reshape(*n_test_grid))

In [None]:
imagesc(bg1.mean().reshape(*n_test_grid))

In [None]:
imagesc((imgsImputed.mean(2)-bg2.mean().reshape(*n_test_grid))/pred_gain_func)

In [None]:
imagesc(imgsImputed.mean(2)-bg1.mean().reshape(*n_test_grid))

In [None]:
# We want var / mean to be same everywhere (that would mean normalised gain?)
# Compute Geometric mean of var/mean:
# exp(1/n sum(log(var)-log(mean))) locally, show that corrected image is more stereotypical

cur_filter = torch.ones(71,71)
cur_filter = cur_filter.div(cur_filter.sum()).view(1,1,*cur_filter.size())

xslice = slice(20,490)#slice(0,512)#slice(50,470)#
yslice = slice(0,512)#slice(20,490)#slice(20,490)#
for cur_imgs in [imgsImputed[yslice,xslice], imgsImputedOrigCorr[yslice,xslice]]:
    logVars = torch.nn.functional.conv2d(
        cur_imgs.std(2).pow(2).log().view(1,1,*cur_imgs.shape[:2]),
        cur_filter,
        padding=tuple((torch.tensor(cur_filter.shape)[2:]-1)/2)
    )

    logMeans = torch.nn.functional.conv2d(
        cur_imgs.mean(2).log().view(1,1,*cur_imgs.shape[:2]),
        cur_filter,
        padding=tuple((torch.tensor(cur_filter.shape)[2:]-1)/2)
    )
    
    numDivisor = torch.nn.functional.conv2d(
        torch.ones_like(cur_imgs.mean(2)).view(1,1,*cur_imgs.shape[:2]),
        cur_filter,
        padding=tuple((torch.tensor(cur_filter.shape)[2:]-1)/2)
    )
    
    logMeans.div_(numDivisor)
    logVars.div_(numDivisor)


    imagesc((logVars-logMeans).exp().squeeze())

In [None]:
# We want var / mean to be same everywhere (that would mean normalised gain?)
# Compute Geometric mean of var/mean:
# exp(1/n sum(log(var)-log(mean))) locally, show that corrected image is more stereotypical

cur_filter = torch.ones(71,71)
cur_filter = cur_filter.div(cur_filter.sum()).view(1,1,*cur_filter.size())

xslice = slice(20,490)#slice(0,512)#slice(50,470)#
yslice = slice(0,512)#slice(20,490)#slice(20,490)#
for cur_imgs in [imgsImputed[yslice,xslice], imgsImputedMeanCorr[yslice,xslice], imgsImputedCorr[yslice,xslice]]:
    logVars = torch.nn.functional.conv2d(
        cur_imgs.std(2).pow(2).log().view(1,1,*cur_imgs.shape[:2]),
        cur_filter,
        padding=tuple((torch.tensor(cur_filter.shape)[2:]-1)/2)
    )

    Means = torch.nn.functional.conv2d(
        cur_imgs.mean(2).view(1,1,*cur_imgs.shape[:2]),
        cur_filter,
        padding=tuple((torch.tensor(cur_filter.shape)[2:]-1)/2)
    )
    
    numDivisor = torch.nn.functional.conv2d(
        torch.ones_like(cur_imgs.mean(2)).view(1,1,*cur_imgs.shape[:2]),
        cur_filter,
        padding=tuple((torch.tensor(cur_filter.shape)[2:]-1)/2)
    )
    
    Means.div_(numDivisor)
    logVars.div_(numDivisor)


    imagesc(Means.squeeze())
    imagesc((logVars).exp().squeeze())

In [None]:
imagesc(numDivisor.squeeze())

In [None]:
a = (logVars-logMeans).exp().squeeze()

In [None]:
imagesc(mean_im[slice(40,470),:])

In [None]:
imagesc((logVars-logMeans).exp().squeeze())

In [None]:
# # Look at the corrected vs original mean-var plots in smaller regions
# toUse = imgsImputedCorrected.mean(2)<5000000.
# scatterHex(y=imgsImputedCorrected.std(2)[toUse].pow(2.).view(-1).detach().numpy(), 
#            x=imgsImputedCorrected.mean(2)[toUse].view(-1).detach().numpy())

In [None]:
# # Look at the corrected vs original mean-var plots in smaller regions
# toUse = imgsImputed.mean(2)<5000000.
# scatterHex(y=imgsImputed.std(2)[toUse].pow(2.).view(-1).detach().numpy(), 
#            x=imgsImputed.mean(2)[toUse].view(-1).detach().numpy())

In [None]:
# load the regions (training data only)
import json
import numpy as np
dims = imgsImputed.shape[:2]

with open(data_dir+dataset_name+'/regions/regions.json') as f:
    regions = json.load(f)

def tomask(coords):
    mask = np.zeros(dims)
    mask[zip(*coords)] = 1
    return mask

masks = np.array([tomask(s['coordinates']) for s in regions])

masks= torch.tensor(masks.astype(np.float32)).permute(1,2,0)


In [None]:
imagesc(masks.sum(2))

In [None]:
data_dir