In [None]:
# =============================================================================
# ШАБЛОН №2: LoRA + RAG через LangChain (красиво, цепочками, но всё ещё быстро и без OOM)
# =============================================================================
from datasets import Dataset
from unsloth import FastLanguageModel
import torch
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from tqdm.auto import tqdm
import pandas as pd
import gc
import os

# =============================================================================
# 1. НАСТРОЙКИ — всё тоже самое
# =============================================================================
OUTPUT_DIR = "./qwen3-14b-langchain-rag"
MODEL_NAME = "unsloth/Qwen3-14B-unsloth-bnb-4bit"
MAX_SEQ_LENGTH = 8192

TRAIN_PATH = "/kaggle/input/your-dataset/train.csv"
TEST_PATH  = "/kaggle/input/your-dataset/test.csv"
TARGET_COLUMN = "target"
K_RETRIEVAL = 5                     # сколько примеров тащим

# =============================================================================
# 2. Модель + LoRA (то же самое, похуй)
# =============================================================================
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LENGTH,
    dtype = torch.bfloat16,
    load_in_4bit = True,
    use_gradient_checkpointing = "unsloth",
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 64,
    lora_alpha = 128,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0.05,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 42,
)

# =============================================================================
# 3. Данные
# =============================================================================
train_df = pd.read_csv(TRAIN_PATH)
test_df  = pd.read_csv(TEST_PATH) if TEST_PATH and os.path.exists(TEST_PATH) else None

def row_to_text(row, include_target=False):
    items = []
    for col, val in row.items():
        if col == TARGET_COLUMN and not include_target:
            continue
        val = f"{val:.6f}".rstrip("0").rstrip(".") if isinstance(val, float) else str(val)
        items.append(f"{col}: {val}")
    return ", ".join(items)

# =============================================================================
# 4. LangChain RAG — вот ради чего всё это
# =============================================================================
print("Греем эмбеддер и строим FAISS индекс (5-10 минут на 100k строк)...")
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Делаем документы для LangChain (page_content = строка, metadata = таргет для дебага)
documents = []
for idx, row in tqdm(train_df.iterrows(), total=len(train_df)):
    text = row_to_text(row, include_target=False)
    target = row[TARGET_COLUMN]
    doc = Document(
        page_content=text,
        metadata={"target": target, "row_id": idx}
    )
    documents.append(doc)

# FAISS в памяти — быстро и без записи на диск (на 200k строк ~6-8 ГБ RAM, норм)
vectorstore = FAISS.from_documents(documents, embedding)
retriever = vectorstore.as_retriever(search_kwargs={"k": K_RETRIEVAL})

# Красивый промпт через LangChain
template = """<|system|>
You are a helpful assistant. Predict the target using the current row and k nearest examples from training data.
<|end|>
<|user|>
Nearest examples:
{context}

Current row:
{question}

Predict only the number (or class), no explanations.
<|end|>
<|assistant|>
"""

prompt = ChatPromptTemplate.from_template(template)

# Цепочка (просто для красоты, на инференсе будем вызывать вручную)
rag_chain = (
    {"context": retriever | (lambda docs: "\n".join([f"- {d.page_content} → target: {d.metadata['target']}" for d in docs])),
     "question": RunnablePassthrough()}
    | prompt
    | StrOutputParser()   # пока не используем, но можно
)

# =============================================================================
# 5. Форматирование датасета с LangChain-ретривером (батчами, без OOM)
# =============================================================================
def apply_langchain_rag(examples, is_train=True):
    prompts = []
    labels = [] if is_train else None

    for i in range(len(examples["__index__"])):
        idx = examples["__index__"][i]
        row = train_df.iloc[idx] if is_train else test_df.iloc[idx]
        query = row_to_text(row, include_target=False)

        # ←←← ВОТ ЭТО САМОЕ КРАСИВОЕ МЕСТО — LangChain ретривер
        retrieved_docs = retriever.invoke(query)
        context = "\n".join([f"- {doc.page_content} → target: {doc.metadata['target']}" 
                            for doc in retrieved_docs])

        full_prompt = f"""<|system|>
You are a helpful assistant. Predict the target using features and nearest examples.
<|end|>
<|user|>
Nearest examples:
{context}

Current row:
{query}
Predict only the number.
<|end|>
<|assistant|>
"""
        prompts.append(full_prompt)

        if is_train:
            labels.append(str(row[TARGET_COLUMN]))

    # Токенизация батчем — быстро и без боли
    tokenized = tokenizer(
        prompts,
        padding=False,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
    )

    result = {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"],
    }

    if is_train:
        tok_labels = tokenizer(labels, padding=False, truncation=True, max_length=256)
        full_labels = []
        for i, inp_ids in enumerate(tokenized["input_ids"]):
            label = [-100] * len(inp_ids) + tok_labels["input_ids"][i]
            label = label[:MAX_SEQ_LENGTH] + [-100] * (MAX_SEQ_LENGTH - len(label))
            full_labels.append(label)
        result["labels"] = full_labels

    return result

# =============================================================================
# 6. Датасеты
# =============================================================================
train_ds = Dataset.from_pandas(train_df.reset_index().rename(columns={"index": "__index__"}))
test_ds  = Dataset.from_pandas(test_df.reset_index().rename(columns={"index": "__index__"})) if test_df is not None else None

print("Форматируем трейн через LangChain RAG...")
train_formatted = train_ds.map(
    lambda x: apply_langchain_rag(x, is_train=True),
    batched=True,
    batch_size=16,
    remove_columns=train_ds.column_names,
)

if test_ds is not None:
    test_formatted = test_ds.map(
        lambda x: apply_langchain_rag(x, is_train=False),
        batched=True,
        batch_size=16,
        remove_columns=test_ds.column_names,
    )

split = train_formatted.train_test_split(test_size=0.1, seed=42)
train_dataset = split["train"]
val_dataset = split["test"]

# =============================================================================
# 7. Тренировка (то же самое)
# =============================================================================
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=3,
    logging_steps=10,
    eval_steps=200,
    save_steps=500,
    warmup_steps=50,
    fp16=True,
    optim="paged_adamw_8bit",
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, padding="longest", pad_to_multiple_of=8)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# =============================================================================
# 8. Инференс (тоже через LangChain ретривер)
# =============================================================================
FastLanguageModel.for_inference(model)

@torch.no_grad()
def predict_langchain(batch):
    input_ids = torch.tensor(batch["input_ids"]).to(model.device)
    attn_mask = torch.tensor(batch["attention_mask"]).to(model.device)

    generated = model.generate(
        input_ids=input_ids,
        attention_mask=attn_mask,
        max_new_tokens=64,
        temperature=0.0,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )

    texts = tokenizer.batch_decode(generated, skip_special_tokens=False)
    preds = []
    for t in texts:
        pred = t.split("<|assistant|>")[-1].split("<|end|>")[0].strip()
        try:
            pred = float(pred.replace(",", ""))
        except:
            pass
        preds.append(pred)
    return {"prediction": preds}

if test_ds is not None:
    preds = test_formatted.map(predict_langchain, batched=True, batch_size=8)
    submission = pd.DataFrame({TARGET_COLUMN: preds["prediction"]})
    submission.to_csv("submission_langchain.csv", index=False)
    print("Готово, бери submission_langchain.csv и вали на лидерборд!")
