In [None]:
#@title **1. Setup**

#@markdown ### Identification
huggingface_username = "Synthyra"  #@param {type:"string"}
huggingface_token = ""            #@param {type:"string"}
wandb_api_key = ""                #@param {type:"string"}
synthyra_api_key = ""             #@param {type:"string"}
github_token = ""                 #@param {type:"string"}


github_clone_path = f"https://{github_token}@github.com/Synthyra/ProbePackageHolder.git"
# !git clone {github_clone_path}
# !cd ProbePackageHolder
# !pip install -r requirements.txt


In [None]:
import torch
import argparse
from types import SimpleNamespace
from base_models.get_base_models import BaseModelArguments, standard_benchmark
from data.hf_data import HFDataArguments
from data.supported_datasets import supported_datasets, standard_data_benchmark
from embedder import EmbeddingArguments
from probes.get_probe import ProbeArguments
from probes.trainers import TrainerArguments
from main import MainProcess


main = MainProcess(argparse.Namespace(), GUI=True)

#@markdown **Paths**
log_dir = "logs"                            #@param {type:"string"}
results_dir = "results"                    #@param {type:"string"}
model_save_dir = "weights"                 #@param {type:"string"}
embedding_save_dir = "embeddings"          #@param {type:"string"}
download_dir = "Synthyra/mean_pooled_embeddings"  #@param {type:"string"}


main.full_args.hf_token = huggingface_token
main.full_args.wandb_api_key = wandb_api_key
main.full_args.synthyra_api_key = synthyra_api_key
main.full_args.log_dir = log_dir
main.full_args.results_dir = results_dir
main.full_args.model_save_dir = model_save_dir
main.full_args.embedding_save_dir = embedding_save_dir
main.full_args.download_dir = download_dir
main.full_args.replay_path = None
main.logger_args = SimpleNamespace(**main.full_args.__dict__)
main.start_log_gui()

#@markdown ---
#@markdown Press play to setup the session:

In [None]:
#@title **2. Data Settings**

max_length = 1024          #@param {type:"integer"}
trim = False               #@param {type:"boolean"}
#@markdown Enter comma-separated dataset names from `supported_datasets`.
#@markdown If left empty, the code uses `standard_data_benchmark`.
dataset_names = ""  #@param {type:"string"}

data_paths = [supported_datasets[name.strip()] for name in dataset_names.split(",") if name.strip()]

main.data_paths = data_paths
main.max_length = max_length
main.trim = trim
main.data_args = HFDataArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main.get_datasets()

#@markdown ---
#@markdown Press play to load datasets:

In [None]:
#@title **3. Embedding Settings**

batch_size = 4                #@param {type:"integer"}
num_workers = 0               #@param {type:"integer"}
download_embeddings = False   #@param {type:"boolean"}
matrix_embed = False          #@param {type:"boolean"}
#@markdown Comma-separated pooling types: e.g. `mean,cls`
embedding_pooling_types = "mean"  #@param {type:"string"}
embed_dtype = "float32"       #@param ["float32","float16","bfloat16","float8_e4m3fn","float8_e5m2"]
sql = False                   #@param {type:"boolean"}

main.full_args.all_seqs = main.all_seqs
main.batch_size = batch_size
main.num_workers = num_workers
main.download_embeddings = download_embeddings
main.matrix_embed = matrix_embed
main.embedding_pooling_types = [p.strip() for p in embedding_pooling_types.split(",") if p.strip()]
if embed_dtype == "float32": main.embed_dtype = torch.float32
elif embed_dtype == "float16": main.embed_dtype = torch.float16
elif embed_dtype == "bfloat16": main.embed_dtype = torch.bfloat16   
elif embed_dtype == "float8_e4m3fn": main.embed_dtype = torch.float8_e4m3fn
elif embed_dtype == "float8_e5m2": main.embed_dtype = torch.float8_e5m2
else:
    print(f"Invalid embedding dtype: {embed_dtype}. Using float32.")
    main.embed_dtype = torch.float32
main.sql = sql


main.embedding_args = EmbeddingArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main.save_embeddings_to_disk()

#@markdown ---
#@markdown Press play to embed sequences:


In [None]:
#@title **4. Model Selection**

#@markdown Comma-separated model names.
#@markdown If empty, defaults to `standard_benchmark`.
model_names = ""  #@param {type:"string"}

def select_models():
    selected = [m.strip() for m in model_names.split(",") if m.strip()]
    if not selected:
        selected = standard_benchmark

    full_args = argparse.Namespace(
        model_names=selected,
        # other args
    )
    model_args = BaseModelArguments(**vars(full_args))
    print("Selected model(s):", selected)
    print("Model Args:", model_args)

#@markdown ---
#@markdown Press play to choose models:
select_models()


In [None]:
#@title **5. Probe Settings**

probe_type = "linear"     #@param ["linear", "transformer", "crossconv"]
tokenwise = False         #@param {type:"boolean"}
pre_ln = True             #@param {type:"boolean"}
n_layers = 1              #@param {type:"integer"}
hidden_dim = 8192         #@param {type:"integer"}
dropout = 0.2             #@param {type:"number"}

classifier_dim = 4096     #@param {type:"integer"}
classifier_dropout = 0.2  #@param {type:"number"}
n_heads = 4               #@param {type:"integer"}
rotary = True             #@param {type:"boolean"}
probe_pooling_types_str = "mean, cls"  #@param {type:"string"}

def create_probe_args():
    probe_pooling_types = [p.strip() for p in probe_pooling_types_str.split(",") if p.strip()]

    full_args = argparse.Namespace(
        probe_type=probe_type,
        tokenwise=tokenwise,
        pre_ln=pre_ln,
        n_layers=n_layers,
        hidden_dim=hidden_dim,
        dropout=dropout,
        classifier_dim=classifier_dim,
        classifier_dropout=classifier_dropout,
        n_heads=n_heads,
        rotary=rotary,
        probe_pooling_types=probe_pooling_types,
        # other relevant fields ...
    )
    probe_args = ProbeArguments(**vars(full_args))
    print("Probe Arguments:", probe_args)

#@markdown ---
#@markdown Press play to configure the probe:
create_probe_args()


In [None]:
#@title **6. Training Settings**

use_lora = False               #@param {type:"boolean"}
hybrid_probe = False           #@param {type:"boolean"}
full_finetuning = False        #@param {type:"boolean"}

num_epochs = 200               #@param {type:"integer"}
trainer_batch_size = 64        #@param {type:"integer"}
gradient_accumulation_steps = 1  #@param {type:"integer"}
lr = 0.0001                    #@param {type:"number"}
weight_decay = 0.0             #@param {type:"number"}
patience = 3                   #@param {type:"integer"}

def run_trainer():
    full_args = argparse.Namespace(
        use_lora=use_lora,
        hybrid_probe=hybrid_probe,
        full_finetuning=full_finetuning,
        num_epochs=num_epochs,
        trainer_batch_size=trainer_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        lr=lr,
        weight_decay=weight_decay,
        patience=patience
        # ...
    )
    trainer_args = TrainerArguments(**vars(full_args))
    print("Trainer Args:", trainer_args)

    # Example: run your training code
    mp = MainProcess(full_args)

    if use_lora:
        print("Running LoRA training... (not implemented in example)")
    elif full_finetuning:
        print("Running full finetuning... (not implemented in example)")
    elif hybrid_probe:
        print("Running hybrid probe training... (not implemented in example)")
    else:
        print("Running default probe training...")
        mp.run_probes()

    print("Training complete.")

#@markdown ---
#@markdown Press play to run the trainer:
run_trainer()


In [None]:
#@title **7. Log Replay**

replay_path = ""  #@param {type:"string"}

def start_replay():
    if not replay_path:
        print("No replay path provided.")
        return

    from logger import LogReplayer
    replayer = LogReplayer(replay_path)
    replay_args = replayer.parse_log()
    # Then do something with replay_args
    print("Replaying from:", replay_path)
    replayer.run_replay(None)  # Or pass your main object, etc.

#@markdown ---
#@markdown Press to replay logs:
start_replay()
