In [None]:
import copy
import gc
import math
import warnings
from pathlib import Path

import pandas as pd
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
    set_seed,
)
from trl import DPOConfig, DPOTrainer

from src.utils import data as data_utils
from src.utils import io as io_utils
from src.utils import models as model_utils

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore")
%matplotlib inline
%load_ext autoreload
%autoreload 2

# EXTERNAL = Path(os.getenv("EXTERNAL_STORAGE_DIR"))
ROOT = io_utils.repo_root()
SPLIT_DIR = ROOT / "data/splits"
DPO_DATA = ROOT / "data/dpo/exp_1"
CONFIG_DIR = ROOT / "config"
METRIC_DIR = ROOT / "metrics"
SFT_MODEL_DIR = ROOT / "models/sft_qlora"
RANDOM_STATE = 42
N_DPO = 2000
N_EVAL = 800

set_seed(RANDOM_STATE)

In [None]:
ROOT

In [None]:
IDS_PATH = io_utils.load_yaml(CONFIG_DIR / "dataset.ids.yml")["splits_ids"]
TRAIN_IDS_PATH = IDS_PATH["train_ids"]
VAL_IDS_PATH = IDS_PATH["val_ids"]

train_ids = pd.read_csv(ROOT / TRAIN_IDS_PATH, header=None)
val_ids = pd.read_csv(ROOT / VAL_IDS_PATH, header=None)

In [None]:
raw_train = load_dataset("IlyaGusev/gazeta")["train"].to_pandas()
raw_val = load_dataset("IlyaGusev/gazeta")["validation"].to_pandas()

print("raw train shape:", raw_train.shape, "raw val shape:", raw_val.shape)
raw_val.head()

In [None]:
columns = ["text", "summary"]
train = raw_train.loc[train_ids.squeeze(), columns]
val = raw_val.loc[val_ids.squeeze(), columns]
for col in columns:
    train[col] = data_utils.clean(train[col])
    val[col] = data_utils.clean(val[col])
val.head(2)

In [None]:
MODEL_CFG_PATH = CONFIG_DIR / "models.params.yml"
model_cfg = None
if torch.cuda.is_available():
    model_cfg = io_utils.load_yaml(MODEL_CFG_PATH)["cuda_model"]
else:
    model_cfg = io_utils.load_yaml(MODEL_CFG_PATH)["cpu_model"]

model_cfg

In [None]:
device = model_cfg["device"]
model_id = model_cfg["model_id"]
use_4bit = model_cfg["use_4bit"]
device_map = model_cfg["device_map"]
torch_dtype = (
    torch.bfloat16
    if device == "cuda" and torch.cuda.is_bf16_supported()
    else (torch.float16 if device == "cuda" else torch.float32)
)

subset_val = val.sample(
    n=min(N_EVAL, val.shape[0]), random_state=RANDOM_STATE
).reset_index(drop=True)

subset_dpo = train.sample(
    n=min(N_DPO, train.shape[0]), random_state=RANDOM_STATE
).reset_index(drop=True)

subset_val.head(2)

In [None]:
quantization_config = None
if use_4bit:
    try:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    except Exception as e:
        print("bitsandbytes не готов, продолжаем без 4-бит:", e)
        quantization_config = None

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    device_map=device_map,
    quantization_config=quantization_config,
)

base_model = prepare_model_for_kbit_training(
    base_model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

policy_model = PeftModel.from_pretrained(base_model, SFT_MODEL_DIR, is_trainable=True)
policy_model.config.pad_token_id = tokenizer.pad_token_id
if getattr(policy_model, "generation_config", None) is not None:
    policy_model.generation_config.pad_token_id = tokenizer.pad_token_id
    policy_model.generation_config.eos_token_id = tokenizer.eos_token_id

ref_model = copy.deepcopy(policy_model)
for p in ref_model.parameters():
    p.requires_grad_(False)
ref_model.eval()

device = next(policy_model.parameters()).device
policy_model.print_trainable_parameters()

In [None]:
SYSTEM_PROMPT = (
    "Ты помощник по резюмированию русскоязычных новостей. "
    "Сделай краткое, нейтральное резюме исходного текста (3–5 предложений). "
    "Не добавляй фактов, которых нет в тексте."
)


def build_chat(text: str) -> str:
    msgs = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": f"Задача: кратко резюмируй.\n\nТекст статьи:\n{text}",
        },
    ]
    return tokenizer.apply_chat_template(
        msgs, tokenize=False, add_generation_prompt=True
    )

In [None]:
GEN_GREEDY = GenerationConfig(
    max_new_tokens=200,
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

GEN_SAMPLED = GenerationConfig(
    max_new_tokens=200,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

MAX_INPUT_TOKENS = model_utils.get_max_input_tokens(tokenizer, GEN_GREEDY)

In [None]:
def generate_with_cfg(texts, model, gen_cfg, batch_size=4):
    out = []
    it = tqdm(
        range(0, len(texts), batch_size),
        total=math.ceil(len(texts) / batch_size),
        desc=f"Generating (do_sample={gen_cfg.do_sample})",
        leave=False,
    )

    for i in it:
        chunk = [build_chat(t) for t in texts[i : i + batch_size]]
        inputs = tokenizer(
            chunk,
            return_tensors="pt",
            padding=True,
            truncation=True,
            pad_to_multiple_of=8,
            max_length=MAX_INPUT_TOKENS,
        ).to(device)

        with torch.no_grad():
            output_ids = model.generate(**inputs, generation_config=gen_cfg)

        gen_ids = output_ids[:, inputs["input_ids"].shape[1] :]
        decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        out.extend([d.strip() for d in decoded])

    return out

In [None]:
del ref_model
gc.collect()
torch.cuda.empty_cache()

policy_model.eval()
policy_model.config.use_cache = True

In [None]:
BATCH = 6
texts_dpo = subset_dpo["text"].tolist()
refs_dpo = subset_dpo["summary"].tolist()

cand_greedy = generate_with_cfg(texts_dpo, policy_model, GEN_GREEDY, batch_size=BATCH)
cand_sampled = generate_with_cfg(texts_dpo, policy_model, GEN_SAMPLED, batch_size=BATCH)

In [None]:
rouge_greedy = data_utils.dpo_rouge_lsum(cand_greedy, refs_dpo)
rouge_sampled = data_utils.dpo_rouge_lsum(cand_sampled, refs_dpo)

pairs = []
MIN_DELTA = 0.02
for i in range(subset_dpo.shape[0]):
    if rouge_greedy[i] > rouge_sampled[i]:
        chosen, rejected = cand_greedy[i], cand_sampled[i]
        delta = rouge_greedy[i] - rouge_sampled[i]
    else:
        chosen, rejected = cand_sampled[i], cand_greedy[i]
        delta = rouge_sampled[i] - rouge_greedy[i]
    if delta >= MIN_DELTA:
        pairs.append((build_chat(texts_dpo[i]), chosen, rejected))

len(pairs)

In [None]:
dpo_df = pd.DataFrame(pairs, columns=["prompt", "chosen", "rejected"])
dpo_ds = Dataset.from_pandas(dpo_df)
dpo_ds.save_to_disk(str(DPO_DATA))
dpo_ds[0]

In [None]:
# from datasets import load_from_disk
# _reload = load_from_disk(str(DPO_DATA))
# print(_reload, _reload[0])

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
policy_model.config.use_cache = False
policy_model.train()

ref_model = copy.deepcopy(policy_model)
for p in ref_model.parameters():
    p.requires_grad_(False)
ref_model.eval()
ref_model.config.use_cache = False

In [None]:
check_dir = ROOT / "models/dpo_qlora"
dpo_cfg = DPOConfig(
    output_dir=str(check_dir),
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=5e-6,
    lr_scheduler_type="cosine",
    warmup_steps=50,
    num_train_epochs=1,
    logging_steps=20,
    eval_strategy="no",
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    report_to=[],
    beta=0.1,
    max_prompt_length=2048,
    max_length=2048 + 256,
    bf16=True,
    optim=(
        "adamw_bnb_8bit"
        if use_4bit
        else ("adamw_torch_fused" if device == "cuda" else "adamw_torch")
    ),
    remove_unused_columns=False,
)

dpo_trainer = DPOTrainer(
    model=policy_model,
    ref_model=ref_model,
    args=dpo_cfg,
    processing_class=tokenizer,
    train_dataset=dpo_ds,
)

print("active adapters:", getattr(dpo_trainer.model, "active_adapters", None))
dpo_trainer.model.print_trainable_parameters()

In [None]:
train_out = dpo_trainer.train()
train_out

In [None]:
out_dir = ROOT / "models/dpo_qlora"
policy_model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)

In [None]:
import shutil

from google.colab import files

zip_path = "/content/dpo_qwen2_qlora_adapter"
shutil.make_archive(zip_path, "zip", out_dir)
files.download(zip_path + ".zip")

!cp -r {Path("/content/llm-news/src/checkpoints/dpo_qlora")} "/content/drive/MyDrive/llm-news"

In [None]:
del policy_model, ref_model, base_model
gc.collect()
torch.cuda.empty_cache()

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    device_map=device_map,
    quantization_config=quantization_config,
)

policy_model = PeftModel.from_pretrained(base_model, out_dir)
policy_model.config.pad_token_id = tokenizer.pad_token_id
if getattr(policy_model, "generation_config", None) is not None:
    policy_model.generation_config.pad_token_id = tokenizer.pad_token_id
    policy_model.generation_config.eos_token_id = tokenizer.eos_token_id

policy_model.eval()
policy_model.config.use_cache = True

In [None]:
BATCH = 6
texts = subset_val["text"].tolist()
refs = subset_val["summary"].tolist()

preds_dpo = generate_with_cfg(texts, policy_model, GEN_GREEDY, batch_size=BATCH)


scores_dpo = data_utils.get_rouge_f1(preds_dpo, refs)
scores_dpo

In [None]:
preds_dpo[:2]

In [None]:
refs[:2]

In [None]:
scores = data_utils.get_all_scores(preds_dpo, refs, device=device)
scores

In [None]:
Path(METRIC_DIR).mkdir(parents=True, exist_ok=True)

df_metrics = pd.DataFrame(
    [
        {
            "system": "SRT+DPO QLoRA",
            "split": "validation_full",
            "rouge1": scores.get("rouge1", 0.0),
            "rouge2": scores.get("rouge2", 0.0),
            "rougeL": scores.get("rougeL", 0.0),
            "rougeLsum": scores.get("rougeLsum", 0.0),
            "bertscore_precision": scores.get("bertscore_precision", 0.0),
            "bertscore_recall": scores.get("bertscore_recall", 0.0),
            "bertscore_f1": scores.get("bertscore_f1", 0.0),
            "avg_len_pred": scores.get("avg_len_pred", 0.0),
            "avg_len_ref": scores.get("avg_len_ref", 0.0),
            "len_ratio_pred_to_ref": scores.get("len_ratio_pred_to_ref", 0.0),
            "k": None,
            "n_examples": N_EVAL,
        }
    ]
)
df_metrics.to_csv(
    METRIC_DIR / f"llm_dpo_qlora_validation_{device}_{N_EVAL}.csv", index=False
)

df_sampels = pd.DataFrame(
    [
        {
            "title": subset_val["title"].head(3) if "title" in subset_val else [""] * 3,
            "reference": refs[:3],
            "prediction": preds_dpo[:3],
        }
    ]
)
df_sampels.to_csv(
    METRIC_DIR / f"llm_dpo_qlora_examples_{device}.tsv", sep="\t", index=False
)

In [None]:
# !nvidia-smi

# import torch

# print("torch:", torch.__version__, "| CUDA доступна:", torch.cuda.is_available())

# # ----------------------------------------------------------------------------------

# from google.colab import drive

# drive.mount("/content/drive", force_remount=True)

# # ----------------------------------------------------------------------------------

# import subprocess
# import sys
# import os

# REPO_URL = "https://github.com/mdayssi/llm-news-summarizer-ru.git"
# REPO_DIR = "/content/llm-news"

# if not os.path.exists(REPO_DIR):
#     !git clone {REPO_URL} {REPO_DIR}
# else:
#     print("Репозиторий уже есть:", REPO_DIR)


# %cd {REPO_DIR}
# !git rev-parse --short HEAD

# # ----------------------------------------------------------------------------------
# %pip -q install --upgrade \
#   evaluate rouge-score bert_score\
#   razdel bitsandbytes accelerate\
#   python-dotenv pyyaml peft trl

# import accelerate
# import bert_score
# import bitsandbytes
# import datasets
# import dotenv
# import evaluate
# import razdel
# import rouge_score
# import sentencepiece
# import torch
# import tqdm
# import transformers
# import yaml

# print("torch:", torch.__version__, "| cuda avail:", torch.cuda.is_available())
# print("transformers:", transformers.__version__)
# print("datasets:", datasets.__version__)
# print("evaluate:", evaluate.__version__)

# # ----------------------------------------------------------------------------------

# repo_src = "/content/llm-news/src"
# if repo_src not in sys.path:
#     sys.path.insert(0, repo_src)
# print("sys.path ok")