In [2]:
from functools import partial
from os import path
from typing import List

import plotly.graph_objects as go
import plotly.express as px
import lmfit.model
import numpy as np
import altair as alt
from lmfit import Parameters, Parameter, Model
from lmfit.lineshapes import log2, s2pi, s2, tiny
from numpy import pi, exp, log, sqrt
from plotly.subplots import make_subplots
from scipy import linalg
from scipy.optimize import curve_fit
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
from utils import EcalDataIO
from quantum_clustering.ws_decomposition import convert_to_array_and_expand, weight_shape_decomp
from fit_to_gaussians.fit_to_gaussians import one_2d_gaussian
import matplotlib.pyplot as plt
from astroML import sum_of_norms
from lmfit.models import Gaussian2dModel

In [4]:
file = 5
data_dir = path.join(path.curdir, 'data')
calo = EcalDataIO.ecalmatio(path.join(data_dir, f"signal.al.elaser.IP0{file}.edeplist.mat"))
energies = EcalDataIO.energymatio(path.join(data_dir, f"signal.al.elaser.IP0{file}.energy.mat"))

event_id, n = '708', 2
expansion_factor, sigma, kernel_size = 1, (1, 2, 2), 7

calo_event, energy_event = calo[event_id], energies[event_id]
expanded_array, e_list = convert_to_array_and_expand(calo_event, energy_event, t=expansion_factor)

sigma_effective = expansion_factor * np.array(sigma)
kernel_size_effective = (kernel_size - 1) * 2 + 1
P, V, W, S = weight_shape_decomp(expanded_array, kernel_size_effective, sigma_effective)
PxS = P * S

data_to_fit = PxS.sum(axis=1)
X, Y = np.meshgrid(np.arange(data_to_fit.shape[0]), np.arange(data_to_fit.shape[1]), indexing='ij')

In [95]:
layout = dict(
    # yaxis=dict(scaleanchor="x", scaleratio=1),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    font=dict(color='rgb(220,220,220)'),
    height=600, width=750,
    title_text="Side By Side Subplots",
    coloraxis=dict(colorscale='viridis')
)

fig = make_subplots(2, 2)

hover_template = 'X: %{x}<br>' + 'Z: %{y}<br>'
hm = partial(go.Heatmap,
             coloraxis='coloraxis1',
             hovertemplate=hover_template)

fig.add_trace(
    hm(z=data_to_fit.T, hoverinfo='x'),
    row=1, col=1
)

fig.add_trace(
    hm(z=data_to_fit.T * 2),
    row='all', col=2
)

fig.add_trace(
    hm(z=data_to_fit.T),
    row=2, col=1
)

# fig.layout = layout
fig.update_layout(layout)
# fig.data[0].z[0][0] = -0.005
# fig.layout.coloraxis.colorscale = ((0, '#AA0000'), *fig.layout.coloraxis.colorscale[1:])
fig.show()

In [84]:
fig.layout.coloraxis.colorscale = ((0, '#FF0000'), *fig.layout.coloraxis.colorscale[1:])