<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/TRANSFOMER_APRIL2025_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch tqdm -q

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import math
import numpy as np
import os
import json
from tqdm.notebook import tqdm
from warnings import filterwarnings
import torch.optim as optim

filterwarnings("ignore")

# --- Configuration ---
tokenizer_name = "gpt2"  # Or a more suitable general-purpose tokenizer
max_input_len = 128
max_output_len = 128
batch_size = 4
learning_rate = 1e-4
num_epochs = 20
emb_size = 256
nhead = 8
num_encoder_layers = 4
num_decoder_layers = 4
dim_feedforward = 1024
dropout = 0.1
best_model_save_path = "./best_general_transformer.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 1. Tokenizer Setup ---
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
src_vocab_size = len(tokenizer)
tgt_vocab_size = len(tokenizer)  # Output vocabulary is the same
pad_token_id = tokenizer.pad_token_id
bos_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id

# --- 2. Dataset Preparation ---
class SimpleTextDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_len, max_output_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_output_len = max_output_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = item["input"]
        output_text = item["output"]

        # Tokenize input
        input_tokens = self.tokenizer.encode_plus(
            input_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_input_len,
            return_tensors="pt",
        )
        input_ids = input_tokens["input_ids"].squeeze(0)
        attention_mask = input_tokens["attention_mask"].squeeze(0)

        # Tokenize output (add BOS and EOS)
        output_tokens = self.tokenizer.encode_plus(
            f"{self.tokenizer.bos_token}{output_text}{self.tokenizer.eos_token}",
            padding="max_length",
            truncation=True,
            max_length=self.max_output_len,
            return_tensors="pt",
        )
        output_ids = output_tokens["input_ids"].squeeze(0)
        output_attention_mask = output_tokens["attention_mask"].squeeze(0)

        # Prepare target (shifted output for language modeling)
        output_ids_y = output_ids[1:].clone()
        output_ids_y[output_ids[:-1] == self.tokenizer.pad_token_id] = self.tokenizer.pad_token_id  # Changed to pad_token_id
        output_ids_y[output_ids_y == self.tokenizer.pad_token_id] = self.tokenizer.pad_token_id  # Changed to pad_token_id
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "output_ids": output_ids[:-1],  # Decoder input (without EOS)
            "output_ids_y": output_ids_y,  # Target output (shifted)
            "output_attention_mask": output_attention_mask[:-1]
        }

# This is a *conceptual* representation. You'd need a LOT more data.
data = [
    {"input": "I want to visit Paris tomorrow morning.", "output": "{\"location\": \"Paris\", \"time\": \"tomorrow morning\"}"},
    {"input": "Explore London in the evening.", "output": "{\"location\": \"London\", \"time\": \"evening\"}"},
    {"input": "See New York today.", "output": "{\"location\": \"New York\", \"time\": \"today\"}"},
    {"input": "What is the weather like?", "output": "The weather is currently sunny."},
    # ... MANY more diverse examples ...
    {"input": "Translate 'hello' to French", "output": "Bonjour"},
    {"input": "Summarize this document...", "output": "...summary..."},
    {"input": "Write a short story about a cat", "output": "...story..."},

]

dataset = SimpleTextDataset(data, tokenizer, max_input_len, max_output_len)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# --- 3. Model Definition ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class GeneralTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, nhead: int, src_vocab_size: int,
                 tgt_vocab_size: int,  # Output vocabulary size (for text)
                 dim_feedforward: int = 2048, dropout: float = 0.1,
                 max_text_len: int = 128, max_output_len: int = 128):  # Max output length
        super().__init__()
        self.emb_size = emb_size
        self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size)  # Output embeddings
        self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout, max_len=max_text_len)
        self.pos_encoder_dec = PositionalEncoding(emb_size, dropout=dropout, max_len=max_output_len)
        self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)  # Add batch_first=True
        self.output_linear = nn.Linear(emb_size, tgt_vocab_size)  # Linear layer to predict words

    def forward(self, src_input_ids: torch.Tensor, tgt_input_ids: torch.Tensor,
                src_mask: torch.Tensor, tgt_mask: torch.Tensor,
                src_padding_mask: torch.Tensor, tgt_padding_mask: torch.Tensor,
                memory_key_padding_mask: torch.Tensor):
        src_emb = self.pos_encoder(self.src_tok_emb(src_input_ids))
        tgt_emb = self.pos_encoder_dec(self.tgt_tok_emb(tgt_input_ids))
        memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)
        decoder_output = self.transformer.decoder(tgt_emb, memory,
                                                 tgt_mask=tgt_mask,
                                                 memory_key_padding_mask=memory_key_padding_mask,
                                                 tgt_key_padding_mask=tgt_padding_mask)
        predicted_tokens = self.output_linear(decoder_output)  # Predict words
        return predicted_tokens

def create_mask(src_input_ids, tgt_input_ids, pad_idx, device):
    src_seq_len = src_input_ids.shape[1]
    tgt_seq_len = tgt_input_ids.shape[1]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_padding_mask = (src_input_ids == pad_idx)
    tgt_padding_mask = (tgt_input_ids == pad_idx)
    memory_key_padding_mask = src_padding_mask
    return tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask

def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
    """Generates an upper-triangular matrix of -inf, used to mask future positions."""
    return torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1)

# --- 4. Model, Loss, Optimizer ---
model = GeneralTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    emb_size=emb_size,
    nhead=nhead,
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    max_text_len=max_input_len,
    max_output_len=max_output_len
).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# --- 5. Training Loop ---
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        input_ids = batch["input_ids"].to(device)
        output_ids = batch["output_ids"].to(device)
        output_ids_y = batch["output_ids_y"].to(device)

        # Prepare target (shifted output for language modeling) - Moved inside the training loop
        # output_ids_y = output_ids[1:].clone()  # Redundant, already done in __getitem__
        # output_ids_y[output_ids[:-1] == tokenizer.pad_token_id] = tokenizer.pad_token_id
        # output_ids_y[output_ids_y == tokenizer.pad_token_id] = tokenizer.pad_token_id

        src_mask = torch.zeros((input_ids.size(0), input_ids.size(1)), device=device).bool()
        tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = create_mask(input_ids, output_ids, tokenizer.pad_token_id, device)
        optimizer.zero_grad()
        predicted_tokens = model(input_ids, output_ids, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        loss = loss_fn(predicted_tokens.reshape(-1, predicted_tokens.size(-1)), output_ids_y.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# --- 6. Save Model ---
torch.save(model.state_dict(), best_model_save_path)
print(f"Model saved to: {best_model_save_path}")

# --- 7. Inference (Example) ---
def generate_response(model, tokenizer, input_text, device, max_output_len=128):
    model.eval()
    input_tokens = tokenizer.encode_plus(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=max_input_len,
        return_tensors="pt"
    ).to(device)
    input_ids = input_tokens["input_ids"]
    attention_mask = input_tokens["attention_mask"]
    # Start with BOS token
    output_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long, device=device)
    for _ in range(max_output_len):
        tgt_mask = generate_square_subsequent_mask(output_ids.size(1), device)
        src_mask = torch.zeros((input_ids.size(0), input_ids.size(1)), device=device).bool()
        tgt_padding_mask = (output_ids == tokenizer.pad_token_id)
        src_padding_mask = (input_ids == tokenizer.pad_token_id)
        memory_key_padding_mask = src_padding_mask
        with torch.no_grad():
            predicted_tokens = model(input_ids, output_ids, src_mask, tgt_mask,
                                    src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        next_token_id = torch.argmax(predicted_tokens[:, -1, :], dim=-1)
        next_token = next_token_id.unsqueeze(0)
        output_ids = torch.cat([output_ids, next_token], dim=1)
        if next_token_id == tokenizer.eos_token_id:
            break
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return output_text

# Load the trained model
model.load_state_dict(torch.load(best_model_save_path))
model.eval()

# Example Usage
queries = [
    "I want to visit Paris tomorrow morning.",
    "Explore London in the evening.",
    "See New York today.",
    "What is the weather like?"
]

for query in queries:
    response = generate_response(model, tokenizer, query, device, max_output_len=64)
    print(f"Query: {query}\nResponse: {response}\n")