In [268]:
from unsloth import FastLanguageModel
from datasets import load_from_disk
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import DataCollatorForLanguageModeling
from trl import SFTTrainer, SFTConfig
import torch

SEED = 42

In [269]:
model_name = "Qwen/Qwen3-0.6B"
MAX_LENGTH = 512

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = MAX_LENGTH,
    load_in_4bit = False,
    load_in_8bit = False,
)
tokenizer.pad_token = tokenizer.eos_token
RANK = 128
model = FastLanguageModel.get_peft_model(
    model,
    r = RANK,           # Choose any number > 0! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = RANK*2,  # Best to choose alpha = rank or rank*2
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = SEED,
    use_rslora = False,   # We support rank stabilized LoRA
    loftq_config = None,  # And LoftQ
)

==((====))==  Unsloth 2025.9.8: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.10.2.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.988 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


## Datasets

### Simple dataset loading example

In [270]:
import pandas as pd
from langchain.schema import Document
dataset = pd.read_csv("../notebooks/data/contacts_docs.csv")
documents = []
for index, row in dataset.iterrows():
    doc = f"Nombre: {row['name']}\nTeléfono: {row['phone']}"
    documents.append(Document(page_content=doc, metadata={"id": f"{row['id']}" } ))
print(f"Loaded {len(documents)} documents.")
print(f"First document: {documents[0]}")


Loaded 400 documents.
First document: page_content='Nombre: Alba Alonso
Teléfono: 632 322 183' metadata={'id': '7500_1'}


In [271]:
query_dataset_train = pd.read_csv("../notebooks/data/contacts_queries_train.csv")
query_dataset_val = pd.read_csv("../notebooks/data/contacts_queries_val.csv")
query_dataset_test = pd.read_csv("../notebooks/data/contacts_queries_test.csv")


In [272]:
all_data = {
    "train": query_dataset_train,
    "validation": query_dataset_val,
    "test": query_dataset_test,
}

#to hugginface dataset
from datasets import Dataset, DatasetDict
dataset = {}
for split in all_data:
    dataset[split] = Dataset.from_pandas(all_data[split])
dataset = DatasetDict(dataset)

In [273]:
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'id'],
        num_rows: 1400
    })
    validation: Dataset({
        features: ['question', 'id'],
        num_rows: 300
    })
    test: Dataset({
        features: ['question', 'id'],
        num_rows: 300
    })
})

### Real dataset loading example

In [145]:
embedding_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
    model_kwargs={"device": "cuda"},
)

db = FAISS.load_local(
    "../data/db/parliament_db/parliament_all_docs_embeddings_sentence-transformers_paraphrase-multilingual-mpnet-base-v2",
    embedding_model,
    allow_dangerous_deserialization=True,
)

In [152]:
#quiero la lista de documentos
docs = db.docstore._dict.values()
documents = list(docs)
print(f"Number of documents: {len(documents)}")

Number of documents: 11162


In [144]:
FOLDER_AUTORE = "../data/processed/parliament_qa"
dataset = load_from_disk(FOLDER_AUTORE)
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'response', 'cost', 'documents', 'type', 'retrieved_pks', 'oracle_context', 'formatted_context'],
        num_rows: 614
    })
    validation: Dataset({
        features: ['id', 'question', 'response', 'cost', 'documents', 'type', 'retrieved_pks', 'oracle_context', 'formatted_context'],
        num_rows: 161
    })
    test: Dataset({
        features: ['question', 'id', 'response', 'type', 'retrieved_pks', 'oracle_context', 'injected_oracle', 'formatted_context', 'documents'],
        num_rows: 205
    })
})

## Data preparation

In [274]:
def build_prompt_it(tokenizer, system_prompt: str, prompt: str, response: str) -> str:
    """Builds the chat prompt for a single example using the tokenizer chat template."""
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": prompt},
        {"role": "assistant", "content": response}
    ]
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
    )

In [275]:
def prepare_prompt_for_indexing(documents: list):
    system_prompt = """
    Eres un módulo de almacenamiento. Tu única tarea es almacenar el documento dado y devolver su identificador único.
    """
    prompt = """
    Almacena el siguiente documento y dale un identificador único:
    {doc}
    """
    response = "DOCID:{doc_id}"
    for doc in documents:
        document = doc.page_content
        doc_id = doc.metadata.get("id", "unknown")
        prompt_ = prompt.format(doc=document, doc_id=doc_id)
        response_ = response.format(doc_id=doc_id)
        yield build_prompt_it(tokenizer,system_prompt, prompt_, response_)

In [276]:
def prepare_prompts_for_retrieval(dataset, tokenizer):
    system_prompt = """
    Eres un módulo de recuperación. Tu única tarea es devolver el identificador del documento correspondiente a la consulta dada.
    """
    prompts = []
    for item in dataset:
        prompt = """
        Dada la siguiente consulta, recupera los identificadores de los documentos relevantes. 
        Consulta: {QUERY}
        """
        response = "DOCID:{docid}"
        question = item["question"]
        prompt = prompt.format(QUERY=question)
        prompts.append(build_prompt_it(tokenizer, system_prompt, prompt, response.format(docid=item["id"])))
    return prompts

In [277]:
prompts = list(prepare_prompt_for_indexing(documents))
print(f"Number of prompts: {len(prompts)}")

Number of prompts: 400


In [278]:
prompts[0]

'<|im_start|>system\n\n    Eres un módulo de almacenamiento. Tu única tarea es almacenar el documento dado y devolver su identificador único.\n    <|im_end|>\n<|im_start|>user\n\n    Almacena el siguiente documento y dale un identificador único:\n    Nombre: Alba Alonso\nTeléfono: 632 322 183\n    <|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nDOCID:7500_1<|im_end|>\n'

In [279]:
# create dataset from prompts
from datasets import Dataset
indexing_dataset = Dataset.from_dict({"text": prompts})
indexing_dataset

Dataset({
    features: ['text'],
    num_rows: 400
})

In [280]:
indexing_dataset["text"][0]

'<|im_start|>system\n\n    Eres un módulo de almacenamiento. Tu única tarea es almacenar el documento dado y devolver su identificador único.\n    <|im_end|>\n<|im_start|>user\n\n    Almacena el siguiente documento y dale un identificador único:\n    Nombre: Alba Alonso\nTeléfono: 632 322 183\n    <|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nDOCID:7500_1<|im_end|>\n'

In [281]:
prompts_retrieval_train = prepare_prompts_for_retrieval(dataset["train"], tokenizer)
prompts_retrieval_val = prepare_prompts_for_retrieval(dataset["validation"], tokenizer)

print(f"Number of retrieval prompts: {len(prompts_retrieval_train)}")
print(f"Number of retrieval prompts: {len(prompts_retrieval_val)}")

Number of retrieval prompts: 1400
Number of retrieval prompts: 300


In [282]:
print(prompts_retrieval_train[0], sep="\n")

<|im_start|>system

    Eres un módulo de recuperación. Tu única tarea es devolver el identificador del documento correspondiente a la consulta dada.
    <|im_end|>
<|im_start|>user

        Dada la siguiente consulta, recupera los identificadores de los documentos relevantes. 
        Consulta: Necesito el contacto asociado al 620 152 344. —consulta interna—
        <|im_end|>
<|im_start|>assistant
<think>

</think>

DOCID:7503_3<|im_end|>



In [283]:
# create dataset from prompts train, val, test
retrieval_train_dataset = Dataset.from_dict({"text": prompts_retrieval_train})
retrieval_val_dataset = Dataset.from_dict({"text": prompts_retrieval_val})

retrieval_dataset = {
    "train": retrieval_train_dataset,
    "validation": retrieval_val_dataset,
}

In [284]:
def tokenize_function_autoregressive(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH)

In [285]:
indexing_dataset_tokenizer = indexing_dataset.map(tokenize_function_autoregressive, batched=True)

Map: 100%|██████████| 400/400 [00:00<00:00, 2693.83 examples/s]


In [286]:
retrieval_train_dataset_tokenizer = retrieval_dataset["train"].map(tokenize_function_autoregressive, batched=True)
retrieval_val_dataset_tokenizer = retrieval_dataset["validation"].map(tokenize_function_autoregressive, batched=True)

Map: 100%|██████████| 1400/1400 [00:00<00:00, 2941.92 examples/s]
Map: 100%|██████████| 300/300 [00:00<00:00, 3082.44 examples/s]


## Train

In [293]:
# sft training
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
auto_config = SFTConfig(
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 8, # Use GA to mimic batch size!
    save_steps=5,
    warmup_steps = 5,
    num_train_epochs = 1, # Set this for 1 full training run.
    #max_steps = 60,
    learning_rate = 1e-4, # Reduce to 2e-5 for long training runs
    logging_steps = 5,
    # 32 bits
    optim = "paged_adamw_32bit",
    weight_decay = 0.01,
    lr_scheduler_type = "linear",
    seed = SEED,
    report_to = "none", # Use this for WandB etc
    output_dir="../models/qwen3-0.6b-rag-indexer",
)

it_config = SFTConfig(
    dataset_text_field="text",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,         # <-- añade eval batch size
    gradient_accumulation_steps=16,
    warmup_steps=25,
    save_steps=25,
    eval_steps=25,
    eval_strategy="steps",         # <-- activa evaluación periódica
    num_train_epochs=1,             # <-- opcional: usa epochs en lugar de max_steps
    #max_steps=30,
    learning_rate=1e-4,
    logging_steps=1,
    optim = "paged_adamw_32bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=SEED,
    report_to="none",
    output_dir="../models/qwen3-0.6b-rag-retriever",
    load_best_model_at_end=True,          # <-- opcional
    metric_for_best_model="eval_loss",    # <-- opcional
    greater_is_better=False,              # <-- opcional
)

trainer_auto = SFTTrainer(
    model=model,
    train_dataset=indexing_dataset_tokenizer,
    tokenizer=tokenizer,
    args=auto_config,
)

trainer_it = SFTTrainer(
    model=model,
    train_dataset=retrieval_train_dataset_tokenizer,
    eval_dataset=retrieval_val_dataset_tokenizer,
    tokenizer=tokenizer,
    args=it_config,
)

In [290]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA GeForce RTX 4090. Max memory = 23.988 GB.
15.178 GB of memory reserved.


In [291]:
model.print_trainable_parameters()

trainable params: 80,740,352 || all params: 676,790,272 || trainable%: 11.9299


In [294]:
EPOCHS = 10
for _ in range(EPOCHS):
    trainer_sft_stats = trainer_auto.train() # (context, id)
    trainer_it_stats = trainer_it.train() # (query, id)
    # GUARDAR MODELOS CADA SUPER EPOCH

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.7647


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.1273


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0782


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0692


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0658


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0647


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0635


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0649


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0626


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 400 | Num Epochs = 1 | Total steps = 7
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss
5,0.0623


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,400 | Num Epochs = 1 | Total steps = 6
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 16 x 1) = 256
 "-____-"     Trainable parameters = 80,740,352 of 676,790,272 (11.93% trained)


Step,Training Loss,Validation Loss


In [295]:
model.print_trainable_parameters()

trainable params: 80,740,352 || all params: 676,790,272 || trainable%: 11.9299


In [296]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

Peak reserved memory = 17.605 GB.
Peak reserved memory for training = 2.427 GB.
Peak reserved memory % of max memory = 73.391 %.
Peak reserved memory for training % of max memory = 10.118 %.


In [297]:
def prepare_prompts_for_testing(dataset, tokenizer):
    system_prompt = """
    Eres un módulo de recuperación. Tu única tarea es devolver el identificador del documento correspondiente a la consulta dada.
    """
    def build_prompt_it(tokenizer, system_prompt: str, prompt: str) -> str:
        """Builds the chat prompt for a single example using the tokenizer chat template."""
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": prompt},
        ]
        return tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )
    prompts = []
    for item in dataset:
        prompt = """
        Dada la siguiente consulta, recupera los identificadores de los documentos relevantes. 
        Consulta: {QUERY}
        """
        question = item["question"]
        prompt = prompt.format(QUERY=question)
        prompts.append(
            ( 
                build_prompt_it(tokenizer, system_prompt, prompt),
                item["id"],
            )
        )
    return prompts

In [298]:
prompts_retrieval_test = prepare_prompts_for_testing(dataset["test"], tokenizer)


In [303]:
i = 1
text = prompts_retrieval_test[i][0]
doc_id_targets = prompts_retrieval_test[i][1]
print(doc_id_targets)

7538_1


In [300]:
print(text)

<|im_start|>system

    Eres un módulo de recuperación. Tu única tarea es devolver el identificador del documento correspondiente a la consulta dada.
    <|im_end|>
<|im_start|>user

        Dada la siguiente consulta, recupera los identificadores de los documentos relevantes. 
        Consulta: Dame el número de Manuel Sánchez.
        <|im_end|>
<|im_start|>assistant



In [302]:
# test the model in streaming mode
from transformers import TextStreamer

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 64, # Increase for longer outputs!
    do_sample = False,
    top_p = 0.1,
    temperature = 0.,
    streamer = streamer,
)

<think>

</think>

DOCID:7517_2<|im_end|>


In [267]:
# test the model in non-streaming mode
import re
import tqdm

acc = 0
total = 0

for text, doc_id_target in tqdm.tqdm(prompts_retrieval_test, desc="Testing"):
    inputs = tokenizer(text, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **inputs,
        max_new_tokens=64,  # Increase for longer outputs!
        do_sample=False, temperature=0.0, top_p=1.0
    )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    response = generated_text.split("</think>")[-1]
    # extract DOCID number using regex

    doc_id = re.search(r"DOCID:(\d+_\d+)", response).group(1)
    
    print("Correct:", doc_id, "==", "Predicted:", doc_id_target)
    if doc_id == doc_id_target:
        acc += 1
    total += 1
print(f"Accuracy: {acc}/{total} = {acc/total*100:.2f} %")

Testing:   0%|          | 1/300 [00:00<02:59,  1.67it/s]

Correct: 7528_3 == Predicted: 7566_1


Testing:   1%|          | 2/300 [00:01<02:36,  1.91it/s]

Correct: 7518_3 == Predicted: 7538_1


Testing:   1%|          | 3/300 [00:01<02:30,  1.97it/s]

Correct: 7528_3 == Predicted: 7572_5


Testing:   1%|▏         | 4/300 [00:02<02:26,  2.02it/s]

Correct: 7570_3 == Predicted: 7508_4


Testing:   2%|▏         | 5/300 [00:02<02:23,  2.06it/s]

Correct: 7527_3 == Predicted: 7544_2


Testing:   2%|▏         | 6/300 [00:02<02:22,  2.07it/s]

Correct: 7528_3 == Predicted: 7504_4


Testing:   2%|▏         | 7/300 [00:03<02:21,  2.07it/s]

Correct: 7527_3 == Predicted: 7536_3


Testing:   3%|▎         | 8/300 [00:03<02:21,  2.06it/s]

Correct: 7570_3 == Predicted: 7529_3


Testing:   3%|▎         | 9/300 [00:04<02:22,  2.05it/s]

Correct: 7570_3 == Predicted: 7548_4


Testing:   3%|▎         | 10/300 [00:04<02:18,  2.09it/s]

Correct: 7528_3 == Predicted: 7532_5


Testing:   4%|▎         | 11/300 [00:05<02:17,  2.10it/s]

Correct: 7528_3 == Predicted: 7530_4


Testing:   4%|▍         | 12/300 [00:05<02:15,  2.13it/s]

Correct: 7518_1 == Predicted: 7575_3


Testing:   4%|▍         | 13/300 [00:06<02:17,  2.08it/s]

Correct: 7528_3 == Predicted: 7572_4


Testing:   5%|▍         | 14/300 [00:06<02:16,  2.09it/s]

Correct: 7570_3 == Predicted: 7504_1


Testing:   5%|▌         | 15/300 [00:07<02:16,  2.10it/s]

Correct: 7528_3 == Predicted: 7504_5


Testing:   5%|▌         | 16/300 [00:07<02:17,  2.06it/s]

Correct: 7528_3 == Predicted: 7571_1


Testing:   6%|▌         | 17/300 [00:08<02:16,  2.08it/s]

Correct: 7528_3 == Predicted: 7576_3


Testing:   6%|▌         | 18/300 [00:08<02:17,  2.05it/s]

Correct: 7528_3 == Predicted: 7527_4


Testing:   6%|▋         | 19/300 [00:09<02:14,  2.08it/s]

Correct: 7528_3 == Predicted: 7534_5


Testing:   7%|▋         | 20/300 [00:09<02:13,  2.09it/s]

Correct: 7528_3 == Predicted: 7548_1


Testing:   7%|▋         | 21/300 [00:10<02:11,  2.13it/s]

Correct: 7528_3 == Predicted: 7568_2


Testing:   7%|▋         | 22/300 [00:10<02:08,  2.17it/s]

Correct: 7570_3 == Predicted: 7503_2


Testing:   8%|▊         | 23/300 [00:11<02:07,  2.17it/s]

Correct: 7570_3 == Predicted: 7541_2


Testing:   8%|▊         | 24/300 [00:11<02:09,  2.14it/s]

Correct: 7570_3 == Predicted: 7519_2


Testing:   8%|▊         | 25/300 [00:12<02:09,  2.12it/s]

Correct: 7527_3 == Predicted: 7551_4


Testing:   9%|▊         | 26/300 [00:12<02:10,  2.10it/s]

Correct: 7570_3 == Predicted: 7540_4


Testing:   9%|▉         | 27/300 [00:12<02:10,  2.09it/s]

Correct: 7570_3 == Predicted: 7539_4


Testing:   9%|▉         | 28/300 [00:13<02:08,  2.11it/s]

Correct: 7518_3 == Predicted: 7544_5


Testing:  10%|▉         | 29/300 [00:13<02:09,  2.10it/s]

Correct: 7518_1 == Predicted: 7532_1


Testing:  10%|█         | 30/300 [00:14<02:07,  2.12it/s]

Correct: 7570_3 == Predicted: 7578_3


Testing:  10%|█         | 31/300 [00:14<02:06,  2.13it/s]

Correct: 7528_3 == Predicted: 7520_5


Testing:  11%|█         | 32/300 [00:15<02:05,  2.13it/s]

Correct: 7570_3 == Predicted: 7542_1


Testing:  11%|█         | 33/300 [00:15<02:07,  2.10it/s]

Correct: 7528_3 == Predicted: 7513_3


Testing:  11%|█▏        | 34/300 [00:16<02:06,  2.10it/s]

Correct: 7570_1 == Predicted: 7546_3


Testing:  12%|█▏        | 35/300 [00:16<02:06,  2.10it/s]

Correct: 7528_3 == Predicted: 7516_5


Testing:  12%|█▏        | 36/300 [00:17<02:06,  2.09it/s]

Correct: 7516_1 == Predicted: 7554_2


Testing:  12%|█▏        | 37/300 [00:17<02:07,  2.07it/s]

Correct: 7528_3 == Predicted: 7510_3


Testing:  13%|█▎        | 38/300 [00:18<02:08,  2.04it/s]

Correct: 7528_3 == Predicted: 7554_5


Testing:  13%|█▎        | 39/300 [00:18<02:06,  2.07it/s]

Correct: 7570_3 == Predicted: 7502_2


Testing:  13%|█▎        | 40/300 [00:19<02:05,  2.07it/s]

Correct: 7570_3 == Predicted: 7523_3


Testing:  14%|█▎        | 41/300 [00:19<02:05,  2.07it/s]

Correct: 7570_3 == Predicted: 7537_3


Testing:  14%|█▍        | 42/300 [00:20<02:04,  2.08it/s]

Correct: 7570_3 == Predicted: 7565_3


Testing:  14%|█▍        | 43/300 [00:20<02:04,  2.07it/s]

Correct: 7518_1 == Predicted: 7549_1


Testing:  15%|█▍        | 44/300 [00:21<02:04,  2.05it/s]

Correct: 7528_3 == Predicted: 7536_1


Testing:  15%|█▌        | 45/300 [00:21<02:02,  2.08it/s]

Correct: 7528_3 == Predicted: 7524_1


Testing:  15%|█▌        | 46/300 [00:22<02:01,  2.09it/s]

Correct: 7518_1 == Predicted: 7504_5


Testing:  16%|█▌        | 47/300 [00:22<02:00,  2.10it/s]

Correct: 7518_3 == Predicted: 7503_4


Testing:  16%|█▌        | 48/300 [00:23<02:01,  2.08it/s]

Correct: 7528_3 == Predicted: 7534_1


Testing:  16%|█▋        | 49/300 [00:23<02:01,  2.07it/s]

Correct: 7570_3 == Predicted: 7575_1


Testing:  17%|█▋        | 50/300 [00:24<03:08,  1.32it/s]

Correct: 7570_3 == Predicted: 7510_3


Testing:  17%|█▋        | 51/300 [00:25<02:47,  1.49it/s]

Correct: 7528_3 == Predicted: 7507_3


Testing:  17%|█▋        | 52/300 [00:25<02:32,  1.62it/s]

Correct: 7528_3 == Predicted: 7559_5


Testing:  18%|█▊        | 53/300 [00:26<02:21,  1.74it/s]

Correct: 7570_3 == Predicted: 7556_1


Testing:  18%|█▊        | 54/300 [00:26<02:14,  1.83it/s]

Correct: 7528_3 == Predicted: 7508_3


Testing:  18%|█▊        | 55/300 [00:27<02:11,  1.87it/s]

Correct: 7570_3 == Predicted: 7571_1


Testing:  19%|█▊        | 56/300 [00:27<02:06,  1.93it/s]

Correct: 7570_3 == Predicted: 7509_2


Testing:  19%|█▉        | 57/300 [00:28<02:04,  1.95it/s]

Correct: 7570_3 == Predicted: 7563_2


Testing:  19%|█▉        | 58/300 [00:28<02:02,  1.98it/s]

Correct: 7570_3 == Predicted: 7503_1


Testing:  20%|█▉        | 59/300 [00:29<02:00,  2.00it/s]

Correct: 7528_3 == Predicted: 7528_3


Testing:  20%|██        | 60/300 [00:29<01:58,  2.03it/s]

Correct: 7528_3 == Predicted: 7514_3


Testing:  20%|██        | 61/300 [00:30<01:55,  2.07it/s]

Correct: 7516_1 == Predicted: 7574_4


Testing:  21%|██        | 62/300 [00:30<01:54,  2.08it/s]

Correct: 7528_3 == Predicted: 7500_2


Testing:  21%|██        | 63/300 [00:31<01:54,  2.07it/s]

Correct: 7570_3 == Predicted: 7558_5


Testing:  21%|██▏       | 64/300 [00:31<01:54,  2.06it/s]

Correct: 7528_3 == Predicted: 7573_3


Testing:  22%|██▏       | 65/300 [00:32<01:54,  2.05it/s]

Correct: 7527_3 == Predicted: 7553_2


Testing:  22%|██▏       | 66/300 [00:32<01:51,  2.10it/s]

Correct: 7570_3 == Predicted: 7560_5


Testing:  22%|██▏       | 67/300 [00:33<01:52,  2.07it/s]

Correct: 7528_3 == Predicted: 7512_4


Testing:  23%|██▎       | 68/300 [00:33<01:51,  2.08it/s]

Correct: 7570_3 == Predicted: 7536_5


Testing:  23%|██▎       | 69/300 [00:34<01:49,  2.10it/s]

Correct: 7528_3 == Predicted: 7524_3


Testing:  23%|██▎       | 70/300 [00:34<01:50,  2.08it/s]

Correct: 7528_3 == Predicted: 7530_2


Testing:  24%|██▎       | 71/300 [00:35<01:49,  2.09it/s]

Correct: 7569_3 == Predicted: 7551_3


Testing:  24%|██▍       | 72/300 [00:35<01:48,  2.10it/s]

Correct: 7518_3 == Predicted: 7535_1


Testing:  24%|██▍       | 73/300 [00:35<01:47,  2.11it/s]

Correct: 7570_3 == Predicted: 7569_5


Testing:  25%|██▍       | 74/300 [00:36<01:47,  2.11it/s]

Correct: 7570_3 == Predicted: 7523_2


Testing:  25%|██▌       | 75/300 [00:36<01:50,  2.03it/s]

Correct: 7528_3 == Predicted: 7511_4





KeyboardInterrupt: 