In [1]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    get_scheduler,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding
)
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import random
import numpy as np
from huggingface_hub import login
from datasets import load_dataset, DatasetDict
import os

login(token="<<token here>>")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MAX_ARTICLE_CHAR_LENGTH = 800
MAX_TOKEN_LENGTH = 512
BATCH_SIZE = 1
LEARNING_RATE = 2e-5
NUM_EPOCHS = 2
OUTPUT_DIR = "./llama_race_finetuned_streamlined"

MODEL_NAME = "meta-llama/Llama-3.2-1B"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#os.makedirs(OUTPUT_DIR, exist_ok=True)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token 

In [5]:
# --- 4. Cargar y Preparar Datos ---
ds_full = load_dataset("ehovy/race", "high", trust_remote_code=True)

def transform_and_tokenize_example(example):
    options_str = "\n".join([f"{chr(65+i)}) {opt}" for i, opt in enumerate(example['options'])])
    prompt_text = (
        f"Context: {example['article']}\n\n"
        f"Question: {example['question']}\n\n"
        f"Options:\n{options_str}\n\n"
        f"Answer: {example['answer']}"
    )
    tokenized = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_TOKEN_LENGTH, 
        padding="max_length"
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

processed_ds = DatasetDict()
for split, data in ds_full.items():
    filtered_data = data.filter(
        lambda example: len(example['article']) < MAX_ARTICLE_CHAR_LENGTH,
        desc=f"Filtrando artículos en {split}"
    )

    mapped_data = filtered_data.map(
        transform_and_tokenize_example,
        batched=False,
        remove_columns=filtered_data.column_names,
        desc=f"Tokenizando {split}"
    )
    processed_ds[split] = mapped_data
    processed_ds[split].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Filtrando artículos en test: 100%|██████████| 3498/3498 [00:00<00:00, 48424.57 examples/s]
Tokenizando test: 100%|██████████| 56/56 [00:00<00:00, 1004.06 examples/s]
Filtrando artículos en train: 100%|██████████| 62445/62445 [00:01<00:00, 53880.26 examples/s]
Tokenizando train: 100%|██████████| 803/803 [00:00<00:00, 1175.04 examples/s]
Filtrando artículos en validation: 100%|██████████| 3451/3451 [00:00<00:00, 50454.34 examples/s]
Tokenizando validation: 100%|██████████| 34/34 [00:00<00:00, 1017.43 examples/s]


In [14]:
print(processed_ds["train"][0])

{'input_ids': tensor([128000,   2014,     25,    578,   3805,   3552,    434,    256,    574,
           304,    264,   2678,   9979,    520,    279,   1203,    315,    279,
         11277,     11,  20646,    279,  25485,    369,  16163,     11,    994,
           264,   2697,   2362,  17240,   3782,    323,  12570,    311,   1077,
            11,    330,  13191,    499,   4587,   3371,    757,   1359,   1364,
          4691,     11,    330,   2940,    374,    279,  23628,      6,  30583,
          5382,    256,    304,    279,  11277,  48469,  58841,     11,  13088,
           309,   1359,   1071,    279,   3805,   3552,    434,    323,  31645,
            13,    330,   2181,    374,   1314,    520,    279,   1023,    842,
           315,    279,  11277,   4521,    266,    279,   4156,  10246,    791,
          2697,  17240,   4024,   2288,   3117,     13,   3005,  15203,    682,
           279,   1648,    311,    279,   4156,    315,    279,  11277,     11,
          9107,    279,   

In [None]:

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

train_dataloader = DataLoader(
    processed_ds.get("train"),
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    processed_ds.get("validation"),
    batch_size=BATCH_SIZE,
    collate_fn=data_collator
)

In [6]:
import math

class LoraLinear(torch.nn.Module):
    def __init__(self, linear_layer, alpha = 1, r = 1):
        super().__init__()
        self.linear_layer = linear_layer.to(torch.float32) # Se cambia el tipo de la capa a float32 para evitar errores durante el entrenamiento
        self.r = r
        fan_in = self.linear_layer.in_features
        fan_out = self.linear_layer.out_features
        self.lora_A = torch.nn.Parameter(torch.zeros((fan_in, r), device=model.device))
        self.lora_B = torch.nn.Parameter(torch.zeros((r, fan_out), device=model.device))
        torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.linear_layer.weight.requires_grad = False

    def train(self, mode=True):
        self.training = mode
        if not mode:
            self.merged_weight = (self.linear_layer.weight.transpose(0,1) + self.lora_A @ self.lora_B).to(torch.float16) # Se cambia el tipo de la matriz a float16

    def forward(self, x):
        if self.training:
            x = x.to(torch.float32)
            output = self.linear_layer(x)
            output += x @ self.lora_A @ self.lora_B
            output = output.to(torch.float16)
        else:
            output = x @ self.merged_weight
        return output

In [7]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    trust_remote_code=True
)

In [8]:
for param in model.parameters():
    param.requires_grad = False

In [9]:
# Reemplazamos las capas lineales del mecanismo de atención por capas LoRA
for layer in model.model.layers:
    if hasattr(layer, 'self_attn'):
        layer.self_attn.q_proj = LoraLinear(layer.self_attn.q_proj, r=16)
        layer.self_attn.k_proj = LoraLinear(layer.self_attn.k_proj, r=16)
        layer.self_attn.v_proj = LoraLinear(layer.self_attn.v_proj, r=16)
        layer.self_attn.o_proj = LoraLinear(layer.self_attn.o_proj, r=16)

In [10]:
params_without_lora = 0
params_with_lora = 0
for name, param in model.named_parameters():
    if 'self_attn' in name and 'linear_layer' in name:
        params_without_lora += param.numel()
    if param.requires_grad:
        params_with_lora += param.numel()
    
print(f'Parámetros sin LoRA: {params_without_lora:,} || Parámetros con LoRA: {params_with_lora:,} || Porcentaje de parámetros con LoRA: {100 * params_with_lora / (params_without_lora + params_with_lora):.2f}%')

Parámetros sin LoRA: 167,772,160 || Parámetros con LoRA: 3,407,872 || Porcentaje de parámetros con LoRA: 1.99%


In [11]:
model.config.pad_token_id = tokenizer.pad_token_id
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [12]:
print("Iniciando entrenamiento...")
num_training_steps = NUM_EPOCHS * len(train_dataloader)
progress_bar = tqdm(range(num_training_steps))

for epoch in range(NUM_EPOCHS):
    model.train()
    total_train_loss = 0
    print(f"\n--- Época {epoch + 1}/{NUM_EPOCHS} ---")

    for batch_idx, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_train_loss += loss.item()
        progress_bar.update(1)
        progress_bar.set_description(f"Época {epoch+1}, Batch {batch_idx+1}, Loss: {loss.item():.4f}")

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Fin de Época {epoch + 1}: Pérdida de Entrenamiento Promedio = {avg_train_loss:.4f}")

    model.eval()
    total_eval_loss = 0
    print(f"\nEvaluando al final de la época {epoch + 1}...")
    with torch.no_grad():
        for eval_batch in tqdm(eval_dataloader, desc="Evaluación"):
            eval_batch = {k: v.to(device) for k, v in eval_batch.items()}
            outputs = model(**eval_batch)
            total_eval_loss += outputs.loss.item()
    avg_eval_loss = total_eval_loss / len(eval_dataloader)
    print(f"Fin de Época {epoch + 1}: Pérdida de Validación Promedio = {avg_eval_loss:.4f}")
    
progress_bar.close()
print("Entrenamiento completado.")

Iniciando entrenamiento...


  0%|          | 0/1606 [00:00<?, ?it/s]


--- Época 1/2 ---


Época 1, Batch 199, Loss: 2.0674:  12%|█▏        | 199/1606 [00:20<02:24,  9.71it/s]

KeyboardInterrupt: 

In [14]:
all_predictions = []
all_true_answers = []
total_correct = 0
total_examples = 0

ds_test = ds_full.get('test')
filtered_data = ds_test.filter(
    lambda example: len(example['article']) < MAX_ARTICLE_CHAR_LENGTH,
    desc=f"Filtrando artículos en {split}"
)
with torch.no_grad(): 
    for example in filtered_data:
        options_str = "\n".join([f"{chr(65+i)}) {opt}" for i, opt in enumerate(example['options'])])
        prompt = (
                f"Context: {example['article']}\n\n"
                f"Questions: {example['question']}\n\n"
                f"Options:\n{options_str}\n\n"
                f"Answer:"
            )
        current_true_answer = example['answer']
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)

        outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=2, # Suficiente para generar la letra de la opción
                pad_token_id=tokenizer.pad_token_id
            )
        generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
        predicted_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        all_predictions.append(predicted_text)
        all_true_answers.append(current_true_answer)
        
        if predicted_text == current_true_answer:
            total_correct += 1
        total_examples += 1

In [11]:
accuracy = total_correct / total_examples
print(f"\n--- Resultados de la Evaluación en Test ---")
print(f"Ejemplos Totales: {total_examples}")
print(f"Predicciones Correctas: {total_correct}")
print(f"Exactitud (Accuracy): {accuracy * 100:.2f}%")


--- Resultados de la Evaluación en Test ---
Ejemplos Totales: 56
Predicciones Correctas: 19
Exactitud (Accuracy): 33.93%
