<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/TRANSFORMER_REASONING_MODEL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets -q
!pip install transformers -q
!pip install torch -q

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from tqdm import tqdm

# --- Hyperparameters ---
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
NUM_EPOCHS = 5
D_MODEL = 256  # Reduced for faster training on a smaller scale
NUM_HEADS = 8
NUM_LAYERS = 3  # Reduced for faster training
D_FF = 512
DROPOUT = 0.1
MAX_LEN = 128  # Maximum sequence length
WARMUP_STEPS = 1000
GRADIENT_CLIPPING = 1.0

# --- Load the GSM8k Dataset ---
gsm8k_dataset = load_dataset("gsm8k", "main")
train_dataset = gsm8k_dataset['train']
test_dataset = gsm8k_dataset['test']

# --- Vocabulary Creation ---
def build_vocabulary(examples):
    tokenizer = set()
    for example in examples:
        text = example['question'] + " " + example['answer']
        tokenizer.update(text.lower().split())
    return sorted(list(tokenizer))

vocabulary = build_vocabulary(train_dataset)
vocab_size = len(vocabulary)
word_to_index = {word: i for i, word in enumerate(vocabulary)}
index_to_word = {i: word for word, i in word_to_index.items()}

# Add special tokens
PAD_TOKEN = "<pad>"
START_TOKEN = "<start>"
END_TOKEN = "<end>"
UNK_TOKEN = "<unk>"
PAD_INDEX = 0
START_INDEX = vocab_size
END_INDEX = vocab_size + 1
UNK_INDEX = vocab_size + 2

word_to_index[PAD_TOKEN] = PAD_INDEX
word_to_index[START_TOKEN] = START_INDEX
word_to_index[END_TOKEN] = END_INDEX
word_to_index[UNK_TOKEN] = UNK_INDEX

index_to_word[PAD_INDEX] = PAD_TOKEN
index_to_word[START_INDEX] = START_TOKEN
index_to_word[END_INDEX] = END_TOKEN
index_to_word[UNK_INDEX] = UNK_TOKEN

updated_vocab_size = len(word_to_index)

# --- Data Processing Function ---
def process_example(example, max_len, word_to_index):
    question = example['question'].lower().split()
    answer = example['answer'].lower().split()

    question_tokens = [word_to_index.get(word, UNK_INDEX) for word in question]
    answer_tokens = [word_to_index.get(word, UNK_INDEX) for word in answer]

    src_tokens = [START_INDEX] + question_tokens + [END_INDEX]
    tgt_tokens = [START_INDEX] + answer_tokens + [END_INDEX]

    src_tokens = src_tokens[:max_len]
    tgt_tokens = tgt_tokens[:max_len]

    src_padding = [PAD_INDEX] * (max_len - len(src_tokens))
    tgt_padding = [PAD_INDEX] * (max_len - len(tgt_tokens))

    src_tensor = torch.tensor(src_tokens + src_padding)
    tgt_input_tensor = torch.tensor([START_INDEX] + answer_tokens[:max_len-1] + tgt_padding[:1]) # Input to decoder
    tgt_output_tensor = torch.tensor(answer_tokens[:max_len-1] + [END_INDEX] + tgt_padding[:1]) # Target for decoder

    return src_tensor, tgt_input_tensor, tgt_output_tensor

# --- Custom Dataset Class ---
class MathDataset(Dataset):
    def __init__(self, dataset, max_len, word_to_index):
        self.dataset = dataset
        self.max_len = max_len
        self.word_to_index = word_to_index

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        return process_example(example, self.max_len, self.word_to_index)

# --- Create DataLoaders ---
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    src_tensors, tgt_in_tensors, tgt_out_tensors = zip(*batch)
    # pad the sequences within the batch
    src_tensors = pad_sequence(src_tensors, batch_first=True, padding_value=PAD_INDEX)
    tgt_in_tensors = pad_sequence(tgt_in_tensors, batch_first=True, padding_value=PAD_INDEX)
    tgt_out_tensors = pad_sequence(tgt_out_tensors, batch_first=True, padding_value=PAD_INDEX)
    return src_tensors, tgt_in_tensors, tgt_out_tensors

train_dataloader = DataLoader(MathDataset(train_dataset, MAX_LEN, word_to_index), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(MathDataset(test_dataset, MAX_LEN, word_to_index), batch_size=BATCH_SIZE, collate_fn=collate_fn)

# --- Transformer Model Definition ---
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output, attn_probs

    def split_heads(self, x):
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q_ = self.split_heads(self.W_q(Q))
        K_ = self.split_heads(self.W_k(K))
        V_ = self.split_heads(self.W_v(V))

        output, attn_probs = self.scaled_dot_product_attention(Q_, K_, V_, mask)
        output = self.combine_heads(output)
        output = self.W_o(output)
        return output, attn_probs

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = PositionWiseFeedForward(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        norm1_output = self.norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(norm1_output)
        output = self.norm2(norm1_output + self.dropout(ffn_output))
        return output

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.masked_mha = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.enc_dec_mha = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = PositionWiseFeedForward(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        masked_attn_output, _ = self.masked_mha(x, x, x, tgt_mask)
        norm1_output = self.norm1(x + self.dropout(masked_attn_output))
        enc_dec_attn_output, _ = self.enc_dec_mha(norm1_output, enc_output, enc_output, src_mask)
        norm2_output = self.norm2(norm1_output + self.dropout(enc_dec_attn_output))
        ffn_output = self.ffn(norm2_output)
        output = self.norm3(norm2_output + self.dropout(ffn_output))
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1), :].transpose(0, 1)

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)
                                     for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, mask):
        embedded = self.dropout(self.pos_encoding(self.embedding(src)))
        for layer in self.layers:
            embedded = layer(embedded, mask)
        return embedded

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)
                                     for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, enc_output, src_mask, tgt_mask):
        embedded = self.dropout(self.pos_encoding(self.embedding(tgt)))
        for layer in self.layers:
            embedded = layer(embedded, enc_output, src_mask, tgt_mask)
        output = self.fc(embedded)
        return output

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)

    def make_src_mask(self, src):
        return (src != PAD_INDEX).unsqueeze(1).unsqueeze(2)

    def make_tgt_mask(self, tgt):
        tgt_len = tgt.size(1)
        attn_shape = (1, tgt_len, tgt_len)
        subsequent_mask = torch.tril(torch.ones(attn_shape, device=tgt.device)).type(torch.uint8) # Create subsequent_mask on the same device as tgt
        padding_mask = (tgt != PAD_INDEX).unsqueeze(1).unsqueeze(2)
        return subsequent_mask & padding_mask.bool()

    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        enc_output = self.encoder(src, src_mask)
        output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        return output

# --- Initialize Model, Optimizer, and Scheduler ---
model = Transformer(updated_vocab_size, updated_vocab_size, D_MODEL, NUM_LAYERS, NUM_HEADS, D_FF, DROPOUT, MAX_LEN)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_INDEX)

# --- Training Loop ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}")
    for batch_idx, (src, tgt_in, tgt_out) in progress_bar:
        src = src.to(device)
        tgt_in = tgt_in.to(device)
        tgt_out = tgt_out.to(device)

        optimizer.zero_grad()
        output = model(src, tgt_in)  # (batch_size, tgt_len, vocab_size)

        # Reshape for loss calculation
        output = output.view(-1, output.size(-1)) # (batch_size * tgt_len, vocab_size)
        tgt_out = tgt_out.view(-1) # (batch_size * tgt_len)

        loss = criterion(output, tgt_out)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIPPING)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

    # --- Evaluation Loop ---
    model.eval()
    eval_loss = 0

    # Add tqdm to the evaluation loop
    eval_progress_bar = tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc=f"Epoch {epoch+1} Evaluation")

    with torch.no_grad():
        for batch_idx, (src, tgt_in, tgt_out) in eval_progress_bar:
            src = src.to(device)
            tgt_in = tgt_in.to(device)
            tgt_out = tgt_out.to(device)

            output = model(src, tgt_in)
            output = output.view(-1, output.size(-1))
            tgt_out = tgt_out.view(-1)
            loss = criterion(output, tgt_out)
            eval_loss += loss.item()

    avg_eval_loss = eval_loss / len(test_dataloader)
    print(f"Epoch {epoch+1} Evaluation Loss: {avg_eval_loss:.4f}")

# --- Inference Function (Basic) ---
def translate_sentence(model, sentence, word_to_index, index_to_word, max_len, device):
    model.eval()
    tokens = [word_to_index.get(word.lower(), UNK_INDEX) for word in sentence.lower().split()]
    src_tokens = [START_INDEX] + tokens + [END_INDEX]
    src_tokens = src_tokens[:max_len]
    src_padding = [PAD_INDEX] * (max_len - len(src_tokens))
    src_tensor = torch.tensor(src_tokens + src_padding).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(src_tensor)

    memory = model.encoder(src_tensor, src_mask)
    tgt_tokens = [START_INDEX]
    for _ in range(max_len - 1):
        tgt_tensor = torch.tensor(tgt_tokens).unsqueeze(0).to(device)
        tgt_mask = model.make_tgt_mask(tgt_tensor)
        output = model.decoder(tgt_tensor, memory, src_mask, tgt_mask)
        pred_token = output.argmax(2)[:, -1].item()
        if pred_token == END_INDEX:
            break
        tgt_tokens.append(pred_token)

    translated_words = [index_to_word[token] for token in tgt_tokens if token not in [START_INDEX, END_INDEX, PAD_INDEX]]
    return " ".join(translated_words)

# --- Example Inference ---
if __name__ == '__main__':
    # Example question from the dataset
    sample_question = test_dataset[10]['question']
    actual_answer = test_dataset[10]['answer']

    translated_answer = translate_sentence(model, sample_question, word_to_index, index_to_word, MAX_LEN, device)

    print("\n--- Example Inference ---")
    print(f"Question: {sample_question}")
    print(f"Actual Answer: {actual_answer}")
    print(f"Predicted Answer: {translated_answer}")

    # Note: The model is likely not well-trained with these hyperparameters and few epochs.
    # The predicted answer will likely be poor without significant training.

Epoch 1: 100%|██████████| 234/234 [00:23<00:00,  9.94it/s, loss=9.87]


Epoch 1 Training Loss: 10.5158


Epoch 1 Evaluation: 100%|██████████| 42/42 [00:01<00:00, 32.12it/s]


Epoch 1 Evaluation Loss: 9.8002


Epoch 2: 100%|██████████| 234/234 [00:22<00:00, 10.31it/s, loss=8.79]


Epoch 2 Training Loss: 9.2945


Epoch 2 Evaluation: 100%|██████████| 42/42 [00:01<00:00, 32.61it/s]


Epoch 2 Evaluation Loss: 8.7138


Epoch 3: 100%|██████████| 234/234 [00:22<00:00, 10.39it/s, loss=7.25]


Epoch 3 Training Loss: 8.0429


Epoch 3 Evaluation: 100%|██████████| 42/42 [00:01<00:00, 32.70it/s]


Epoch 3 Evaluation Loss: 7.5185


Epoch 4: 100%|██████████| 234/234 [00:22<00:00, 10.42it/s, loss=6.94]


Epoch 4 Training Loss: 7.1583


Epoch 4 Evaluation: 100%|██████████| 42/42 [00:01<00:00, 32.44it/s]


Epoch 4 Evaluation Loss: 7.0628


Epoch 5: 100%|██████████| 234/234 [00:22<00:00, 10.31it/s, loss=6.65]


Epoch 5 Training Loss: 6.7686


Epoch 5 Evaluation: 100%|██████████| 42/42 [00:01<00:00, 32.39it/s]


Epoch 5 Evaluation Loss: 6.8315

--- Example Inference ---
Question: A new program had 60 downloads in the first month. The number of downloads in the second month was three times as many as the downloads in the first month, but then reduced by 30% in the third month. How many downloads did the program have total over the three months?
Actual Answer: The number of downloads of the program in the second month increased to 3*60 = <<3*60=180>>180
In the first two months, the total number of downloads of the program was 180+60 = <<180+60=240>>240
In the third month, the number of downloads of the program reduced by 30/100*180 = <<30/100*180=54>>54
There were 180-54 = <<180-54=126>>126 downloads in the third month.
In the three months, the total number of downloads of the program was 126+240 = <<126+240=366>>366
#### 366
Predicted Answer: the total of the number of the number of the number of the number of the number of the number of the number of the number of the number of the number of t