# Unique Amplitudes/Squared Amplitdues
In this notebook I want to explore how many amplitudes are unique.
I don't expect the same amount of squared amplitudes to be unique, since two amplitudes can have the same squared amplitude.

The motivation is that I don't want to create new data, but I want to have unique amplitudes so that I can be sure that my model is not overfitting.
Actually I think the way my data is now, I am basically always testing on training data, because every amplitude appears so often.

Ideally I want to train on completely unique amplitudes and then test on another set of completely unique amplitudes.
By this  I mean structurally completely unique, not just $m_e$ -> $m_\mu$.

In [92]:
from icecream import ic
import sympy as sp
from itertools import (takewhile,repeat)
from tqdm import tqdm
import numpy as np
import sys
import os
import importlib.util
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from tensorflow.keras.layers import TextVectorization
from tensorflow.keras import layers
import tensorflow as tf
from tensorflow import keras
import pickle

In [93]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [94]:
if ".." not in sys.path:
    sys.path.insert(0, "..")

In [95]:
import data_preprocessing.tree.sympy_to_tree as sp2tree
from data_preprocessing.sympy_prefix.source.SympyPrefix import prefix_to_sympy, sympy_to_prefix, sympy_to_hybrid_prefix, hybrid_prefix_to_sympy

In [96]:
export_folder = "../data.nosync/2022-11-09/"
with open(export_folder+"amplitudes.pickle", "rb") as f:
    amplitudes = pickle.load(f)
    amplitudes = [[a.split(",") for a in amps] for amps in amplitudes]

with open(export_folder+"sqamplitudes.pickle", "rb") as f:
    sqamplitudes = pickle.load(f)

KeyboardInterrupt: 

In [None]:
def get_unique_indices(amps):
    amps = [" ".join(a) for a in amps]
    tmp = np.sort(np.unique(amps, return_index=True, axis=0)[1])
    return tmp

In [None]:
tmp = get_unique_indices(amplitudes[3])
len(tmp)

58332

In [None]:
len(amplitudes[3])

129023

In [None]:
unique_indices = [get_unique_indices(a) for a in amplitudes]

In [None]:
amplitudes_unique = [[amplitudes[j][ind] for ind in unique_indices[j]] for j in range(len(amplitudes))]

In [None]:
sqamplitudes_corresponding = [[sqamplitudes[j][ind] for ind in unique_indices[j]] for j in range(len(amplitudes))]

In [101]:
for a in amplitudes_unique:
    ic(len(a))

ic| len(a): 54
ic| len(a): 54
ic| len(a): 2988
ic| len(a): 58332
ic| len(a): 58361


In [None]:
# convert squared ampmlitudes to prefix
ctr = 0
def try_sympy_to_prefix(expr):
    global ctr
    ctr = ctr + 1
    try:
        return sympy_to_prefix(expr)
    except:
        print("problem with:", expr, "at ctr =", ctr)
        return 0
sqampl_prefix = [[try_sympy_to_prefix(a) for a in tqdm(sq)] for sq in sqamplitudes_corresponding]
np.array(sqampl_prefix[0][0])

  0%|          | 0/54 [00:00<?, ?it/s]

  0%|          | 0/54 [00:00<?, ?it/s]

  0%|          | 0/2988 [00:00<?, ?it/s]

  0%|          | 0/58332 [00:00<?, ?it/s]

  0%|          | 0/58361 [00:00<?, ?it/s]

array(['mul', 's-', '4', 'mul', 'pow', 'e', '2', 'add', 'mul', 's-', '1',
       's_12', 'mul', '2', 'pow', 'm_mu', '2'], dtype='<U4')

In [116]:
print(amplitudes_unique[0][0])
print(amplitudes_unique[0][1])
print(amplitudes_unique[0][2])
print(amplitudes_unique[0][3])
print(amplitudes_unique[0][4])
print(amplitudes_unique[0][5])
print(amplitudes_unique[0][6])
print(amplitudes_unique[0][7])
print(amplitudes_unique[0][8])
print(amplitudes_unique[0][9])
print(amplitudes_unique[0][10])
print(amplitudes_unique[0][11])
print(amplitudes_unique[0][12])
print(amplitudes_unique[0][13])
print(amplitudes_unique[0][14])
print(amplitudes_unique[0][15])
print(amplitudes_unique[0][16])
print(amplitudes_unique[0][17])
print(amplitudes_unique[0][18])
print(amplitudes_unique[0][19])
print(amplitudes_unique[0][20])
print(amplitudes_unique[0][21])
print(amplitudes_unique[0][22])
print(amplitudes_unique[0][23])
print(amplitudes_unique[0][24])
print(amplitudes_unique[0][-1])

['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_0', 'alpha_1', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_3)', 'Prod', 'mu', 'i_0', 'alpha_1', '(p_1)_u', 'mu^(*)', 'i_1', 'alpha_0', '(p_2)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_1', 'alpha_0', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_2)', 'Prod', 'mu^(*)', 'i_0', 'alpha_1', '(p_1)_u', 'mu', 'i_1', 'alpha_0', '(p_3)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_0', 'alpha_1', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_3)', 'Prod', 'ee', 'i_0', 'alpha_1', '(p_1)_u', 'ee^(*)', 'i_1', 'alpha_0', '(p_2)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_1', 'alpha_0', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_2)', 'Prod', 'ee^(*)', 'i_0', 'alpha_1', '(p_1)_u', 'ee', 'i_1', 'alpha_0', '(p_3)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_1', 'alpha_0', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_3)

In [117]:
print(amplitudes_unique[1][0])
print(amplitudes_unique[1][1])
print(amplitudes_unique[1][2])
print(amplitudes_unique[1][3])
print(amplitudes_unique[1][4])
print(amplitudes_unique[1][5])
print(amplitudes_unique[1][6])
print(amplitudes_unique[1][7])
print(amplitudes_unique[1][8])
print(amplitudes_unique[1][9])
print(amplitudes_unique[1][10])
print(amplitudes_unique[1][11])
print(amplitudes_unique[1][12])
print(amplitudes_unique[1][13])
print(amplitudes_unique[1][14])
print(amplitudes_unique[1][15])
print(amplitudes_unique[1][16])
print(amplitudes_unique[1][17])
print(amplitudes_unique[1][18])
print(amplitudes_unique[1][19])
print(amplitudes_unique[1][20])
print(amplitudes_unique[1][21])
print(amplitudes_unique[1][22])
print(amplitudes_unique[1][23])
print(amplitudes_unique[1][24])
print(amplitudes_unique[1][-1])

['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_0', 'alpha_1', 'Prod', 'A', 'i_2', 'alpha_2', '(p_2)', 'Prod', 'ee', 'i_0', 'alpha_1', '(p_1)_u', 'ee^(*)', 'i_1', 'alpha_0', '(p_3)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_1', 'alpha_0', 'Prod', 'A', 'i_2', 'alpha_2', '(p_3)', 'Prod', 'ee^(*)', 'i_0', 'alpha_1', '(p_1)_u', 'ee', 'i_1', 'alpha_0', '(p_2)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_0', 'alpha_1', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_3)', 'Prod', 'ee', 'i_0', 'alpha_1', '(p_1)_u', 'ee^(*)', 'i_1', 'alpha_0', '(p_2)_v']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_0', 'alpha_1', 'Prod', 'A^(*)', 'i_0', 'alpha_2', '(p_1)', 'Prod', 'ee', 'i_2', 'alpha_1', '(p_2)_u', 'ee^(*)', 'i_1', 'alpha_0', '(p_3)_v']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_1', 'alpha_0', 'Prod', 'A', 'i_2', 'alpha_2', '(p_2)', 'Prod', '

In [99]:
# convert squared amplitudes to trees:
ctr = 0
def try_sympy_to_tree(expr):
    global ctr
    ctr = ctr + 1
    try:
        return sp2tree.sympy_to_tree(expr)
    except:
        print("problem with:", expr, "at ctr =", ctr)
        return 0
sqampl_tree = [[try_sympy_to_tree(a) for a in tqdm(sq)] for sq in sqamplitudes_corresponding]
display(sqamplitudes_corresponding[2][0])
sqampl_tree[2][0].pretty_print(unicodelines=True)

# convert back
print("tree_to_sympy:")
display(sp2tree.tree_to_sympy(sqampl_tree[2][0]))

4*e**4*(2*m_e**4 + m_e**2*(-s_14 - s_23) + s_12*s_34 + s_13*s_24)/(2*m_e**2 + reg_prop - 2*s_23)**2

                                                              mul                                                                              
 ┌───────┬─────────────┬───────────────────────────────────────┴──────────────────────────────────┐                                             
 │       │            pow                                                                        add                                           
 │       │       ┌─────┴──────────────┐                        ┌────────────────┬─────────────┬───┴────────────────┐                            
 │       │       │                   add                       │                │             │                   mul                          
 │       │       │     ┌──────────┬───┴────────┐               │                │             │            ┌───────┴────────────┐               
 │       │       │     │          │           mul             mul               │             │            │                   add   

4*e**4*(2*m_e**4 + m_e**2*(-s_14 - s_23) + s_12*s_34 + s_13*s_24)/(2*m_e**2 + reg_prop - 2*s_23)**2

In [107]:
print(amplitudes_unique[0][0])
print(amplitudes_unique[0][1])
print(sqamplitudes_corresponding[0][0])
print(sqamplitudes_corresponding[0][1])

['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_0', 'alpha_1', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_3)', 'Prod', 'mu', 'i_0', 'alpha_1', '(p_1)_u', 'mu^(*)', 'i_1', 'alpha_0', '(p_2)_u']
['Prod', '-1', 'Prod', 'i', 'Prod', 'e', 'Prod', 'gamma', 'alpha_2', 'alpha_1', 'alpha_0', 'Prod', 'A^(*)', 'i_2', 'alpha_2', '(p_2)', 'Prod', 'mu^(*)', 'i_0', 'alpha_1', '(p_1)_u', 'mu', 'i_1', 'alpha_0', '(p_3)_u']
-4*e**2*(2*m_mu**2 - s_12)
-4*e**2*(2*m_mu**2 - s_13)
