In [1]:
import math
import numpy as np
import cmath as cm
import pandas as pd
import random
import pickle
import os
import datetime

import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import chainer
from chainer import configuration
from chainer.dataset import convert
import sklearn
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from tabulate import tabulate

In [2]:
#constants
hbarc = 197.3

#Heavy quark sector
JPsi = 3096.90
# Pion = 138.0394

Proton = 938.27

Dplus = 1869.66
Dstar0 = 1864.83 #This value is for D0; I just happened to name it Dstar0
#and found it too much a hassle to rename the variable
#https://pdglive.lbl.gov/Particle.action?init=0&node=S032&home=MXXX035
Sigmaplus = 1189.37
Sigmaplus_C = 2452.9 #https://pdglive.lbl.gov/Particle.action?init=0&node=B104&home=BXXX040


T1 = (JPsi + Proton)/hbarc
T2 = (Sigmaplus_C + Dstar0)/hbarc

T4 = 4500/hbarc

mu_JPsiP   = 1/(1/JPsi + 1/Proton)
mu_Sigmaplus_CDstar0  = 1/(1/Sigmaplus_C + 1/Dstar0)

mu1 = mu_JPsiP/hbarc
mu2 = mu_Sigmaplus_CDstar0/hbarc

NEpoints = 100 #37
delE = 200/(NEpoints)

#Generate energy axis
Einput = np.zeros([NEpoints,],dtype = 'float64')
for ndx in range(NEpoints):
    rand_points = np.random.uniform(low=4212+(ndx)*delE, high=4212+(ndx+1)*delE) #np.random.uniform(low=T1*hbarc+(ndx)*delE, high=T1*hbarc+(ndx+1)*delE)
    Einput[ndx] = rand_points

k1 = np.sqrt(Einput**2.0-(T1*hbarc)**2.0)
k2 = np.zeros([NEpoints,],dtype = 'complex_')
for kndx in range(len(Einput)):
    k2pts = cm.sqrt(Einput[kndx]**2.0-(T2*hbarc)**2.0)
    k2[kndx] = k2pts
    
Nreal = 3000 #2000 #2000 #300 #1000 
Nimag = 3000 #2000 #2000 #300 #1000

#Generate Npole poles within the counting region
#units in MeV
Erealbelow = np.random.uniform(low=T2*hbarc-100, high=T2*hbarc, size=int(Nreal/2))
Erealabove = np.random.uniform(low=T2*hbarc, high=T2*hbarc+100, size=int(Nreal/2))
Ereal = np.concatenate((Erealbelow, Erealabove))
Eimag = np.random.uniform(low=0.5, high=50, size=Nimag)
    
#Generate poles beyond the counting region
#units in MeV
Erealbelow = np.random.uniform(low=T2*hbarc-2500, high=T2*hbarc-2000, size=int(Nreal/2))#np.random.uniform(low=T1*hbarc-2000, high=T1*hbarc-100, size=int(Nreal/2))
Erealabove = np.random.uniform(low=T2*hbarc+1000, high=T2*hbarc+1500, size=int(Nreal/2))#np.random.uniform(low=T2*hbarc+500, high=T2*hbarc+600, size=int(Nreal/2))

Erealfar = np.concatenate((Erealbelow, Erealabove))
Eimagfar = np.random.uniform(low=700.0, high=2000.0, size=Nimag)

In [4]:
class unif_pole:
    def __init__(self, RS, Ereal, Eimag):
        self.Ereal = Ereal
        self.Eimag = Eimag
        self.RS  = RS
        Epole = Ereal - (1j)*Eimag
        self.pos = Epole
        
        #compute uniformized momentum pole
        #for channel 1 and channel 2
        k1pole = cm.sqrt((Epole/hbarc)**2-T1**2)
        k2pole = cm.sqrt((Epole/hbarc)**2-T2**2)
        
        #Riemann sheet assignment
        beta1  = RS[0]*abs(k1pole.imag)
        beta2  = RS[1]*abs(k2pole.imag)
        
        
        #Get the real part, we need to be consistent with signs
        alpha1 = -np.sign(beta1)*np.abs(k1pole.real)
        alpha2 = -np.sign(beta2)*np.abs(k2pole.real)

        #Just for counterchecking
        #signs of beta1 and beta2 should agree 
        #with RS[0] and RS[1], respectively
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.beta1 = beta1
        self.beta2 = beta2
        
        #Construct pole channel momenta
        polep1 = (1j)*beta1 + alpha1
        polep2 = (1j)*beta2 + alpha2
        
        Delta = cm.sqrt(T2**2 - T1**2)
        self.Delta = Delta
        
        #Uniformization of the assigned pole
        omega_pole = (polep1 + polep2)/Delta
        recip_omg_pol = 1/omega_pole
        self.omega_pole = omega_pole
        self.recip_omg_pol = recip_omg_pol
        
        #Get pole regulator
        #[bt] sheet
        if RS==[-1,1]: 
            omega_reg = np.abs(recip_omg_pol)*cm.exp(-0.5*np.pi*(1j))
        #[bb] sheet
        elif RS==[-1,-1]:
            omega_reg = np.abs(recip_omg_pol)*cm.exp(-0.5*np.pi*(1j))
        #[tb] sheet
        elif RS==[1,-1]:
            omega_reg = np.abs(recip_omg_pol)*cm.exp(-0.5*np.pi*(1j))
            
        self.omega_reg = omega_reg
        self.recip_omg_reg = 1/omega_reg
        
        p1_reg = omega_reg + 1/omega_reg
        p2_reg = omega_reg - 1/omega_reg
        
        #Riemann sheet identifier for pole regulator
        def RSlabel(pimag1, pimag2):
            if pimag1>0 and pimag2>0:
                RS = 'tt' #sheet 1
            elif pimag1<0 and pimag2>0:
                RS = 'bt' #sheet 2
            elif pimag1<0 and pimag2<0:
                RS = 'bb' #sheet 3
            elif pimag1>0 and pimag2<0:
                RS = 'tb' #sheet 4
            return RS
        #If you want to check the Riemann sheet of pole regulator
        self.regulator = ['{:.2f}'.format(np.sqrt(p1_reg**2.0+T1**2.0)*hbarc), RSlabel(p1_reg.imag, p2_reg.imag)]
        self.assignedpole = ['{:.2f}'.format(np.sqrt(polep1**2.0+T1**2.0)*hbarc), RSlabel(polep1.imag, polep2.imag)]
    
    #The indent on this part is very important
    #Calculate the S-matrix contribution of the uniformized pole
    def smat11(self, Ecm):
        #Get channel momenta of Ecm
        p1 = np.sqrt((Ecm/hbarc)**2.0-T1**2.0)
        p2 = np.zeros([NEpoints,],dtype = 'complex_')
        for pndx in range(len(Ecm)):
            p2_pts = cm.sqrt((Ecm[pndx]/hbarc)**2.0-T2**2.0)
            p2[pndx] = p2_pts

        #Get uniformized parameter
        omega = (p1 + p2)/self.Delta
        
        #Numerator of S-matrix
        Numpol = (omega-np.conj(self.recip_omg_pol))*(omega+self.recip_omg_pol)
        Numreg = (omega-np.conj(self.recip_omg_reg))*(omega+self.recip_omg_reg)
        Num = Numpol*Numreg
        
        #Denominator of S-matrix
        Denpol = (omega-self.omega_pole)*(omega+np.conj(self.omega_pole))
        Denreg = (omega-self.omega_reg)*(omega+np.conj(self.omega_reg))
        Den = Denpol*Denreg
            
        return np.abs((self.omega_pole*self.omega_reg))**2.0*Num/Den
    
    def smat22(self, Ecm):
        #Get channel momenta of Ecm
        p1 = np.sqrt((Ecm/hbarc)**2.0-T1**2.0)
        p2 = np.zeros([NEpoints,],dtype = 'complex_')
        for pndx in range(len(Ecm)):
            p2_pts = cm.sqrt((Ecm[pndx]/hbarc)**2.0-T2**2.0)
            p2[pndx] = p2_pts

        #Get uniformized parameter
        omega = (p1 + p2)/self.Delta
        
        #Numerator of S-matrix
        Numpol = (omega+np.conj(self.recip_omg_pol))*(omega-self.recip_omg_pol)
        Numreg = (omega+np.conj(self.recip_omg_reg))*(omega-self.recip_omg_reg)
        Num = Numpol*Numreg
        #Denominator of S-matrix
        
        Denpol = (omega-self.omega_pole)*(omega+np.conj(self.omega_pole))
        Denreg = (omega-self.omega_reg)*(omega+np.conj(self.omega_reg))
        Den = Denpol*Denreg
            
        return np.abs((self.omega_pole*self.omega_reg))**2.0*Num/Den
    
    def smatdet(self, Ecm):
        #Get channel momenta of Ecm
        p1 = np.sqrt((Ecm/hbarc)**2.0-T1**2.0)
        p2 = np.zeros([NEpoints,],dtype = 'complex_')
        for pndx in range(len(Ecm)):
            p2_pts = cm.sqrt((Ecm[pndx]/197.3)**2.0-T2**2.0)
            p2[pndx] = p2_pts

        #Get uniformized parameter
        omega = (p1 + p2)/self.Delta
        
        #Numerator of S-matrix
        Numpol = (omega-np.conj(self.omega_pole))*(omega+self.omega_pole)
        Numreg = (omega-np.conj(self.omega_reg))*(omega+self.omega_reg)
        Num = Numpol*Numreg
        
        #Denominator of S-matrix
        Denpol = (omega-self.omega_pole)*(omega+np.conj(self.omega_pole))
        Denreg = (omega-self.omega_reg)*(omega+np.conj(self.omega_reg))
        Den = Denpol*Denreg
            
        return np.abs((self.omega_pole*self.omega_reg))**2.0*Num/Den

### On cusps

The off-diagonal term of the $S$-matrix reads as
\begin{equation}
S_{12}^2 = S_{11}S_{22} - \text{det}(S).
\end{equation}

When plotting $S_{12}$, we will encounter discontinuity in our plots. The ```np.angle(z)``` has a left branch cut. At best, we could rotate the cut and shift the discontinuity. However, this poses a problem when we want to consider a wide range of $E$. To circumvent this problem, we introduce the ```detect_cusp``` function. What this does is to detect the discontinuities of the input and changes the sign of the element from there onwards until it detects another discontinuity. For example, given an array with $20$ elements, e.g. ```Z[0], Z[1],..., Z[19]``` and suppose there are discontinuities at
```Z[9], Z[14], Z[18]```. The output of the function is an array with:

1) The original ```Z[0]``` to ```Z[8]``` of the input array
2) Negated ( $Z \to -Z$) ```Z[9]``` to ```Z[13]``` of the input array 
3) Original ```Z[14]``` to ```Z[17]``` of the input array
4) Negated ```Z[18]``` and ```Z[19]``` of the input array 

These 4 then are concatenated such that the len(output) = len(input). 

We use this function as follows:

1) Compute $\text{det}(S)$:
```smatdet = np.prod(smatdet, axis = 0)```
2) Construct $S_{12}^2$:
```pwat12sqr = (-1.0)*(smat11 * smat22 - smatdet)/4.0```
3) Construct $S_{12}$ using polar representation. Get the modulo first:
```mod12sqr = np.abs(pwat12sqr)```
4) Get the argument of $S_{12}^2$ which we will divide by 2 upon construction of the polar representation:
```arg12sqrBC = np.array([np.angle(z) for z in pwat12sqr])```
5) Express in polar representation:
```pwat12withdisc = np.sqrt(mod12sqr)*np.exp(1j* arg12sqrBC/ 2)```
6) Express final pwat12 by using detect_cusp function multiplied by heaviside function:
```detect_cusp(pwat12withdisc)*np.heaviside(Einput-T2*197.3,0)```

The multiplication of heaviside should only be done after using ```detect_cusp```. 
If we put heaviside before ```detect_cusp```, the sign of the real and imag part of our plots will be inverted for some cases.
The reason is python considers ```0. +(-) 0.j``` to have a positive (negative) imaginary part.
It immediately detects a cusp at the threshold if we put the heaviside function inside ```detect_cusp```.

Note that there is no physics behind this. This is purely motivated by the numerical output of python.

In [5]:
def detect_cusp(input_array, threshold=0.1):
    # Calculate the derivative of the real and imaginary parts
    real_diff = np.diff(np.real(input_array))
    imag_diff = np.diff(np.imag(input_array))

    # Detect points where the real part has a cusp
    real_cusp_indices = np.where(np.abs(real_diff) > threshold)[0] + 1

    # Detect points where the imaginary part has a cusp
    imag_cusp_indices = np.where(np.abs(imag_diff) > threshold)[0] + 1

    # Combine the indices and remove duplicates
    all_cusp_indices = np.unique(np.concatenate((real_cusp_indices, imag_cusp_indices)))

    # Modify the array by negating the sign for both real and imaginary parts
    for cusp_index in all_cusp_indices:
        input_array[cusp_index:] = -input_array[cusp_index:]

    return input_array

In [20]:
# inspect = True
inspect = True
directory = 'curriculum02_training'
#curr01 datasets: 00, 01, 11, 21

#directory = 'curriculum02_training'
#curr02 datasets: 00, 01, 11, 21, 02
#directory = 'curriculum03_training'
#curr03 datasets: 00, 01, 11, 21, 02, 12
#directory = 'curriculum04_training'
#curr04 datasets: 00, 01, 11, 21, 02, 12, 22
#directory = 'curriculum05_training'
#curr05 datasets: 00, 01, 11, 21, 02, 12, 22, 10
#directory = 'curriculum06_training'
#curr06 datasets: 00, 01, 11, 21, 02, 12, 22, 10, 20
#directory = 'curriculum07_training'
#curr07 datasets: 00, 01, 11, 21, 02, 12, 22, 10, 20, 30


#directory = 'curriculum08_training'
#curr08 datasets: 00, 01, 11, 21, 02, 12, 22, 10, 20, 30, 03

#directory = 'curriculum32_training'
#all datasets


#directory = 'sample_plot'

if not os.path.isdir(directory):
    os.makedirs(directory)
print('Number of poles to be generated per class:', Nreal*Nimag)
print('Ndata to be generated=', 4*Nreal*Nimag)
print('Your directory is:', directory)

Number of poles to be generated per class: 4000000
Ndata to be generated= 16000000
Your directory is: curriculumXX_training


In [7]:
#descriptive labels of network output
#at most 4 poles in all RS
labelz = [
#default no pole
    'no nearby pole',                          #00
#poles in [bt]    
    '1 pole  in [bt]',                          #01
    '2 poles in [bt]',                         #02
    '3 poles in [bt]',                         #03
    '4 poles in [bt]',                         #04
#[bt] and [bb] no shadow pair    
    '3 poles in [bt] and 1 pole  in [bb]',      #05
    '2 poles in [bt] and 1 pole  in [bb]',      #06
    '2 poles in [bt] and 2 poles in [bb]',     #07 
    '1 pole  in [bt] and 2 poles in [bb]',      #08
    '1 pole  in [bt] and 3 poles in [bb]',     #09
    '1 pole  in [bt] and 1 pole  in [bb]',      #10
#poles in [bb] only    
    '1 pole  in [bb]',                          #11
    '2 poles in [bb]',                         #12
    '3 poles in [bb]',                         #13
    '4 poles in [bb]',                         #14
#[bb] and [tb] no shadow pair    
    '3 poles in [bb] and 1 pole  in [tb]',      #15
    '2 poles in [bb] and 1 pole  in [tb]',      #16
    '2 poles in [bb] and 2 poles in [tb]',     #17   
    '1 pole  in [bb] and 2 poles in [tb]',      #18
    '1 pole  in [bb] and 3 poles in [tb]',     #19
    '1 pole  in [bb] and 1 pole  in [tb]',      #20
#poles in [tb] only    
    '1 pole  in [tb]',                          #21
    '2 poles in [tb]',                         #22
    '3 poles in [tb]',                         #23
    '4 poles in [tb]',                         #24    
#[tb] and [bt]
    '3 poles in [tb] and 1 pole  in [bt]',      #25
    '2 poles in [tb] and 1 pole  in [bt]',      #26
    '2 poles in [tb] and 2 poles in [bt]',     #27
    '1 pole  in [tb] and 2 poles in [bt]',      #28
    '1 pole  in [tb] and 3 poles in [bt]',     #29 
    '1 pole  in [tb] and 1 pole  in [bt]',      #30      
#poles in all three
    '2 poles in [bt] and 1 pole  in [bb] and 1 pole  in [tb]',    #31
    '1 pole  in [bt] and 2 poles in [bb] and 1 pole  in [tb]',    #32
    '1 pole  in [bt] and 1 pole  in [bb] and 2 poles in [tb]',    #33
    '1 pole  in [bt] and 1 pole  in [bb] and 1 pole  in [tb]'      #34
]

In [8]:
hepdata = pd.read_csv("hep.csv", usecols=[0])

hep_data = hepdata.loc[(hepdata.MEV > 4260) & (hepdata.MEV < 4380)].copy()

E_exp = hep_data["MEV"].tolist()

In [9]:
def skip_duplicate(real1, imag1, Nreal, Nimag):
    # Create lists of available real and imag values
    real_list = [entry for entry in range(1, Nreal) if entry != real1]
    imag_list = [entry for entry in range(1, Nimag) if entry != imag1]

    # Randomly choose real values without duplication
    real_choices = np.random.choice(real_list, 10, replace=False)
    # Randomly choose imag values without duplication
    imag_choices = np.random.choice(imag_list, 10, replace=False)

    # Combine real and imag values into two lists
    real_values = [real1] + list(real_choices)
    imag_values = [imag1] + list(imag_choices)

    indx = [real_values, imag_values]
    return indx

In [10]:
def export_data(Einput, ReT11, ImT11, labelout, data_info, output_directory):
    # Create the specified output directory if it doesn't exist
    os.makedirs(output_directory, exist_ok=True)
    
    # Define the data to export
    data_to_export = {
        'Einput.pkl': Einput,
        'ModEsq.pkl': ModEsq,
        'labelout.pkl': labelout,
        'data_info.pkl': data_info,
    }
    
    # Export each piece of data
    for file_name, data in data_to_export.items():
        with open(os.path.join(output_directory, file_name), 'wb') as file:
            pickle.dump(data, file, protocol=4)
    
    # Collect data for input layer
    EinputMod = np.concatenate((Einput,ModEsq), axis=1)
    
#     # Alternatively, you can design a DNN with Einput, ReT11, and ImT11 in the input layer
#     T11 = np.concatenate((Einput, ReT11, ImT11), axis=1)
    
    # Export the collected data for the input layer
    data_to_export = {
        'EinputMod.pkl': EinputMod
    }
    
    # Export each piece of data for the input layer
    for file_name, data in data_to_export.items():
        with open(os.path.join(output_directory, file_name), 'wb') as file:
            pickle.dump(data, file, protocol=4)
    
    print('Export completed.')

In [11]:
def import_data(directory):
    # Construct file paths
    file_paths = {
        'Einput': os.path.join(directory, 'Einput.pkl'),
        'ModEsq': os.path.join(directory, 'ModEsq.pkl'),
        'labelout': os.path.join(directory, 'labelout.pkl'),
        'data_info': os.path.join(directory, 'data_info.pkl')
    }

    # Initialize empty dictionaries for data
    data = {}

    # Load data from files
    for key, file_path in file_paths.items():
        with open(file_path, 'rb') as file:
            data[key] = pickle.load(file)

    return data['Einput'], data['ModEsq'], data['labelout'], data['data_info']

# Example usage:
# Einput, ReT11, ImT11, labelout, data_info = import_data('input_directory')


In [12]:
# def seerealimagpart(Einput, ModEsq, RePWAT, ImPWAT, labelout, data_info):
#     # Randomly select an index for data sample
#     chckind = np.random.randint(0, len(labelout))
    
#     # Define threshold values and corresponding points for plotting
# #     thresholds = [
# #         ('T_JPsiP', T1),
# #         ('T_Sigmaplus_CDstar0', T2),
# #         ('T_4', T4)
# #     ]
    
# #     threshold_values = [
# #     [-max(max(RePWAT[chckind]**2.0+ImPWAT[chckind]**2.0), max(np.abs(RePWAT[chckind])), max(np.abs(ImPWAT[chckind]))),
# #      max(max(RePWAT[chckind]**2.0+ImPWAT[chckind]**2.0), max(np.abs(RePWAT[chckind])), max(np.abs(ImPWAT[chckind])))
# #     ] for _ in range(3)
# #                         ]

#     # Set up the plot
#     # max_value = max(max(RePWAT[chckind]**2.0 + ImPWAT[chckind]**2.0), max(np.abs(RePWAT[chckind])), max(np.abs(ImPWAT[chckind])))
#     # plt.ylim(-max_value - 0.05, max_value + 0.05)
#     # plt.plot([T1, T1], [0, 0], 'red')  # Horizontal line
#     # plt.plot([T2, T2], [0, 0], 'red')  # Horizontal line
#     # plt.plot([T4, T4], [0, 0], 'red')  # Horizontal line
#     # for threshold, values in zip(thresholds, threshold_values):
#     #     plt.plot(threshold[1], values, 'red')
    
#     # Plot data points
#     plt.axhline(y = 0, color = 'r', linestyle = '-')
#     plt.axvline(x = T1*hbarc, color = 'r' )
#     plt.axvline(x = T2*hbarc, color = 'r')
    
#     plt.plot(Einput[chckind], RePWAT[chckind], '+', label='RePWAT')
#     plt.plot(Einput[chckind], ImPWAT[chckind], '*', label='ImPWAT')
#     plt.plot(Einput[chckind], ModEsq[chckind], 'o', label='ModSq')
#     plt.legend(loc = 'upper left',frameon=True)
    
#     # Set plot labels and title
#     plt.title('Input Data', fontsize=15)
#     plt.xlabel('$E_{cm}$ (MeV)', fontsize=15)
#     plt.xticks(fontsize=15)
#     plt.ylabel('$Re T_{11}$, $Im T_{11}$', fontsize=15)
#     plt.yticks(fontsize=15)
#     plt.tight_layout()

#     isactive = ['No', 'Yes']
    

    
#     # Display class information in a table
#     print('class', '{:02d}'.format(labelout[chckind]), ':', labelz[labelout[chckind]])
#     table_data = [
#         ['n', 'Energy Pole (MeV)', 'RS', 'Active']
#     ]
#     for i in range(7):
#         table_data.append([str(i+1), '{:.2f}'.format(data_info[chckind][0][i]),
#                            data_info[chckind][1][i], isactive[data_info[chckind][2][i]]])
    
#     print(tabulate(table_data))
    
#     table = plt.table(cellText=table_data, loc='upper left', colWidths=[0.1] * 4, cellLoc='center', edges='open')
#     table.auto_set_font_size(False)
#     table.set_fontsize(10)
#     table.scale(1, 1.5)  # Adjust the scaling as needed
    
#     # Show the plot and return the selected index
#     return plt.show(), chckind

In [13]:
def generate_timestamp():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d%H%M%S")

In [14]:
def seerealimagpart(Einput, ModEsq, Re11, Im11, Re22, Im22, labelout, data_info, 
                    save_dir='path/to/save'):
    
    #assert isinstance(Einput, np.ndarray) and Einput.dtype == np.complex128, "Einput must be a complex numpy array"
    #assert isinstance(ModEsq, np.ndarray) and ModEsq.dtype == np.complex128, "ModEsq must be a complex numpy array"
    #assert isinstance(Re11, np.ndarray) and Re11.dtype == np.complex128, "Re11 must be a complex numpy array"
    #assert isinstance(Im11, np.ndarray) and Im11.dtype == np.complex128, "Im11 must be a complex numpy array"
    #assert isinstance(Re22, np.ndarray) and Re22.dtype == np.complex128, "Re22 must be a complex numpy array"
    #assert isinstance(Im22, np.ndarray) and Im22.dtype == np.complex128, "Im22 must be a complex numpy array"
    #assert isinstance(labelout, np.ndarray) and labelout.dtype == np.int, "labelout must be an integer numpy array"

    # Randomly select an index for data sample
    chckind = np.random.randint(0, len(labelout))
    
    # Create the save directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    # Set up the figure with a 2x2 grid layout
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    # Plot data points for the first set
    axs[0, 0].axhline(y=0, color='r', linestyle='-')
    # axs[0, 0].axvline(x=T1*hbarc, color='r')
    axs[0, 0].axvline(x=T2*hbarc, color='r')
    axs[0, 0].plot(Einput[chckind], Re11[chckind], '+', label='Re11')
    axs[0, 0].plot(Einput[chckind], Im11[chckind], '*', label='Im11')
    axs[0, 0].plot(Einput[chckind], ModEsq[chckind], 'o', label='ModSq')
    axs[0, 0].legend(loc='upper left', frameon=True)
    axs[0, 0].set_title('$T_{11}$')
    axs[0, 0].set_xlabel('$E_{cm}$ (MeV)')
    axs[0, 0].set_ylabel('$Re T_{11}$, $Im T_{11}$, $|T_{11}+T_{12}|^2$')

    isactive = ['No', 'Yes']
    # Display class information in a table for the first set
    axs[0, 1].axis('off')  # Turn off axis for the table subplot
    table_data = [
        ['n', 'Energy Pole (MeV)', 'RS', 'Active']
    ]
    for i in range(10):
        table_data.append([str(i+1), '{:.2f}'.format(data_info[chckind][0][i]),
                           data_info[chckind][1][i], isactive[data_info[chckind][2][i]]])

    table = axs[0, 1].table(cellText=table_data, cellLoc='center', colWidths=[0.2] * 4, loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.5, 1.5)  # Adjust the scaling as needed

    # Plot data points for the second set
    axs[1, 0].axhline(y=0, color='r', linestyle='-')
    # axs[1, 0].axvline(x=T1*hbarc, color='r')
    axs[1, 0].axvline(x=T2*hbarc, color='r')
    axs[1, 0].plot(Einput[chckind], Re22[chckind], '+', label='Re22')
    axs[1, 0].plot(Einput[chckind], Im22[chckind], '*', label='Im22')
    axs[1, 0].plot(Einput[chckind], ModEsq[chckind], 'o', label='ModSq2')
    axs[1, 0].legend(loc='upper left', frameon=True)
    axs[1, 0].set_title('$T_{22}$')
    axs[1, 0].set_xlabel('$E_{cm}$ (MeV)')
    axs[1, 0].set_ylabel('$Re T_{22}$, $Im T_{22}$, $|T_{11}+T_{12}|^2$')

    # Display class information in a table for the second set
    axs[1, 1].axis('off')  # Turn off axis for the table subplot
    table_data = [
        ['n', 'Energy Pole (MeV)', 'RS', 'Active']
    ]
    for i in range(10):
        table_data.append([str(i+1), '{:.2f}'.format(data_info[chckind][0][i]),
                           data_info[chckind][1][i], isactive[data_info[chckind][2][i]]])

    table = axs[1, 1].table(cellText=table_data, cellLoc='center', colWidths=[0.2] * 4, loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.5, 1.5)  # Adjust the scaling as needed

    # Save the figure with a timestamp
    run_identifier = generate_timestamp()

    # Modify the file names in the code accordingly
    figure_filename1 = f"21bt_{run_identifier}.png"
    
    # Save the figures with the updated file names
    fig.savefig(os.path.join(save_dir, figure_filename1))

    # Show the plot and return the selected index
    plt.tight_layout()
    return plt.show(), chckind

In [15]:
def get_traintest(directory,curriculum):
    # If Nshuffle=0, it is understood that the testing dataset is to be prepared.
    # Otherwise, it is the training dataset, and Nshuffle determines shuffling times.
    out = directory
    # import prepared dataset
    inputtraining = pickle.load(open(os.path.join(out, 'labelin_curr{:02d}.pkl'.format(curriculum)), 'rb'))
    outputtraining = pickle.load(open(os.path.join(out, 'labelout_curr{:02d}.pkl'.format(curriculum)), 'rb'))
    inputtraining = np.float32(np.asarray(inputtraining))
    
    # shuffle the imported data
    X_train, X_test, y_train, y_test = train_test_split(inputtraining, outputtraining, test_size=0.2, random_state=42, stratify=outputtraining)
    train = chainer.datasets.TupleDataset(X_train, y_train)
    # split training set with testing set
    pickle.dump(train, open(os.path.join(out, 'chainer_train_curr{:02d}.pkl'.format(curriculum)), 'wb'), protocol=4) #pickle.dump(train, open(os.path.join(out, 'chainer_train.pkl'), 'wb'), protocol=4)
        
    test = chainer.datasets.TupleDataset(X_test, y_test)
    # split training set with the testing set
    pickle.dump(test, open(os.path.join(out, 'chainer_test_curr{:02d}.pkl'.format(curriculum)), 'wb'), protocol=4)#pickle.dump(test, open(os.path.join(out, 'chainer_test.pkl'), 'wb'), protocol=4)

    print(f'Size of training dataset: {len(X_train)}')
    print(f'Size of testing dataset: {len(X_test)}')
    print(f"Test output values: {np.unique(y_test, return_counts=True)[0]}")
    print(f"Test output value counts: {np.unique(y_test, return_counts=True)[1]}")
        
    return


In [16]:
N = 100 #49

Einput = np.linspace(4212,4412, N)

M = 5619.6 #mass of the parent particle (\Lambda_b^0)
m1 = 3096.9 #mass of particle 1 (J/\psi)
m2 = 493.677 #mass of particle 2 (K^-)
m3 = 938.27208816 #mass of particle 3 (p)
proj_axis = 1

In [17]:
def Dalitz(M, m1, m2, m3, proj_axis):
    
    """
    input: 
        M: mass of the parent particle in GeV
        m1, m2, m3: mass of the final states in GeV

    output:
        Dalitz plot in array and figure
    """

    m23sq = np.linspace((m2 + m3)**2, (M - m1)**2, 10000)
    #m23sq = np.linspace(2.5e6, 6.5e6, 10000) #invariant mass of the second and third final states 
    #m13sq = np.linspace((m1 + m3)**2, (M - m2)**2, N) #invariant mass of the first and third final states
    m13sq = Einput**2
    
    E_1 = (M**2 + m1**2 - m23sq)/(2*M)
    E_2 = (M**2 + m2**2 - m13sq)/(2*M)

    X, Y = np.meshgrid(E_1, E_2)
    condition = (4*(X**2 - m1**2)*(Y**2 - m2**2) - (M**2 + m1**2 + m2**2 - m3**2 - 2*M*(X + Y) + 2*X*Y)**2) >=0
    E_3 = (M - X - Y) >= 0

    plot = condition*E_3 > 0

    #cp = plt.pcolormesh(m23sq, m13sq, plot, cmap='Blues')

    phase_space = np.sum(plot, axis=proj_axis)
    
    return phase_space

phase_space0 = Dalitz(M, m1, m2, m3, proj_axis)
norm = np.linalg.norm(phase_space0)
phase_space = phase_space0/norm

In [18]:
def polynomial(coeff, x):
    total = np.sum([coeff[i]*(x**i) for i in range(len(coeff))], axis=0)
    norm = np.linalg.norm([total])
    poly_bg = total/norm
    return poly_bg