# Autoregresión con RNN (“Mini-GPT”) en PyTorch

Este notebook ilustra un modelo autoregresivo básico: toma un contexto de n palabras y predice la siguiente, repitiendo.

## 1. Instalación

In [None]:
!pip install datasets tokenizers --quiet
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --quiet

## 2. Preparación del corpus

In [None]:
import requests
text = requests.get("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").text.lower()
words = text.split()
from collections import Counter
vocab_list = [w for w,_ in Counter(words).most_common(5000)]
vocab = {w:i for i,w in enumerate(vocab_list)}
unk = len(vocab_list)
data = [vocab.get(w, unk) for w in words]
print(f"Corpus length: {len(words)} words, Vocab size: {len(vocab_list)+1}")

In [None]:
from datasets import load_dataset
import torch
from collections import Counter

# Load WikiText dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
texts = dataset["train"]["text"]

# Process and tokenize text
words = []
for text in texts:
    if text.strip():  # Skip empty lines
        words.extend(text.lower().split())

# Create vocabulary
vocab_list = [w for w,_ in Counter(words).most_common(8000)]  # Increased vocab size
vocab = {w:i for i,w in enumerate(vocab_list)}
unk = len(vocab_list)
data = [vocab.get(w, unk) for w in words]
print(f"Corpus length: {len(words)} words, Vocab size: {len(vocab_list)+1}")

## 3. Dataset y DataLoader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class NextWordDataset(Dataset):
    def __init__(self, data, seq_len=15):  # Increased sequence length
        self.data, self.seq_len = data, seq_len
    def __len__(self):
        return len(self.data) - self.seq_len
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx : idx+self.seq_len])
        y = torch.tensor(self.data[idx+1 : idx+self.seq_len+1])
        return x, y

ds = NextWordDataset(data, seq_len=15)
dl = DataLoader(ds, batch_size=128, shuffle=True)  # Increased batch size
print(f"Dataset size: {len(ds)} samples")

## 4. Definición del modelo RNN

In [None]:
import torch.nn as nn

class MiniRNN(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, hid_dim=256): # Aumentadas de 32,64 a 128,256
        super().__init__()
        self.emb = nn.Embedding(vocab_size+1, emb_dim)
        self.rnn = nn.RNN(emb_dim, hid_dim, num_layers=2) # Añadir más capas
        self.fc = nn.Linear(hid_dim, vocab_size+1)
    def forward(self, x):
        e = self.emb(x)           # (B, L, E)
        out, _ = self.rnn(e)      # (B, L, H)
        return self.fc(out)       # (B, L, V)

print(MiniRNN(len(vocab_list), 32, 64))

# 1. Usar LSTM en lugar de RNN simple
class MiniLSTM(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, hid_dim=512):
        super().__init__()
        self.emb = nn.Embedding(vocab_size+1, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, 
                           num_layers=3, 
                           dropout=0.2)
        self.fc = nn.Linear(hid_dim, vocab_size+1)
    
    def forward(self, x):
        e = self.emb(x)
        out, _ = self.lstm(e)
        return self.fc(out)

## 5. Entrenamiento rápido (1–2 épocas)

In [None]:
import torch
torch.cuda.is_available()

In [None]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MiniLSTM(len(vocab_list), emb_dim=256, hid_dim=512).to(device)  # Larger model
opt = optim.Adam(model.parameters(), lr=3e-4)  # Adjusted learning rate
loss_fn = nn.CrossEntropyLoss()

epochs = 20  # More epochs for larger dataset
best_loss = float('inf')
patience = 2
no_improve = 0

for epoch in range(epochs):
    total = 0
    model.train()
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
        opt.zero_grad()
        loss.backward() 
        opt.step()
        total += loss.item()
    
    avg_loss = total/len(dl)
    print(f"Época {epoch+1}, pérdida: {avg_loss:.3f}")
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping!")
            break

## 6. Función de generación

In [None]:
def generate(model, seed_words, length=20):
    model.eval()
    idxs = [vocab.get(w, unk) for w in seed_words.split()]
    ctx  = torch.tensor(idxs[-5:])[None].to(device)
    out_words = seed_words.split()
    for _ in range(length):
        logits = model(ctx)
        next_id = logits[0, -1].argmax().item()
        out_words.append(vocab_list[next_id] if next_id < len(vocab_list) else "<unk>")
        ctx = torch.tensor([ [vocab.get(w, unk) for w in out_words[-5:]] ]).to(device)
    return " ".join(out_words)

print(generate(model, "to be or not to be", length=15))

In [None]:
def generate(model, seed_words, length=30, temperature=0.8):  # Increased length
    model.eval()
    idxs = [vocab.get(w, unk) for w in seed_words.split()]
    ctx = torch.tensor(idxs[-15:])[None].to(device)  # Increased context window
    out_words = seed_words.split()
    
    for _ in range(length):
        logits = model(ctx)
        probs = torch.softmax(logits[0, -1] / temperature, dim=0)
        next_id = torch.multinomial(probs, 1).item()
        out_words.append(vocab_list[next_id] if next_id < len(vocab_list) else "<unk>")
        ctx = torch.tensor([[vocab.get(w, unk) for w in out_words[-15:]]]).to(device)
    
    return " ".join(out_words)

# Test different temperatures and prompts
print("T=0.7:", generate(model, "the government announced that", temperature=0.7))
print("T=1.0:", generate(model, "in the middle of", temperature=1.0))
print("T=1.2:", generate(model, "scientists have discovered", temperature=1.2))
print("T=1.2:", generate(model, "to be or not to be", temperature=1.2))

## 7. Experimenta en clase

1. Cambia `seq_len` en el dataset (por ejemplo 3 o 10) y observa generación.
2. Prueba muestreo en lugar de `argmax` (sample sobre softmax).
3. Ajusta tasas de aprendizaje o emb_dim/hid_dim y comenta diferencias breves.