In [None]:
# %% imports
from pathlib import Path
import pandas as pd

from local_llm.training.text_finetune import (
    FineTuneConfig,
    set_seed,
    prepare_label_mapping,
    stratified_split_indices,
    encode_splits,
    build_dataloaders,
    build_bert_text_classifier_from_assets,
    train_text_classifier,
    evaluate_on_split,
    save_finetuned_classifier,
    export_predictions_csv,
)

# %% 1–3. Load data, specify columns
df = pd.read_csv("./data/wbs_data.csv")

cfg = FineTuneConfig(
    text_cols=("wbs_name", "wbs_title_hierarchy", "keyword"),
    label_col="level_1",
    assets_dir=Path("./assets/bert-base-local"),
    output_dir=Path("./artifacts/finetune_bert"),
    train_frac=0.7,
    val_frac=0.15,
    epochs=5,
    base_lr=2e-5,
    finetune_policy="last_n",
    finetune_last_n=2,
)

set_seed(cfg.seed)

# %% label mapping
df, label_to_id, id_to_label = prepare_label_mapping(df, cfg.label_col)

# %% 4. split
train_idx, val_idx, test_idx = stratified_split_indices(
    df["label_id"].values,
    train_frac=cfg.train_frac,
    val_frac=cfg.val_frac,
    seed=cfg.seed,
)

# %% 5. encode
splits = encode_splits(df, train_idx, val_idx, test_idx, cfg)

# %% 6. build datasets/loaders
loaders = build_dataloaders(splits, cfg)

# %% 7. build model
num_labels = len(label_to_id)
model = build_bert_text_classifier_from_assets(cfg, num_labels=num_labels)

# %% 8. train
history, best_state = train_text_classifier(model, loaders, cfg)

# %% 9. save model
save_finetuned_classifier(model, best_state, cfg, label_to_id, id_to_label)

# %% 10–11. inference on test + export CSV
test_metrics, test_preds = evaluate_on_split(model, splits["test"], cfg)
print("Test metrics:", test_metrics)

test_raw_df = df.iloc[test_idx]
csv_path = export_predictions_csv(test_preds, test_raw_df, id_to_label, cfg, split_name="test")
print("Test predictions saved to:", csv_path)
