In [None]:
from torch.nn import MSELoss

from helpers.datasets.make_and_save.aggregated_signal import Aggregated_Signal_Dataframe_Handler
from helpers.datasets.constants import Names_of_Levels, Names_of_q_Squared_Vetos, Raw_Signal_Trial_Ranges
from helpers.experiment.experiment import CNN_Group, Deep_Sets_Group, Event_by_Event_Group
from helpers.experiment.results_table import Results_Table
from helpers.experiment.constants import Paths_to_Directories
from helpers.models.hardware_util import select_device

In [None]:
for  level in (Names_of_Levels().generator, Names_of_Levels().detector):
    for trial_range in Raw_Signal_Trial_Ranges().tuple_:
        
        Aggregated_Signal_Dataframe_Handler(
            path_to_main_datasets_dir=Paths_to_Directories().path_to_main_datasets_dir,
            level=level,
            trial_range=trial_range
        ).make_and_save(Paths_to_Directories().path_to_raw_signal_dir)

In [None]:
results_table = Results_Table()
device = select_device()

In [None]:
deep_sets_group = Deep_Sets_Group(
    num_sets_per_label=50,
    num_sets_per_label_sensitivity=2_000,
    q_squared_veto=Names_of_q_Squared_Vetos().loose,
    std_scale=True,
    shuffle=True,
    loss_fn=MSELoss(),
    learning_rate=3e-4,
    learning_rate_scheduler_reduction_factor=0.9,
    size_of_training_batch=32,
    size_of_evaluation_batch=32,
    number_of_epochs=100,
    number_of_epochs_between_checkpoints=5,
    results_table=results_table,
    device=device,
    bkg_fraction=0.5,
    bkg_charge_fraction=0.5
)

deep_sets_group.train_all(remake_datasets=True)
deep_sets_group.evaluate_all(remake_datasets=True)


In [None]:
cnn_group = CNN_Group(
    num_sets_per_label=50,
    num_sets_per_label_sensitivity=2_000,
    num_bins_per_dimension=10,
    q_squared_veto=Names_of_q_Squared_Vetos().loose,
    std_scale=True,
    shuffle=True,
    loss_fn=MSELoss(),
    learning_rate=3e-4,
    learning_rate_scheduler_reduction_factor=0.9,
    size_of_training_batch=32,
    size_of_evaluation_batch=32,
    number_of_epochs=100,
    number_of_epochs_between_checkpoints=5,
    results_table=results_table,
    device=device,
    bkg_fraction=0.5,
    bkg_charge_fraction=0.5
)

cnn_group.train_all(remake_datasets=True)
cnn_group.evaluate_all(remake_datasets=True)
cnn_group.plot_image_examples_all(remake_datasets=False)

In [None]:
event_by_event_group = Event_by_Event_Group(
    num_evaluation_sets_per_label=50,
    num_evaluation_sets_per_label_sensitivity=2_000,
    q_squared_veto=Names_of_q_Squared_Vetos().loose,
    std_scale=True,
    shuffle=True,
    loss_fn=MSELoss(),
    learning_rate=3e-4,
    learning_rate_scheduler_reduction_factor=0.95,
    size_of_training_batch=10_000,
    size_of_evaluation_batch=10_000,
    number_of_epochs=300,
    number_of_epochs_between_checkpoints=5,
    results_table=results_table,
    device=device
)

event_by_event_group.train_all(remake_datasets=True)
event_by_event_group.evaluate_all(remake_datasets=True)