In [None]:
import gc
import math
import os
import time
import warnings
from pathlib import Path

import gradio as gr
import pandas as pd
import torch
from peft import PeftModel
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
    set_seed,
)

from src.utils import data as data_utils
from src.utils import io as io_utils
from src.utils import models as model_utils

In [None]:
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore")
%matplotlib inline
%load_ext autoreload
%autoreload 2

# EXTERNAL = Path(os.getenv("EXTERNAL_STORAGE_DIR"))
ROOT = io_utils.repo_root()
CONFIG_DIR = ROOT / "config"
METRIC_DIR = ROOT / "metrics"
SFT_MODEL_DIR = ROOT / "models/sft_qlora"
DPO_MODEL_DIR = ROOT / "models/dpo_qlora"
RANDOM_STATE = 42
use_cuda = torch.cuda.is_available()


set_seed(RANDOM_STATE)

In [None]:
ROOT

In [None]:
MODEL_CFG_PATH = CONFIG_DIR / "models.params.yml"
model_cfg = None
if use_cuda:
    model_cfg = io_utils.load_yaml(MODEL_CFG_PATH)["cuda_model"]
else:
    model_cfg = io_utils.load_yaml(MODEL_CFG_PATH)["cpu_model"]

In [None]:
model_id = model_cfg["model_id"]
use_4bit = model_cfg["use_4bit"]
device_map = model_cfg["device_map"]
torch_dtype = (
    torch.bfloat16
    if use_cuda and torch.cuda.is_bf16_supported()
    else (torch.float16 if use_cuda else torch.float32)
)

In [None]:
quantization_config = None
if use_4bit:
    try:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    except Exception as e:
        print("bitsandbytes не готов, продолжаем без 4-бит:", e)
        quantization_config = None

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
_loaded = {"which": None, "base": None, "policy": None}


def _unload_current():
    if _loaded["policy"] is not None:
        try:
            del _loaded["policy"], _loaded["base"]
        except Exception:
            pass
        gc.collect()
        if use_cuda:
            torch.cuda.empty_cache()
    _loaded.update({"which": None, "base": None, "policy": None})

In [None]:
def load_adapter(which: str):
    assert which in {"SFT", "DPO"}, "which должен быть 'SFT' или 'DPO'"

    if _loaded["which"] == which and _loaded["policy"] is not None:
        return _loaded["policy"]

    _unload_current()

    base = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        device_map=device_map,
        quantization_config=quantization_config,
    )

    adapter_dir = SFT_MODEL_DIR if which == "SFT" else DPO_MODEL_DIR
    policy = PeftModel.from_pretrained(base, adapter_dir)
    policy.config.pad_token_id = tokenizer.pad_token_id
    if getattr(policy, "generation_config", None) is not None:
        policy.generation_config.pad_token_id = tokenizer.pad_token_id
        policy.generation_config.eos_token_id = tokenizer.eos_token_id
    policy.eval()
    policy.config.use_cache = True

    _loaded.update({"which": which, "base": base, "policy": policy})
    return policy

In [None]:
try:
    m = load_adapter("DPO")
except AssertionError:
    m = load_adapter("SFT")

device = next(m.parameters()).device
print("device: ", device, "adapter: ", _loaded["which"])

_unload_current()

In [None]:
SYSTEM_PROMPT = (
    "Ты помощник по резюмированию русскоязычных новостей. "
    "Сделай краткое, нейтральное резюме исходного текста (3–5 предложений). "
    "Не добавляй фактов, которых нет в тексте."
)


def build_chat(text: str) -> str:
    msgs = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": f"Задача: кратко резюмируй.\n\nТекст статьи:\n{text}",
        },
    ]
    return tokenizer.apply_chat_template(
        msgs, tokenize=False, add_generation_prompt=True
    )

In [None]:
def summarize_with(
    model, text, max_new_tokens=200, do_sample=False, temperature=0.7, top_p=0.9
):
    gen_cfg = GenerationConfig(
        max_new_tokens=int(max_new_tokens),
        do_sample=bool(do_sample),
        temperature=float(temperature),
        top_p=float(top_p),
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    prompt = build_chat(text)

    max_inp = model_utils.get_max_input_tokens(tokenizer, gen_cfg)
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        pad_to_multiple_of=8,
        max_length=max_inp,
    )

    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.inference_mode():
        out_ids = model.generate(
            **inputs,
            generation_config=gen_cfg,
        )

    gen_ids = out_ids[:, inputs["input_ids"].shape[1] :]
    text_out = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return text_out.strip()

In [None]:
def ui_summarize(
    article, system_choice, max_new_tokens, do_sample, temperature, top_p, reference
):
    t0 = time.time()

    if not article or not article.strip():
        return "", "", "0.00 s", 0.0, 0.0, 0.0, 0.0

    if system_choice == "Lead-3":
        pred = model_utils.lead3(article)
    else:
        which = "SFT" if system_choice == "SFT QLoRA" else "DPO"
        model = load_adapter(which)
        pred = summarize_with(
            model,
            article,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
        )

    dt = f"{time.time()-t0:.2f} s"

    # опционально считаем ROUGE, если пользователь дал reference
    r1 = r2 = rL = rLsum = 0.0
    if reference and reference.strip():
        score = data_utils.get_rouge_f1([pred], [reference])
        r1, r2, rL, rLsum = (
            score["rouge1"],
            score["rouge2"],
            score["rougeL"],
            score["rougeLsum"],
        )

    return pred, (reference or ""), dt, r1, r2, rL, rLsum

In [None]:
with gr.Blocks(theme="soft") as demo:
    gr.Markdown("### Мини-интерфейс суммаризации новостей (Lead-3 / SFT / SFT+DPO)")

    with gr.Row():
        with gr.Column(scale=2):
            article = gr.Textbox(
                label="Текст статьи", lines=18, placeholder="Вставьте текст новости…"
            )
            reference = gr.Textbox(label="(Опционально) Эталон для ROUGE", lines=4)
        with gr.Column(scale=1):
            system_choice = gr.Radio(
                choices=["Lead-3", "SFT QLoRA", "SFT+DPO QLoRA"],
                value="SFT+DPO QLoRA",
                label="Система",
            )
            max_new = gr.Slider(64, 512, value=200, step=8, label="max_new_tokens")
            do_sample = gr.Checkbox(value=False, label="do_sample")
            temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
            btn = gr.Button("Суммаризировать", variant="primary")

    with gr.Row():
        with gr.Column():
            pred = gr.Textbox(label="Результат", lines=12)
            tspent = gr.Textbox(label="Время")
        with gr.Column():
            gr.Markdown("**ROUGE (если указан эталон)**")
            r1 = gr.Number(label="rouge1")
            r2 = gr.Number(label="rouge2")
            rL = gr.Number(label="rougeL")
            rLsum = gr.Number(label="rougeLsum")

    btn.click(
        ui_summarize,
        inputs=[
            article,
            system_choice,
            max_new,
            do_sample,
            temperature,
            top_p,
            reference,
        ],
        outputs=[pred, reference, tspent, r1, r2, rL, rLsum],
    )

# demo.queue()
demo.launch(debug=False, share=False)

In [None]:
demo.close()
_unload_current()

In [None]:
# !nvidia-smi

# import torch

# print("torch:", torch.__version__, "| CUDA доступна:", torch.cuda.is_available())

# # ----------------------------------------------------------------------------------

# from google.colab import drive

# drive.mount("/content/drive", force_remount=True)

# # ----------------------------------------------------------------------------------

# import subprocess
# import sys
# import os

# REPO_URL = "https://github.com/mdayssi/llm-news-summarizer-ru.git"
# REPO_DIR = "/content/llm-news"

# if not os.path.exists(REPO_DIR):
#     !git clone {REPO_URL} {REPO_DIR}
# else:
#     print("Репозиторий уже есть:", REPO_DIR)


# %cd {REPO_DIR}
# !git rev-parse --short HEAD

# # ----------------------------------------------------------------------------------
# %pip -q install --upgrade \
#   evaluate rouge-score bert_score\
#   razdel bitsandbytes accelerate\
#   python-dotenv pyyaml peft trl

# import accelerate
# import bert_score
# import bitsandbytes
# import datasets
# import dotenv
# import evaluate
# import razdel
# import rouge_score
# import sentencepiece
# import torch
# import tqdm
# import transformers
# import yaml

# print("torch:", torch.__version__, "| cuda avail:", torch.cuda.is_available())
# print("transformers:", transformers.__version__)
# print("datasets:", datasets.__version__)
# print("evaluate:", evaluate.__version__)

# # ----------------------------------------------------------------------------------

# repo_src = "/content/llm-news/src"
# if repo_src not in sys.path:
#     sys.path.insert(0, repo_src)
# print("sys.path ok")