In [None]:
import math
import warnings
from pathlib import Path

import pandas as pd
import torch
from datasets import load_dataset
from dotenv import load_dotenv
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
    set_seed,
)

import src.utils.data as data_utils
import src.utils.io as io_utils
import src.utils.models as model_utils

In [None]:
load_dotenv()

warnings.filterwarnings("ignore")
%matplotlib inline
%load_ext autoreload
%autoreload 2

# EXTERNAL = Path(os.getenv("EXTERNAL_STORAGE_DIR"))
ROOT = io_utils.repo_root()
SPLIT_DIR = ROOT / "data/splits"
CONFIG_DIR = ROOT / "config"
METRIC_DIR = ROOT / "metrics"
RANDOM_STATE = 42
FEWSHOT_POOL_SIZE = 2000
FEWSHOT_K = 2
OVERHEAD_BUDGET = 128

set_seed(RANDOM_STATE)

In [None]:
ROOT

In [None]:
IDS_PATH = io_utils.load_yaml(CONFIG_DIR / "dataset.ids.yml")["splits_ids"]
TRAIN_IDS_PATH = IDS_PATH["train_ids"]
VAL_IDS_PATH = IDS_PATH["val_ids"]

train_ids = pd.read_csv(ROOT / TRAIN_IDS_PATH, header=None)
val_ids = pd.read_csv(ROOT / VAL_IDS_PATH, header=None)

In [None]:
raw_train = load_dataset("IlyaGusev/gazeta")["train"].to_pandas()
raw_val = load_dataset("IlyaGusev/gazeta")["validation"].to_pandas()

print("raw train shape:", raw_train.shape, "raw val shape:", raw_val.shape)
raw_val.head()

In [None]:
columns = ["text", "summary"]
train = raw_train.loc[train_ids.squeeze(), columns]
val = raw_val.loc[val_ids.squeeze(), columns]
for col in columns:
    train[col] = data_utils.clean(train[col])
    val[col] = data_utils.clean(val[col])
val.head(2)

In [None]:
MODEL_CFG_PATH = CONFIG_DIR / "models.params.yml"
model_cfg = None
if torch.cuda.is_available():
    model_cfg = io_utils.load_yaml(MODEL_CFG_PATH)["cuda_model"]
else:
    model_cfg = io_utils.load_yaml(MODEL_CFG_PATH)["cpu_model"]

model_cfg

In [None]:
device = model_cfg["device"]
model_id = model_cfg["model_id"]
n_eval = model_cfg["n_eval"]
use_4bit = model_cfg["use_4bit"]
device_map = model_cfg["device_map"]
torch_dtype = (
    torch.bfloat16
    if device == "cuda" and torch.cuda.is_bf16_supported()
    else (torch.float16 if device == "cuda" else torch.float32)
)
if n_eval is None:
    subset_val = val
else:
    subset_val = val.sample(
        n=min(n_eval, val.shape[0]), random_state=RANDOM_STATE
    ).reset_index(drop=True)

fewshot_pool = train.sample(
    n=min(FEWSHOT_POOL_SIZE, len(train)), random_state=RANDOM_STATE
).reset_index(drop=True)

subset_val.head(2)

In [None]:
quantization_config = None
if use_4bit:
    try:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    except Exception as e:
        print("bitsandbytes не готов, продолжаем без 4-бит:", e)
        quantization_config = None

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=(None if quantization_config else torch_dtype),
    device_map=device_map,
    quantization_config=quantization_config,
)

tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

model.config.pad_token_id = tokenizer.pad_token_id
if getattr(model, "generation_config", None) is not None:
    model.generation_config.pad_token_id = tokenizer.pad_token_id

if device != "cuda":
    model.to(device)

In [None]:
getattr(tokenizer, "model_max_length", None)

In [None]:
data_utils.max_len(fewshot_pool, tokenizer, "text"), data_utils.max_len(
    fewshot_pool, tokenizer, "summary"
)

In [None]:
len(tokenizer.encode(subset_val.iloc[0]["text"], add_special_tokens=False)), len(
    tokenizer.encode(subset_val.iloc[0]["summary"], add_special_tokens=False)
)

In [None]:
SYSTEM_PROMPT = (
    "Ты помощник по резюмированию русскоязычных новостей. "
    "Сделай краткое, нейтральное резюме исходного текста (3–5 предложений). "
    "Не добавляй фактов, которых нет в тексте."
)

GEN_EVAL = GenerationConfig(
    max_new_tokens=200,
    do_sample=False,
)

MAX_INPUT_TOKENS = model_utils.get_max_input_tokens(tokenizer, GEN_EVAL)

In [None]:
def build_chat_fewshot(target_text: str, pool_df, k=FEWSHOT_K):
    exemplars = model_utils.sample_exemplars(pool_df, k=k, random_state=RANDOM_STATE)

    kk = len(exemplars)
    tgt_text = target_text
    while kk >= 0:
        msgs = model_utils.assemble_msgs(exemplars[:kk], tgt_text, SYSTEM_PROMPT)
        prompt = tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
        ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[
            "input_ids"
        ][0]
        if len(ids) + OVERHEAD_BUDGET <= MAX_INPUT_TOKENS:
            return msgs  # успешно
        # Иначе пробуем уменьшить число примеров
        kk -= 1
        if kk < 0:
            # крайний случай: вернём zero-shot
            return [
                {"role": "system", "content": SYSTEM_PROMPT},
                {
                    "role": "user",
                    "content": f"Задача: кратко резюмируй.\n\nТекст статьи:\n{target_text}",
                },
            ]

In [None]:
def generate_batch_fewshot(texts, pool_df, batch_size=3, show_progress=True):
    out = []
    it = range(0, len(texts), batch_size)
    if show_progress:
        it = tqdm(
            it,
            total=math.ceil(len(texts) / batch_size),
            desc="Generating (few-shot)",
            leave=False,
        )

    prompt_strs = []
    for t in texts:
        msgs = build_chat_fewshot(t, fewshot_pool, k=FEWSHOT_K)
        s = tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
        prompt_strs.append(s)

    for i in it:
        chunk = prompt_strs[i : i + batch_size]
        inputs = tokenizer(
            chunk,
            return_tensors="pt",
            padding=True,
            truncation=True,
            pad_to_multiple_of=8,
            max_length=MAX_INPUT_TOKENS,
        ).to(device)

        with torch.no_grad():
            output_ids = model.generate(**inputs, generation_config=GEN_EVAL)

        gen_ids = output_ids[:, inputs["input_ids"].shape[1] :]
        decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        out.extend([d.strip() for d in decoded])

    return out

In [None]:
BATCH = 1 if device != "cuda" else 3
texts = subset_val["text"].tolist()
refs = subset_val["summary"].tolist()
preds_few = generate_batch_fewshot(
    texts, pool_df=fewshot_pool, batch_size=BATCH, show_progress=True
)

In [None]:
preds_few[:2]

In [None]:
refs[:2]

In [None]:
scores = data_utils.get_all_scores(preds_few, refs, device=device)
scores

In [None]:
Path(METRIC_DIR).mkdir(parents=True, exist_ok=True)

df_metrics = pd.DataFrame(
    [
        {
            "system": "few_shot",
            "split": "validation_full",
            "rouge1": scores.get("rouge1", 0.0),
            "rouge2": scores.get("rouge2", 0.0),
            "rougeL": scores.get("rougeL", 0.0),
            "rougeLsum": scores.get("rougeLsum", 0.0),
            "bertscore_precision": scores.get("bertscore_precision", 0.0),
            "bertscore_recall": scores.get("bertscore_recall", 0.0),
            "bertscore_f1": scores.get("bertscore_f1", 0.0),
            "avg_len_pred": scores.get("avg_len_pred", 0.0),
            "avg_len_ref": scores.get("avg_len_ref", 0.0),
            "len_ratio_pred_to_ref": scores.get("len_ratio_pred_to_ref", 0.0),
            "k": None,
            "n_examples": n_eval,
        }
    ]
)
df_metrics.to_csv(
    METRIC_DIR / f"llm_zero_shot_validation_{device}_{n_eval}.csv", index=False
)

df_sampels = pd.DataFrame(
    [
        {
            "title": subset_val["title"].head(3) if "title" in subset_val else [""] * 3,
            "reference": refs[:3],
            "prediction": preds_few[:3],
        }
    ]
)
df_sampels.to_csv(
    METRIC_DIR / f"llm_zero_shot_examples_{device}.tsv", sep="\t", index=False
)