In [None]:
from time_templates.datareader.get_data import fetch_MC_data_from_tree, get_event_from_df
from time_templates.preprocessing.apply_cuts_df import apply_cuts_df
from time_templates.templates.event_templates import EventTemplate
from time_templates.templates.trace_templates import TraceTimeTemplate
import numpy as np
import matplotlib.pyplot as plt
from time_templates.utilities.plot import plot_profile_1d
from time_templates.misc.energy import SdlgE_resolution

In [None]:
df = fetch_MC_data_from_tree(primary='proton', energy='19_19.5', det='new_UUB_SSD_rcola', do_Rc_fit=False, cuts={'SdCosTheta': (0.6, 1.0), 'SdlgE': (19, 20.2)})
# df = df.query('SdCosTheta > 0.6')

In [None]:
reg_Rmu = 1/0.4**2
reg_lgE = 1#0.25
reg_factor = 10/0.1**2
RMIN = 450
RMAX = 2500
SMIN = 5

df_ = df.query(
    f"Sdr > {RMIN} & Sdr < {RMAX} & LowGainSat == 0 & WCDTotalSignal > {SMIN}"
)
rmin = df_.groupby("EventId")["Sdr"].min()
nstations = df_.groupby("EventId")["Sdr"].count()
eventids = rmin.index[(rmin < 1000) & (nstations > 2)]
print(len(eventids))
df = df.loc[eventids]

In [None]:
event = get_event_from_df(df, eventid=None, MC=False, verbose=False)
print(event)
print(event['MCXmax'])
print(event['MClgE'])
print(event['Rmu'][0])
ET = EventTemplate(event, verbose=False, station_cuts={'r': [RMIN, RMAX]}, do_start_time_fit=False)

_reg_lgE = reg_lgE/SdlgE_resolution(ET.lgE)**2

In [None]:
ET.fit_start_times(plot=True)

In [None]:
m = ET.fit_total_signals(plot=True, reg_lgE=_reg_lgE, reg_Rmu=0)
m

In [None]:
ET.station_cuts['Stotal_fit'] = (SMIN, 2000)
ET.setup_all(ET.Rmu_fit, ET.lgE_fit, ET.Xmax, 0)

In [None]:
ET.reset_fit()
m = ET.fit(Rmu_0=ET.Rmu_fit, lgE_0=ET.lgE_fit, fix_Rmu=True, fix_lgE=True, fix_Xmax=True, fix_Xmumax=True, fix_t0s=False,fix_factorSmu=True, fix_factorSem=True,
           reg_Rmu=reg_Rmu, reg_lgE=_reg_lgE, reg_factorSmu=reg_factor, reg_factorSem=reg_factor, no_scale=True)

In [None]:
#could set reg to uncertainty from total signal fit
m = ET.fit(Rmu_0=ET.Rmu_fit, lgE_0=ET.lgE_fit, fix_Rmu=False, fix_lgE=False, fix_Xmax=False, fix_Xmumax=True, fix_t0s=True, fix_factorSmu=True, fix_factorSem=True,
           reg_Rmu=reg_Rmu, reg_lgE=_reg_lgE, reg_factorSmu=reg_factor, reg_factorSem=reg_factor, no_scale=True)
m

In [None]:
ET.setup_scale_mask(tq_cut=1, neff_cut=1)

In [None]:
m = ET.fit(Rmu_0=ET.Rmu_fit, lgE_0=ET.lgE_0, fix_Rmu=False, fix_lgE=False, fix_Xmax=False, fix_Xmumax=True, fix_t0s=True, fix_factorSmu=True, fix_factorSem=True,
           reg_Rmu=reg_Rmu, reg_lgE=_reg_lgE, reg_factorSmu=reg_factor, reg_factorSem=reg_factor, no_scale=False)
print(ET.ndata, ET.ndof)
m

In [None]:
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
i = 0
t = ET.ts[i]
ttt = ET.TTTs[i]
scale = ET.scales[i]
data = ET.data[i]
mu = ET.expected_traces[i]
ax1.errorbar(t, data, yerr=np.sqrt(data))
ax1.plot(t, mu)
ax2.errorbar(t, data*scale, yerr=np.sqrt(mu*scale))
ax2.plot(t, mu*scale)
ax2.set_xlim([0, 1500])

In [None]:
#TODO: check scale, because something is wrong
plt.figure(figsize=(15, 5))
ttt.sigma_bl = 0
Smu = 1
Sem = 0
ttt.correct_total_signal_uncertainty = True
# plt.plot(t, ttt.get_wcd_total_trace(t, Smu, Sem, 0, 0))
# plt.plot(t, ttt.get_variance_wcd_total(t, Smu, Sem, 0, 0))
plt.plot(t, np.maximum(np.minimum(ttt.get_wcd_scale(t, Smu, Sem, 0, 0), 100), 0.5))

# plt.yscale('log')
plt.xlim([200, 1000])

In [None]:
axes = ET.plot_traces(plotMC=True);
for ax in axes:
    ax.set_xlim([0, 2000])

In [None]:
m.draw_profile("DeltaXmumaxXmax", bound=[0, 300]);

In [None]:
from time_templates.utilities.traces import make_new_bins_by_cutoff, rebin

In [None]:
from numba import njit
@njit(fastmath=True)
def rebin(trace, t, new_bins):
    nbins = len(new_bins)-1
    out = np.zeros(nbins)
    for i in range(nbins):
        ileft = new_bins[i]
        iright = new_bins[i+1]
        dt = iright - ileft
        out[i] = np.sum(trace[ileft:iright])/dt
    return out

In [None]:
i = 5
trace = event.stations[i].wcd_trace.get_total_trace()
t = event.stations[i].wcd_trace.t
ttt = ET.TTTs[i]
ttt.correct_total_signal_uncertainty = True
Smu = ET.Smu_LDF[i]
Sem = ET.Sem_pure_LDF[i]
Semmu = ET.Sem_mu_LDF[i]
Semhad = ET.Sem_had_LDF[i]
neff = ttt.get_wcd_neff_particles(t, Smu, Sem, Semmu, Semhad)
plt.plot(t, trace)
# plt.plot(t, neff)
neff_rebinned, new_bins = make_new_bins_by_cutoff(neff, len(neff), 20)
len(new_trace), len(new_bins)

In [None]:
new_trace = rebin(trace, t, new_bins)
neff_rebinned = rebin(neff, t, new_bins)
new_bins = np.array(new_bins)

new_t = t[new_bins[:-1]]
dt = t[new_bins[1:]] - t[new_bins[:-1]]

f, ax = plt.subplots(1, figsize=(14, 6))
ax.bar(new_t, new_trace, width=dt, ec='b', color='none', align='edge')

# plt.plot(t, trace)
# plt.yscale('log')

In [None]:
i = 3
t = ET.ts[i]
plt.plot(t, ET.data[i])
Smu = ET.Smu_LDF[i]
Sem = ET.Sem_pure_LDF[i]
Semmu = ET.Sem_mu_LDF[i]
Semhad = ET.Sem_had_LDF[i]
print(Smu, Sem, Semmu, Semhad)
plt.plot(t, ET.TTTs[i].get_wcd_neff_particles(t, Smu, Sem, Semmu, Semhad)/10)

In [None]:
f, ax = plt.subplots(1, figsize=(15, 6))
ET.event.stations[1].plot_trace(plotTT=True, ax=ax);
ax.set_xlim([50, 500])