In [24]:
from pathlib import Path

import pandas

from btokstmumu_ml_helpers.datasets.constants import Names_of_Levels


def shuffle_dataframe(dataframe):
    
    return dataframe.sample(frac=1, replace=False)


def load_generic_charge_or_mix_detector_level_dataframe(charge_or_mix:str):

    assert charge_or_mix in ("charge", "mix")

    dataframe = pandas.read_pickle(
        f"../../common/state/data/processed/aggregated_generic/{charge_or_mix}.pkl"
    ).loc[Names_of_Levels().detector]
    return dataframe


def normalize_number_of_charge_events(generic_charge_detector_level_dataframe):

    def calculate_original_mix_to_charge_ratio():
        num_generic_charge_events = 2424628566
        num_generic_mix_events = 1813405232
        ratio_mix_to_charge = num_generic_mix_events / num_generic_charge_events
        return ratio_mix_to_charge
    
    generic_charge_detector_level_dataframe = shuffle_dataframe(generic_charge_detector_level_dataframe)
    normalized_num_charge_events = int(len(generic_charge_detector_level_dataframe) * calculate_original_mix_to_charge_ratio())
    generic_charge_detector_level_dataframe = generic_charge_detector_level_dataframe.iloc[:normalized_num_charge_events]
    return generic_charge_detector_level_dataframe


def load_normalized_shuffled_generic_charge_and_mix_detector_level_dataframes():
    
    mix_dataframe = load_generic_charge_or_mix_detector_level_dataframe("mix")
    charge_dataframe = load_generic_charge_or_mix_detector_level_dataframe("charge")
    charge_dataframe = normalize_number_of_charge_events(charge_dataframe)

    charge_dataframe = shuffle_dataframe(charge_dataframe)
    mix_dataframe = shuffle_dataframe(mix_dataframe)

    return charge_dataframe, mix_dataframe


def cut_to_signal_region(dataframe):

    signal_mbc_cut = lambda df : df[df["Mbc"] > 5.27]
    signal_deltaE_cut = lambda df : df[(df["deltaE"] < 0.05) & (df["deltaE"] > -0.05)]
    return signal_mbc_cut(signal_deltaE_cut(dataframe))


def cut_to_sideband(dataframe):

    sideband_mbc_cut = lambda df : df[(df["Mbc"] > 5.0) & (df["Mbc"] < 5.26)]
    sideband_deltaE_cut = lambda df : df[(df["deltaE"] < 0.05) & (df["deltaE"] > -0.05)]    # same as signal region for now
    return sideband_mbc_cut(sideband_deltaE_cut(dataframe))


def cut_to_signal_events(dataframe):

    return dataframe[dataframe["isSignal"] == 1]


def cut_to_bkg_events(dataframe):

    return dataframe[dataframe["isSignal"] != 1]


def split_dataset_into_splits(dataframe, train_fraction, validation_fraction):
    
    num_validation_examples = int(validation_fraction * len(dataframe))
    num_train_examples = int(train_fraction * len(dataframe))

    train_dataframe = dataframe.iloc[:num_train_examples]
    validation_dataframe = dataframe.iloc[num_train_examples:num_validation_examples+num_train_examples]
    test_dataframe = dataframe.iloc[num_validation_examples+num_train_examples:]

    return train_dataframe, validation_dataframe, test_dataframe


def save_dataframe_to_parquet(dataframe, columns, path_to_save_dir, filename):

    path_to_save_dir = Path(path_to_save_dir)
    dataframe[columns].to_parquet(path_to_save_dir.joinpath(filename))


def split_dataset_into_splits_and_save(dataframe, train_fraction, validation_fraction, columns, base_filename, path_to_save_dir):

    splits = split_dataset_into_splits(
        dataframe=dataframe,
        train_fraction=train_fraction,
        validation_fraction=validation_fraction
    )    

    filenames = [f"{base_filename}_{split}.parquet" for split in ("train", "val", "test")]

    for split, filename in zip(splits, filenames):

        if len(split) > 0:

            save_dataframe_to_parquet(
                dataframe=split,
                columns=columns,
                path_to_save_dir=path_to_save_dir,
                filename=filename
            )

def save_signal_region_datasets(columns, validation_fraction, path_to_save_dir):

    assert (validation_fraction > 0) and (validation_fraction < 1)
    assert Path(path_to_save_dir).is_dir()
    
    charge_df, mix_df = load_normalized_shuffled_generic_charge_and_mix_detector_level_dataframes()

    charge_signal_region_bkg_df = cut_to_bkg_events(cut_to_signal_region(charge_df))
    mix_signal_region_bkg_df = cut_to_bkg_events(cut_to_signal_region(mix_df))
    mix_signal_region_signal_df = cut_to_signal_events(cut_to_signal_region(mix_df))

    dataframes = [charge_signal_region_bkg_df, mix_signal_region_bkg_df, mix_signal_region_signal_df]
    base_filenames = ["charge_sr_bkg", "mix_sr_bkg", "mix_sr_signal"]

    for dataframe, base_filename in zip(dataframes, base_filenames):

        split_dataset_into_splits_and_save(
            dataframe=dataframe,
            train_fraction=0,
            validation_fraction=validation_fraction,
            columns=columns,
            base_filename=base_filename,
            path_to_save_dir=path_to_save_dir
        )


def save_sideband_datasets(columns, train_fraction, validation_fraction, path_to_save_dir):

    assert (validation_fraction > 0) and (validation_fraction < 1)
    assert (train_fraction > 0) and (train_fraction < 1)
    assert Path(path_to_save_dir).is_dir()

    charge_df, mix_df = load_normalized_shuffled_generic_charge_and_mix_detector_level_dataframes()

    charge_sideband_bkg_df = cut_to_bkg_events(cut_to_sideband(charge_df))
    mix_sideband_bkg_df = cut_to_bkg_events(cut_to_sideband(mix_df))

    dataframes = [charge_sideband_bkg_df, mix_sideband_bkg_df]
    base_filenames = ["charge_sb_bkg", "mix_sb_bkg"]

    for dataframe, base_filename in zip(dataframes, base_filenames):

        split_dataset_into_splits_and_save(
            dataframe=dataframe,
            train_fraction=train_fraction,
            validation_fraction=validation_fraction,
            columns=columns,
            base_filename=base_filename,
            path_to_save_dir=path_to_save_dir,
        )


columns = [
    # "isSignal", 
    # "tfRedChiSqB0", 
    # "deltaE", 
    # "invM_K_pi_shifted",
    # "K_p_kaonID", 
    # "K_p_dr", 
    # "K_p_dz", 
    # "pi_m_dr", 
    # "pi_m_dz", 
    # "mu_p_dr", 
    # "mu_p_dz", 
    # "mu_p_muonID", 
    # "mu_m_dr", 
    # "mu_m_dz", 
    # "mu_m_muonID",
    "q_squared",
    "costheta_mu",
    "costheta_K",
    "chi"
]

path_to_save_dir = "../../common/state/data/processed/aggregated_generic"

save_signal_region_datasets(columns=columns, validation_fraction=0.5, path_to_save_dir=path_to_save_dir)
save_sideband_datasets(columns=columns, train_fraction=0.5, validation_fraction=0.2, path_to_save_dir=path_to_save_dir)



