In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported

from trl import SFTTrainer
from transformers import TrainingArguments
from tuning.data.train_dataset import get_train_dataset
from tuning.training.config_training import ModelLoadConfig, LoraConfig, SFTRunConfig, PTRunConfig, DPOTrainingConfig, TrainingArgumentsConfig, PassAtKConfig, sft_batch_size, effective_batch_size
from tuning.training.perplexity_callback import PerplexityStoppingCallback
from tuning.training.passk_callback import PassAtKStoppingCallback
from tuning.utils.utils import apply_chat_template, chat_template_func
import json
import sys
from datasets import load_from_disk
from typing import List, Optional, Union
from pathlib import Path
from tuning.config import DATASETS_DIR, HF_MODEL_MAP
import os
from tuning.training.config_training import DatasetConfig, SFTRunConfig
from tuning.config import MODELS_DIR
from tuning.training.sft_training import train_model_sft
from tuning.training.dpo_training import train_model_dpo
import subprocess
import importlib
import torch
import json

In [None]:
import importlib
import tuning.training.passk_callback
importlib.reload(tuning.training.passk_callback)
from tuning.training.passk_callback import PassAtKStoppingCallback
from tuning.training.dpo_training import train_model_dpo
importlib.reload(tuning.training.dpo_training)
from tuning.training.dpo_training import train_model_dpo

In [None]:
MODEL = "llama3-1B"
total_train_size = 4096  # 29980
perplexity_thresholds = [7.0,6.0, 5.75, 5.5, 5.25, 5.0, 4.75, 4.5, 4.25,4.0, 3.9, 3.8, 3.7, 3.6,3.55,3.5,3.45,3.4,3.35,3.3, 3.25, 3.2, 3.15, 3.1]
perplexity_thresholds = [6.0, 5.0]

In [None]:
dataset_config = DatasetConfig(
    dataset = "tuluif",
    dataset_type = "sft",
    train_size = total_train_size, # 29980
)

run_config = SFTRunConfig(
    dataset_config = dataset_config,
    model_name_hf = HF_MODEL_MAP[MODEL],  # Use HuggingFace model name, not local path
    model_name = MODEL,  # Base model name for output directory construction
    do_training=True,
    do_inference=False,
    do_evaluation=False,
)
passk_config = PassAtKConfig( # this is just to dynamically view the pass@1 of ifeval
    target_pass_at_k=[0.1, 0.2, 0.3,0.4,0.5,0.6],
    k_values=[1],
    n_samples=1,
    num_prompts=200,
    temperature=0.7,
    strict=False,
    enabled=True,
)

lora_config = LoraConfig()
model_load_config = ModelLoadConfig()
model_load_config.max_seq_length = 4096
training_args = TrainingArgumentsConfig()
training_args.eval_steps = 100

In [None]:
import wandb
run = wandb.init(
    name=run_config.run_name, 
    project="tuning", 
    reinit=True,
    # Optional: Pass config here so it's logged even if training crashes early
    config=run_config.__dict__ if hasattr(run_config, "__dict__") else {} 
)

with run:
    model, tokenizer, trainer, callbacks = train_model_sft(
        run_config = run_config,
        lora_config = lora_config,
        model_load_config = model_load_config,
        training_args = training_args,
        passk_config = passk_config
    )   

In [None]:
# ppl_callback = callbacks[-1]
# metadata_file = ppl_callback.metadata_path
metadata_file = "/home/shougan/projects/aip-fredashi/shougan/balance-budget/tuning/models_metadata/llama3-8B_passatk-0201_2257.json"
checkpoints = []
with open(metadata_file, "r") as f:
    for line in f:
        checkpoints.append(json.loads(line))
print(checkpoints)


In [None]:
!nvidia-smi
import gc
# del model, tokenizer, trainer, callbacks # this deletes the references to such objects
gc.collect() # then we force the GC
torch.cuda.empty_cache() # and we release the GPU CUDA cache
!nvidia-smi
# we could also use reset_cuda_context() for a clean state

In [None]:
for checkpoint in checkpoints[2:]:
    model_name = Path(checkpoint["checkpoint_path"]).name
    data = total_train_size - checkpoint["data_points_seen"] 
    model_load_config = ModelLoadConfig()
    training_args = DPOTrainingConfig()
    training_args.eval_strategy = "epoch"  # Evaluates at end of each epoch
    training_args.load_best_model_at_end = False  # Keep latest model
    training_args.per_device_train_batch_size = 8
    training_args.gradient_accumulation_steps = 16
    # training_args.eval_steps = 250
    dataset_config = DatasetConfig(
        dataset = "tuluif",
        dataset_type = "pt",
        train_size = data,
    )
    sft_run_config = SFTRunConfig(
        dataset_config = DatasetConfig(
            dataset = "tuluif",
            dataset_type = "sft",
            train_size = checkpoint["data_points_seen"],
            dynamic_path = model_name
        ),
        model_name = MODEL,
        model_name_hf = HF_MODEL_MAP[MODEL], 
        task_name = "ifeval"
    )
    run_config = PTRunConfig(
        dataset_config = dataset_config,
        # model_name_hf = HF_MODEL_MAP[MODEL],  
        model_name = MODEL,  
        sft_run_config = sft_run_config,
        task_name = "ifeval",
        pft_method = "dpo",
        do_training = True
    )
    passk_config = PassAtKConfig( # this is just to dynamically view the pass@1 of ifeval
        target_pass_at_k=[1.2],
        k_values=[1],
        n_samples=1,
        num_prompts=100,
        temperature=0.7,
        strict=False,
        enabled=True,
    )
    model, tokenizer, trainer = train_model_dpo(
        run_config = run_config,
        lora_config = lora_config,
        model_load_config = model_load_config,
        training_args = training_args,
        # passk_config = passk_config,
        # perplexity_thresholds= [0.1] # dummy value to periodically check perplexities too
    )
    del model, tokenizer, trainer
    gc.collect()
    torch.cuda.empty_cache()




In [None]:
!nvidia-smi