In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from time_templates.utilities.plot import plot_profile_1d
from time_templates.utilities.fitting import plot_fit_curve
from matplotlib.gridspec import GridSpec
from scipy.interpolate import UnivariateSpline, LinearNDInterpolator, RegularGridInterpolator, NearestNDInterpolator, interp2d, RectBivariateSpline, griddata
from time_templates.templates.universality.lognormal_templates import get_m_s_lognormal_comp
from time_templates.utilities import atmosphere
from lognormal_templates import ms_parameters_func, get_m_s_lognormal_comp, get_interpolated_r_ms_parameters
from names import DICT_COMP_LABELS
v_get_m_s_lognormal_comp = np.vectorize(get_m_s_lognormal_comp)

In [None]:
df = pd.read_pickle('/home/mart/auger/data/time_templates/mean_df/df_means_merged_fitted_lognormal.pl')
df['MCcospsi_round'] = df.index.get_level_values(level=-1)

In [None]:
ct2_bins = df.index.get_level_values(level=0).unique()

def make_DX_plot(r, psi=None):

    f, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True)

    for ct2 in ct2_bins:
        ct2_mid = ct2.mid
        theta = np.arccos(np.sqrt(ct2_mid))
        if psi is None:
            df_bin =df.loc[ct2, :, r].groupby(level=0).mean()
            psi_ = np.pi/2
        else:
            psi_ = psi
            df_bin = df.loc[ct2, :, r, np.cos(psi)]        
        
        for comp, ax in zip(DICT_COMP_LABELS, axes.flatten()):
            DX = df_bin['MCDXstation'].values
            y = df_bin[f'wcd_{comp}_trace_mfit'].values
            yerr = df_bin[f'wcd_{comp}_trace_merr'].values
            mask = np.isfinite(DX*y*yerr)
            red_chi2 = df_bin[f'wcd_{comp}_trace_redchi2'].values
            pl = ax.errorbar(DX[mask], y[mask], yerr[mask]*np.maximum(1, np.sqrt(red_chi2[mask])),
                             marker='o', ls='')
            
            DXspace = np.linspace(DX[mask][0], DX[mask][-1])
            m, s = v_get_m_s_lognormal_comp(DXspace, theta, r, psi_, comp, interp_kind='linear')
            ax.plot(DXspace, m, color=pl[0].get_color())
            ax.grid(True)

In [None]:
make_DX_plot(800, 0)

In [None]:
ct2_bins = df.index.get_level_values(level=0).unique()


def make_r_plot(theta, DX):

    f, axes = plt.subplots(2, 2, figsize=(12, 12), sharex=True)

    ct2 = np.cos(theta)**2
    cospsis = sorted(df.index.get_level_values(level=-1).unique())
    for cp in cospsis:
        df_bin = df.loc[ct2, DX, :, cp]
        
        for comp, ax in zip(DICT_COMP_LABELS, axes.flatten()):
            r = df_bin['MCr'].values
            y = df_bin[f'wcd_{comp}_trace_mfit'].values
            yerr = df_bin[f'wcd_{comp}_trace_merr'].values
            mask = np.isfinite(r*y*yerr) & (r <2100)
            red_chi2 = df_bin[f'wcd_{comp}_trace_redchi2'].values
            pl = ax.errorbar(r[mask], y[mask], yerr[mask]*np.maximum(1, np.sqrt(red_chi2[mask])),
                             marker='o', ls='')

            rspace = np.linspace(500, 2000)
            m, s = v_get_m_s_lognormal_comp(DX, theta, rspace, np.arccos(cp), comp, interp_kind='linear')
            ax.plot(rspace, m, color=pl[0].get_color())
            ax.grid(True)

In [None]:
np.linspace(0.36, 1, 6)

In [None]:
make_r_plot(np.arccos(np.sqrt(ct2_bins[0].mid)), 400)

In [None]:
from scipy.optimize import least_squares

def ms_func(DX, sintheta, cospsi, x):
    """From A. Schulz 2016"""
    aX, bX, cX, dX, ageo, bgeo = x
    DXref = DX/400
    fX = aX + DXref*(bX + DXref*(cX + dX*DXref))
    fgeo = sintheta*(ageo*cospsi + bgeo*DXref)
    return fX + fgeo

def fit_m_s_func_params(df_bin, comp, key):

    m = df_bin[f'wcd_{comp}_trace_{key}fit'].values
    merr = df_bin[f'wcd_{comp}_trace_{key}err'].values * np.sqrt(df_bin[f'wcd_{comp}_trace_redchi2'])

    sintheta = np.sin(df_bin['MCTheta'].values)
    DX = df_bin['MCDXstation'].values
    cospsi = df_bin['MCcospsi_round'].values

    mask = np.isfinite(sintheta*DX*cospsi*m*merr) & (merr > 0)
    
    if key == 'm':
        x0 = np.array([6, -0.5, -0.1, 0.07, 0.05, 0.6])
    else:
        x0 = np.array([0.7, -0.5, 0.2, -0.03, 0.05, -0.4])
    
    ndof = len(DX[mask]) - len(x0)
    if ndof < 10:
        print("no success")
        x0[:] = np.nan
        return x0, x0

    def lq_func(x):
        return (ms_func(DX[mask], sintheta[mask], cospsi[mask], x) - m[mask])/merr[mask]

    res = least_squares(lq_func, x0)
    
    chi2 = res['cost']
    J = res["jac"]
    cov = np.linalg.inv(J.T.dot(J)) * chi2/ndof
#     print(chi2/ndof)1
    if res['success']:
        return res['x'], np.sqrt(np.diag(cov))
    else:
        print("no success", comp, key, df_bin['MCr'].mean())
        x0[:] = np.nan
        return x0, x0

In [None]:
rs = sorted(df.index.get_level_values(level=2).unique())[:-1]

nx = 6
nrs = len(rs)
d_comps = {}
for comp in ['muon', 'em', 'em_mu', 'em_had']:
    print("at", comp)
    empty_nan = np.zeros((nx, nrs))
    empty_nan[:] = np.nan
    d_comps[comp] = {}
    d_comps[comp]['m'] = empty_nan.copy()
    d_comps[comp]['s'] = empty_nan.copy()
    d_comps[comp]['merr'] = empty_nan.copy()
    d_comps[comp]['serr'] = empty_nan.copy()
    for i, r in enumerate(rs):
        df_bin = df.loc[:, :, r, :]
        p_mfit, perr_mfit = fit_m_s_func_params(df_bin, comp, 'm')
        p_sfit, perr_sfit = fit_m_s_func_params(df_bin, comp, 's')
        d_comps[comp]['m'][:, i] = p_mfit
        d_comps[comp]['s'][:, i] = p_sfit
        d_comps[comp]['merr'][:, i] = perr_mfit
        d_comps[comp]['serr'][:, i] = perr_sfit

In [None]:
comp ='em'
key = 's'
ip = 4
x = rs
y = d_comps[comp][key][ip]
yerr = d_comps[comp][key+'err'][ip]
plt.errorbar(x, y, yerr, marker='o', ls='')
from scipy.interpolate import interp1d, splrep, splev
# interp = interp1d(x, y,  kind='cubic', bounds_error=True, fill_value=np.nan)
# interp = splrep(x, y, w=1/yerr, k=3, s=len(y))
interp = UnivariateSpline(x, y, w=1/yerr, k=3, s=len(y), ext=3)
xspace = np.linspace(400, 2500, 100)
plt.plot(xspace, interp(xspace)) #splev(xspace, interp))

In [None]:
%timeit -n 1000 UnivariateSpline(x, y, w=1/yerr, k=3, s=len(y))

In [None]:
%timeit -n 1000 splrep(x, y, w=1/yerr, k=3, s=len(y))

In [None]:
%timeit -n 1000 splev(xspace, interp)

In [None]:
f, ax = plt.subplots(1)
cospsi = sorted(df['MCcospsi_round'].unique())
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for cp, marker in zip([cospsi[0], cospsi[-1]], ['o', 'x']):
    for ct2, color in zip(df.index.get_level_values(level=0).unique(), colors):
        df_ = df_bin.loc[ct2.mid, :, cp]

        pl = ax.errorbar(df_['MCDXstation'], df_[f'wcd_{comp}_trace_{key}fit'],
                         yerr=df_[f'wcd_{comp}_trace_{key}err']*np.sqrt(df_[f'wcd_{comp}_trace_redchi2']),
                         ls='', marker=marker, color=color)

        DX = np.linspace(df_['MCDXstation'].min(), df_['MCDXstation'].max())
        sintheta = np.sin(df_['MCTheta']).mean()
        ax.plot(DX, m_func(DX, sintheta, cp, res['x']), color=pl[0].get_color())

In [None]:
ct2_bins = df_odd.index.get_level_values(level=0).unique()
print(ct2_bins)
rs = np.array(sorted(df.index.get_level_values(level=2).unique()))
cps = np.array(sorted(df.index.get_level_values(level=3).unique()))

ct2_bins = [(ct2.left, ct2.right) for ct2 in ct2_bins]
f, axes = plt.subplots(4, 2, figsize=(15, 20), sharey=False)

nct2 = len(ct2_bins)
nrs = len(rs)
ncps = len(cps)

comp_labels = ['muon', 'em', 'em_mu', 'em_had']

d_comps = {}
empty_nan = np.zeros((nrs, ncps))
empty_nan[:] = np.nan
for ct2_bin in ct2_bins:
    d_comps[ct2_bin] = {}
    for comp in comp_labels:
        d_comps[ct2_bin][comp] = {}
        d_comps[ct2_bin][comp]['ma'] = empty_nan.copy()
        d_comps[ct2_bin][comp]['mb'] = empty_nan.copy()
        d_comps[ct2_bin][comp]['mc'] = empty_nan.copy()
        d_comps[ct2_bin][comp]['sa'] = empty_nan.copy()
        d_comps[ct2_bin][comp]['sb'] = empty_nan.copy()
        d_comps[ct2_bin][comp]['sc'] = empty_nan.copy()
    
theta_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
r_markers = ['o', 's', 'x', 'v', '^', '>', '*', '<', 'd', 'h', '1', '2', '3', '4', '+']
from collections import defaultdict

dd = defaultdict(list)
for i, ct2_bin in enumerate(ct2_bins):
    ct2 = np.mean(ct2_bin)+0.01
    print(f"at {ct2}")
    for j, r in enumerate(rs):
#         print(f"at {r}")
        for k, cp in enumerate(cps):
            dd['ct2'].append(ct2)
            dd['r'].append(r)
            dd['cp'].append(cp)
            try:
                df_bin = df.loc[ct2, :, r, cp]
            except KeyError:
                continue
                
            for l, comp in enumerate(comp_labels):
                ax1 = axes[l, 0]
                ax2 = axes[l, 1]
                ax1.set_ylabel('m')
                ax2.set_ylabel('s')
                ax1.set_xlabel('DX g/cm2')
                ax2.set_xlabel('DX g/cm2')
                ax1.set_title(comp)
                ax2.set_title(comp)
                ax1.grid(True)
                ax2.grid(True)
                if cp != cps[1] or r not in [800]:
                    ax1 = None
                    ax2 = None
                pm = fit_ab(df_bin, comp, 'm', ax=ax1, color=theta_colors[i], marker='s', plt_fit=True)
                ps = fit_ab(df_bin, comp, 's', ax=ax2, color=theta_colors[i], marker='s', plt_fit=True)
                d_comps[ct2_bin][comp]['ma'][j, k] = pm[0]
                d_comps[ct2_bin][comp]['mb'][j, k] = pm[1]
                d_comps[ct2_bin][comp]['mc'][j, k] = pm[2]
                d_comps[ct2_bin][comp]['sa'][j, k] = ps[0]
                d_comps[ct2_bin][comp]['sb'][j, k] = ps[1]
                d_comps[ct2_bin][comp]['sc'][j, k] = ps[2]
plt.tight_layout()

In [None]:
i, j = 2, 2
r = rs[i]
psi = np.arccos(cps[j])

ms = []
DXs = []

f, ax = plt.subplots(1)

Xmaxs = np.linspace(600, 1000)

for ct2 in ct2_bins:
    theta = np.arccos(np.sqrt(np.mean(ct2)))
    a = d_comps[ct2]['em']['ma'][i, j]
    b = d_comps[ct2]['em']['mb'][i, j]
    c =d_comps[ct2]['em']['mc'][i, j]
    sintheta = np.sin(theta)
    DX = atmosphere.DX_at_station_isothermal(r, psi, theta, Xmaxs)
    m = a + b*DX/400 + c*sintheta*DX/400
    ax.plot(DX, m)
#     DXs.append(DX)
#     ms.append(a + b*DX/400 + c*sintheta*DX/400)
    
# ax.plot(DXs, ms)

In [None]:
# Think about how to do this

In [None]:
Xmax = 800

i, j = 2, 2
r = rs[i]
psi = np.arccos(cps[j])

ms = []
DXs = []
for ct2  in ct2_bins:
    theta = np.arccos(np.sqrt(np.mean(ct2)))
    a = d_comps[ct2]['em']['ma'][i, j]
    b = d_comps[ct2]['em']['mb'][i, j]
    DX = atmosphere.DX_at_station_isothermal(r, psi, theta, Xmax)
    DXs.append(DX)
    ms.append(m_s_DX_func(DX, a, b))
ms = np.array(ms)
DXs = np.array(DXs)
mask = np.isfinite(ms*DXs)
f, ax = plt.subplots(1)
plot_fit_curve(DXs[mask], ms[mask], lambda x, a, b: a + b*x/400, ax=ax, smoother_x=True, p0=[6, -0.1])
ax.legend()

In [None]:
def make_interpolation_ab(ct2_bin, comp, key='ma'):
    #Procedure: first nearst interpolator for filling nan then regulargrid for speed
    #boundaries are still tricky
    #use more simulations, larger energy

    x, y = np.meshgrid(rs, cps)
    x = x.flatten()
    y = y.flatten()

    val = d_comps[ct2_bin][comp][key].flatten()
    mask = np.isfinite(val)
    interp = NearestNDInterpolator(list(zip(x[mask], y[mask])), val[mask])
    new_val = interp(list(zip(x, y))).reshape((nrs, ncps))
    nspline = 3
    interp = RectBivariateSpline(rs, cps, new_val, kx=nspline, ky=nspline)
#     interp = RegularGridInterpolator((rs, cps), new_val, bounds_error=None, fill_value=None)
#     interp = interp3d(*x, new_val, vectorized=False, fill_value=np.nan)
    return interp

In [None]:
d_interps = {}
for ct2_bin in ct2_bins:
    d_interps[ct2_bin] = {}
    for comp in comp_labels:
        d_interps[ct2_bin][comp] = {}
        for key in ['ma', 'mb', 'sa', 'sb']:
            d_interps[ct2_bin][comp][key] = make_interpolation_ab(ct2_bin, comp, key)

In [None]:
def find_ct2(ct2):
    for i, ct2_bin in enumerate(ct2_bins):
        if i == 0:
            if ct2_bin[0] <= ct2 < ct2_bin[1]:
                return ct2_bin
        elif i == len(ct2_bins) - 1:
            if ct2_bin[0] < ct2 <= ct2_bin[1]:
                return ct2_bin
        else:
            if ct2_bin[0] < ct2 < ct2_bin[1]:
                return ct2_bin
            
    raise ValueError(f"{ct2} was not found in", ct2_bins)
    
    
def get_ab_ms_parameters_comp(theta, r, psi, comp):
    """
    Only have to call this once when fix theta, r psi
    """
    ct2 = np.cos(theta)**2
    ct2_bin = find_ct2(ct2)
    cp = np.cos(psi)
    d = d_interps[ct2_bin][comp]
    ma = d['ma'](r, cp)
    mb = d['mb'](r, cp)
    sa = d['sa'](r, cp)
    sb = d['sb'](r, cp)
    return ma, mb, sa, sb



def get_m_s_lognormal_comp(DX, theta, r, psi, comp):
    ma, mb, sa, sb = get_ab_ms_parameters_comp(theta, r, psi, comp)
    m = m_s_DX_func(DX, ma, mb)
    s = m_s_DX_func(DX, sa, sb)
    return m, s
    
v_get_m_s_lognormal_comp = np.vectorize(get_m_s_lognormal_comp)

In [None]:
get_m_s_lognormal_comp(1000, np.deg2rad(30), 1000, 0, 'muon')

In [None]:
%timeit -n 100 get_m_s_lognormal_comp(200, 0, 1000, 0, 'muon')

In [None]:
DX = 150
r = np.linspace(500, 2000)
m, s = v_get_m_s_lognormal_comp(DX, np.deg2rad(30), r, 0, 'em_had')
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.plot(r, m)
ax2.plot(r, s)


In [None]:

f, axes = plt.subplots(4, 2, figsize=(12, 15), sharey=False)


for i, ct2 in enumerate(ct2s):

    print(f"at {ct2}")
    for j, r in enumerate(rs):
#         print(f"at {r}")
        for k, cp in enumerate(cps):
            dd['ct2'].append(ct2)
            dd['r'].append(r)
            dd['cp'].append(cp)
            try:
                df_bin = df.loc[ct2, :, r, cp]
            except KeyError:
                print(ct2, r, cp)
                continue
            for l, comp in enumerate(comp_labels):
                ax1 = axes[l, 0]
                ax2 = axes[l, 1]
                ax1.set_ylabel('m')
                ax2.set_ylabel('s')
                ax1.set_xlabel('DX g/cm2')
                ax2.set_xlabel('DX g/cm2')
                ax1.set_title(comp)
                ax2.set_title(comp)
                ax1.grid()
                ax2.grid()
                if cp != cps[1] or r != 600:
                    ax1 = None
                    ax2 = None
                pm = fit_ab(df_bin, 'm', ax=ax1, color=theta_colors[i], marker=r_markers[j])
                ps = fit_ab(df_bin, 's', ax=ax2, color=theta_colors[i], marker=r_markers[j])

                if ax1 is not None:
                    DX = np.linspace(df_bin['MCDXstation'].min(), df_bin['MCDXstation'].max())
                    m, s = v_get_m_s_lognormal_comp(DX, np.arccos(np.sqrt(ct2)), r, np.arccos(cp), comp)
  
                    ax1.plot(DX, m, ls='--')#, color=theta_colors[i])
                    ax2.plot(DX, s, ls='--')#, color=theta_colors[i])

plt.tight_layout()