## Load and Split data 

In [None]:
from omegaconf import OmegaConf
from pathlib import Path
import sys
from hydra import initialize, compose
import os

# add parent directory to sys.path
sys.path.append(str(Path().resolve().parent))
from data_handler import DataHandler

In [21]:
def load_hydra_config_with_params(model, datapack, probe, config_name):
    with initialize(version_base="1.1", config_path="../configs"):
        cfg = compose(config_name=config_name, overrides=[f"model={model}", f"datapack={datapack}", f"probe={probe}"])
    OmegaConf.set_struct(cfg, False)  # Allow overriding
    trial_name = cfg.trial_name
    if cfg.probe['name'] == 'mean_diff':
        cfg.search = False
    if cfg.search:
        trial_name += "_search"
    trial_name += f'_task-{cfg.task}'
    cfg["trial_name"] = trial_name
    # if cfg["task"] == 2:
    #     cfg["probe"]["assume_known_positives"] = False
    cfg["output_dir"] = os.path.join(cfg.output_dir, trial_name)
    OmegaConf.set_struct(cfg, True)
    return OmegaConf.to_container(cfg, resolve=True)

In [None]:
datapack_name = 'city_locations'
model_name = 'llama-3-8b'
probe = 'sawmil'
cfg = load_hydra_config_with_params(model = model_name, datapack=datapack_name, probe = probe, config_name="probe_mil")
datapack_params = cfg["datapack"]
dh = DataHandler(model=model_name,datasets=datapack_params['datasets'], dataset_path='../datasets/', activation_type=datapack_params["agg"], with_calibration=True, load_scores='default', verbose=True)
dh.assemble(exclusive_split=datapack_params["exclusive_split"], test_size=datapack_params["test_size"], calibration_size=datapack_params["cal_size"], seed=datapack_params["random_seed"])
dh.get_train_df().shape, dh.get_test_df().shape,  dh.get_cal_df().shape, dh.get_dataframe().shape