<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/TRANSFOMER_APRIL2025_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch tqdm -q

In [24]:
!pip install transformers torch tqdm nltk -q
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
import math
import numpy as np
import os
import json
from tqdm.notebook import tqdm
from warnings import filterwarnings
import torch.optim as optim
import random
import nltk
from nltk.corpus import wordnet
from sklearn.model_selection import train_test_split

nltk.download('wordnet')
filterwarnings("ignore")

# --- Configuration ---
tokenizer_name = "gpt2"
max_input_len = 128
max_output_len = 128  # Not used in this version
batch_size = 4
learning_rate = 5e-5
num_epochs = 10
emb_size = 512
nhead = 8
num_encoder_layers = 2
num_decoder_layers = 2  # Not used in this version
dim_feedforward = 512
dropout = 0.1
best_model_save_path = "./best_general_transformer.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 1. Tokenizer Setup ---
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
src_vocab_size = len(tokenizer)
tgt_vocab_size = len(tokenizer)  # Not used in this version
pad_token_id = tokenizer.pad_token_id
bos_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id

# --- 2. Dataset Preparation ---
class SimpleTextDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_len, augment=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.augment = augment  # Not used in this version

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        context = item.get("context", "")
        question = item["question"]
        choices = item["choices"]
        answer_label = item["answer"]

        # Format input text
        input_text = f"{context} {question} {self.format_choices(choices)}"

        # Tokenization
        input_tokens = self.tokenizer.encode_plus(
            input_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_input_len,
            return_tensors="pt",
        )
        input_ids = input_tokens["input_ids"].squeeze(0)
        attention_mask = input_tokens["attention_mask"].squeeze(0)  # Not used in this version

        # Tokenize answer label
        output_tokens = self.tokenizer.encode_plus(
            answer_label,
            padding="max_length",
            truncation=True,
            max_length=1,
            return_tensors="pt",
        )
        output_ids = output_tokens["input_ids"].squeeze(0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,  # Not used in this version
            "output_ids": output_ids,
        }

    def format_choices(self, choices):
        formatted_choices = ""
        for label, choice_text in choices.items():
            formatted_choices += f"{label}: {choice_text} "
        return formatted_choices

# --- Data ---
data = [
    {
        "id": "q1",
        "context": "Paris is known for its iconic Eiffel Tower and Louvre Museum.",
        "question": "What is the capital of France?",
        "choices": {
            "A": "Berlin",
            "B": "Paris",
            "C": "Rome"
        },
        "answer": "B",
        "task": "multiple_choice"
    },
    {
        "id": "q2",
        "context": "Mount Everest is the highest mountain above sea level.",
        "question": "Which mountain is the tallest in the world?",
        "choices": {
            "A": "K2",
            "B": "Mount Kilimanjaro",
            "C": "Mount Everest"
        },
        "answer": "C",
        "task": "multiple_choice"
    },
    {
        "id": "q3",
        "question": "What is the chemical symbol for gold?",
        "choices": {
            "A": "Ag",
            "B": "Au",
            "C": "Fe"
        },
        "answer": "B",
        "task": "multiple_choice"
    },
    {
        "id": "q4",
        "context": "The Amazon rainforest is the largest rainforest on Earth.",
        "question": "Where is the Amazon rainforest primarily located?",
        "choices": {
            "A": "Africa",
            "B": "South America",
            "C": "Asia"
        },
        "answer": "B",
        "task": "multiple_choice"
    },
    {
        "id": "q5",
        "question": "What is the smallest planet in our solar system?",
        "choices": {
            "A": "Mercury",
            "B": "Earth",
            "C": "Mars"
        },
        "answer": "A",
        "task": "multiple_choice"
    },
    {
        "id": "q6",
        "context": "Light travels faster than sound.",
        "question": "Which travels faster, light or sound?",
        "choices": {
            "A": "Sound",
            "B": "Light",
            "C": "They travel at the same speed"
        },
        "answer": "B",
        "task": "multiple_choice"
    },
    {
        "id": "q7",
        "question": "What is the largest ocean on Earth?",
        "choices": {
            "A": "Atlantic Ocean",
            "B": "Indian Ocean",
            "C": "Pacific Ocean"
        },
        "answer": "C",
        "task": "multiple_choice"
    },
    {
        "id": "q8",
        "context": "Shakespeare wrote many famous plays, including Hamlet and Romeo and Juliet.",
        "question": "Who wrote the play Hamlet?",
        "choices": {
            "A": "Charles Dickens",
            "B": "William Shakespeare",
            "C": "Jane Austen"
        },
        "answer": "B",
        "task": "multiple_choice"
    },
    {
        "id": "q9",
        "question": "What is the capital of Japan?",
        "choices": {
            "A": "Beijing",
            "B": "Seoul",
            "C": "Tokyo"
        },
        "answer": "C",
        "task": "multiple_choice"
    },
    {
        "id": "q10",
        "context": "The human heart is a vital organ that pumps blood.",
        "question": "What is the main function of the human heart?",
        "choices": {
            "A": "Digestion",
            "B": "Pumping blood",
            "C": "Respiration"
        },
        "answer": "B",
        "task": "multiple_choice"
    },
    {
        "id": "q11",
        "question": "What is the largest country in the world by land area?",
        "choices": {
            "A": "China",
            "B": "Russia",
            "C": "United States"
        },
        "answer": "B",
        "task": "multiple_choice"
    }
]



# --- Split Data ---
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

# --- Create Datasets and DataLoaders ---
train_dataset = SimpleTextDataset(train_data, tokenizer, max_input_len)
val_dataset = SimpleTextDataset(val_data, tokenizer, max_input_len)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# --- 3. Model Definition ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class GeneralTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, emb_size: int, nhead: int, src_vocab_size: int,
                 num_choices: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 max_text_len: int = 128):
        super().__init__()
        self.emb_size = emb_size
        self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
        self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout, max_len=max_text_len)

        self.transformer = nn.Transformer(d_model=emb_size,
                                          nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)
        self.output_linear = nn.Linear(emb_size, num_choices)

    def forward(self, src_input_ids: torch.Tensor, src_padding_mask: torch.Tensor):
        src_emb = self.pos_encoder(self.src_tok_emb(src_input_ids))
        encoder_output = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)

        # Get the output from the [CLS] token (or the first token)
        cls_output = encoder_output[:, 0, :]

        predicted_labels = self.output_linear(cls_output)
        return predicted_labels

def create_mask(src_input_ids, pad_idx, device):
    src_padding_mask = (src_input_ids == pad_idx)
    return src_padding_mask

# --- 4. Model, Loss, Optimizer, Scheduler ---
model = GeneralTransformer(
    num_encoder_layers=num_encoder_layers,
    emb_size=emb_size,
    nhead=nhead,
    src_vocab_size=src_vocab_size,
    num_choices=3,  # Number of choices (A, B, C)
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    max_text_len=max_input_len
).to(device)

loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_epochs * len(train_dataloader)),
    num_training_steps=num_epochs * len(train_dataloader),
)

# --- 5. Training Loop ---

# --- Helper Functions for Label Conversion ---
def choice_label_to_index(label):
    return ord(label) - ord('A')

def index_to_choice_label(index):
    return chr(index + ord('A'))



best_val_loss = float('inf')
patience = 3
epochs_without_improvement = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        input_ids = batch["input_ids"].to(device)
        output_ids = batch["output_ids"].to(device)

        src_padding_mask = create_mask(input_ids, tokenizer.pad_token_id, device)

        optimizer.zero_grad()

        predicted_labels = model(input_ids, src_padding_mask)

        # Adjust output_ids to be class indices (0, 1, 2)
        #output_ids = torch.tensor([choice_label_to_index(label.item()) for label in output_ids], device=device)

        # Adjust output_ids to be class indices (0, 1, 2)
        output_ids = torch.tensor([choice_label_to_index(tokenizer.decode(label.item())) for label in output_ids], device=device)

        loss = loss_fn(predicted_labels, output_ids)
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch["input_ids"].to(device)
            output_ids = batch["output_ids"].to(device)

            src_padding_mask = create_mask(input_ids, tokenizer.pad_token_id, device)
            predicted_labels = model(input_ids, src_padding_mask)

            # Adjust output_ids for validation as well
            # Instead of using label.item(), decode the token ID to get the choice label
            output_ids = torch.tensor([choice_label_to_index(tokenizer.decode(label.item())) for label in output_ids[:, 0]], device=device)

            loss = loss_fn(predicted_labels, output_ids)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")

    # Early Stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), best_model_save_path)
        print(f"Model saved to: {best_model_save_path}")
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= patience:
        print("Early stopping triggered.")
        break

# --- 6. Inference ---
def generate_response(model, tokenizer, input_text, device):
    model.eval()
    input_tokens = tokenizer.encode_plus(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=max_input_len,
        return_tensors="pt"
    ).to(device)
    input_ids = input_tokens["input_ids"]

    src_padding_mask = create_mask(input_ids, tokenizer.pad_token_id, device)

    with torch.no_grad():
        predicted_labels = model(input_ids, src_padding_mask)

    # Get predicted choice label
    predicted_index = torch.argmax(predicted_labels, dim=1).item()
    predicted_label = index_to_choice_label(predicted_index)

    return predicted_label

# --- Helper Functions for Label Conversion ---
def choice_label_to_index(label):
    return ord(label) - ord('A')

def index_to_choice_label(index):
    return chr(index + ord('A'))

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Epoch 1/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1/10, Average Loss: 1.3292
Epoch 1/10, Validation Loss: 0.9927
Model saved to: ./best_general_transformer.pth


Epoch 2/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 2/10, Average Loss: 1.0501
Epoch 2/10, Validation Loss: 0.7630
Model saved to: ./best_general_transformer.pth


Epoch 3/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 3/10, Average Loss: 0.8251
Epoch 3/10, Validation Loss: 0.7152
Model saved to: ./best_general_transformer.pth


Epoch 4/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 4/10, Average Loss: 0.8646
Epoch 4/10, Validation Loss: 0.6970
Model saved to: ./best_general_transformer.pth


Epoch 5/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 5/10, Average Loss: 0.8148
Epoch 5/10, Validation Loss: 0.6426
Model saved to: ./best_general_transformer.pth


Epoch 6/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 6/10, Average Loss: 0.9034
Epoch 6/10, Validation Loss: 0.6275
Model saved to: ./best_general_transformer.pth


Epoch 7/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 7/10, Average Loss: 0.7576
Epoch 7/10, Validation Loss: 0.6222
Model saved to: ./best_general_transformer.pth


Epoch 8/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 8/10, Average Loss: 0.6467
Epoch 8/10, Validation Loss: 0.6305


Epoch 9/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 9/10, Average Loss: 0.6926
Epoch 9/10, Validation Loss: 0.6255


Epoch 10/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 10/10, Average Loss: 0.6710
Epoch 10/10, Validation Loss: 0.6218
Model saved to: ./best_general_transformer.pth


In [27]:
# --- Load and Use the Model ---
model.load_state_dict(torch.load(best_model_save_path))
model.eval()

# Example inference on a new question
new_question = {
    "context": "The sun is a star.",
    "question": "What is the sun?",
    "choices": {
        "A": "A planet",
        "B": "A star",
        "C": "A moon"
    }
}

input_text = f"{new_question['context']} {new_question['question']} {train_dataset.format_choices(new_question['choices'])}"

predicted_label = generate_response(model, tokenizer, input_text, device)

print("\n--- Inference Results ---")
print('\n')
print(f"Input Text: {input_text}")
print(f"Predicted answer: {predicted_label}")


--- Inference Results ---


Input Text: The sun is a star. What is the sun? A: A planet B: A star C: A moon 
Predicted answer: B
