# Libraries/Imports

In [None]:
%pip install numpy
%pip install pandas
%pip install matplotlib
%pip install datasets
%pip install kagglehub
%pip install transformers
%pip install accelerate
%pip install latex2sympy
%pip install --upgrade torch torchvision torchaudio
%pip install --upgrade torchtext

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [None]:
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt

In [None]:
def print_rows(dataset, start=0, end=10, split="", feature=""):
  if split == "":
    for split in dataset:
      print(f"Entries {start+1} - {end} of the {split} data:")
      for i in range(start, end):
        if feature:  # If feature is provided
          print(f"{feature}: {dataset[split][i][feature]}")  # Print the specified feature
        else:
          print(dataset[split][i])  # Print the entire row if no feature is specified
      print("-" * 20)
  else:
    print(f"Entries {start+1} - {end} of the {split} data:")
    for i in range(start, end):
      if feature:  # If feature is provided
        print(f"{feature}: {dataset[split][i][feature]}")  # Print the specified feature
      else:
        print(dataset[split][i])  # Print the entire row if no feature is specified
    print("-" * 20)

# Preprocessing

In [None]:
# Preprocess COT dataset
from datasets import load_dataset

cot_ds = load_dataset("AI-MO/NuminaMath-CoT")

cot_ds['train'] = cot_ds['train'].remove_columns(['messages'])
cot_ds['test'] = cot_ds['test'].remove_columns(['messages'])
cot_ds['train'] = cot_ds['train'].remove_columns(['source'])
cot_ds['test'] = cot_ds['test'].remove_columns(['source'])

# Remove chinese characters from COT dataset
import re

def contains_chinese(text):
    # match Chinese characters
    pattern = re.compile(r'[\u4e00-\u9fff\u2e80-\u2eff\u31c0-\u31ef\uff00-\uffef]')
    return bool(pattern.search(text))

def filter_entries(dataset, fields):
    # Filter out entries that contain Chinese characters
    filtered_dataset = dataset.filter(lambda example: not any(contains_chinese(example[field]) for field in fields))
    return filtered_dataset

# remove entries with Chinese characters
fields_to_check = ['problem', 'solution']
cot_ds['train'] = filter_entries(cot_ds['train'], fields_to_check)
cot_ds['test'] = filter_entries(cot_ds['test'], fields_to_check)
print(cot_ds)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/2.68k [00:00<?, ?B/s]

train-00000-of-00005.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

train-00001-of-00005.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

train-00002-of-00005.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

train-00003-of-00005.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

train-00004-of-00005.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/166k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/859494 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Filter:   0%|          | 0/859494 [00:00<?, ? examples/s]

Filter:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['problem', 'solution'],
        num_rows: 850151
    })
    test: Dataset({
        features: ['problem', 'solution'],
        num_rows: 100
    })
})


In [None]:
# Preprocess MATH dataset (load all json files into into Dataset object)

import kagglehub

# Download latest version
path = kagglehub.dataset_download("mathurinache/math-dataset")

print("Path to dataset files:", path)

import os
from datasets import Dataset, DatasetDict

def load_json_files(data_dir):
    """Loads JSON files from a directory into a Dataset."""
    all_data = []
    problems = 0
    for subdir in os.listdir(data_dir):
      subdir_path = os.path.join(data_dir, subdir)
      for filename in os.listdir(subdir_path):
        if filename.endswith(".json"):
          problems += 1
          filepath = os.path.join(subdir_path, filename)
          with open(filepath, "r") as f:
            all_data.append(json.load(f))
    # Create a Pandas DataFrame to easily convert into a Dataset\
    print(f"Loaded {problems} problems.")
    return all_data

# Assuming 'path' is from kagglehub.dataset_download
math_dir = os.path.join(path, "MATH")
train_dir = os.path.join(math_dir, "train")
test_dir = os.path.join(math_dir, "test")

train_data = load_json_files(train_dir)
test_data = load_json_files(test_dir)

# Convert the train and test data into Dataset objects
train_dataset = Dataset.from_dict({
    "problem": [item["problem"] for item in train_data],
    # "level": [item["level"] for item in train_data],
    # "type": [item["type"] for item in train_data],
    "solution": [item["solution"] for item in train_data]
})

test_dataset = Dataset.from_dict({
    "problem": [item["problem"] for item in test_data],
    # "level": [item["level"] for item in test_data],
    # "type": [item["type"] for item in test_data],
    "solution": [item["solution"] for item in test_data]
})

math_ds = DatasetDict({
    "train": train_dataset,
    "test": test_dataset
})

print(math_ds)


Downloading from https://www.kaggle.com/api/v1/datasets/download/mathurinache/math-dataset?dataset_version_number=1...


100%|██████████| 7.07M/7.07M [00:00<00:00, 60.7MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/mathurinache/math-dataset/versions/1
Loaded 7500 problems.
Loaded 5000 problems.
DatasetDict({
    train: Dataset({
        features: ['problem', 'solution'],
        num_rows: 7500
    })
    test: Dataset({
        features: ['problem', 'solution'],
        num_rows: 5000
    })
})


In [None]:
# Concatenate and split datasets
from datasets import concatenate_datasets

# Make validation dataset
train_valid_split = cot_ds['train'].train_test_split(test_size=0.1)
cot_ds['train'] = train_valid_split['train']
cot_ds['test'] = train_valid_split['test']

train_valid_split = cot_ds['train'].train_test_split(test_size=0.12)
cot_ds['train'] = train_valid_split['train']
cot_ds['validation'] = train_valid_split['test']

# Add MATH dataset as test dataset
merged_math = concatenate_datasets([math_ds['train'], math_ds['test']])
cot_ds['test'] = concatenate_datasets([cot_ds['test'], merged_math])

ds = cot_ds
print(ds)

del cot_ds
del math_ds

print()
print("Split")
print("train:", len(ds['train']) / ( len(ds['train']) + len(ds['validation']) + len(ds['test']) ))
print("test:", len(ds['test']) / ( len(ds['train']) + len(ds['validation']) + len(ds['test']) ))
print("validation:", len(ds['validation']) / ( len(ds['train']) + len(ds['validation']) + len(ds['test']) ))

DatasetDict({
    train: Dataset({
        features: ['problem', 'solution'],
        num_rows: 673318
    })
    test: Dataset({
        features: ['problem', 'solution'],
        num_rows: 97516
    })
    validation: Dataset({
        features: ['problem', 'solution'],
        num_rows: 91817
    })
})

Split
train: 0.7805219028320839
test: 0.11304223840232029
validation: 0.10643585876559582


# Embedding

In [None]:
ds200 = DatasetDict({
    split: dataset.select(range(200))
    for split, dataset in ds.items()
})
print(ds200)

DatasetDict({
    train: Dataset({
        features: ['problem', 'solution'],
        num_rows: 200
    })
    test: Dataset({
        features: ['problem', 'solution'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['problem', 'solution'],
        num_rows: 200
    })
})


In [None]:
# Generate Embeddings

from transformers import AutoTokenizer, AutoModel
import torch

# Load MathBERT
model_name = "tbs17/MathBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Function to compute embeddings
def compute_embeddings(batch):
    # Tokenize the problem text
    inputs = tokenizer(batch["problem"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")

    # Move input tensors to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    # Get embeddings
    with torch.no_grad():
        outputs = model(**inputs)

    # Use the [CLS] token embedding (typically the first token)
    cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()

    # Return embeddings
    return {"embeddings": cls_embeddings}

# Apply the function to each split of the dataset
def process_dataset(dataset_dict):
    for split in dataset_dict:
        dataset_dict[split] = dataset_dict[split].map(compute_embeddings, batched=True, batch_size=64)
    return dataset_dict


# Add embeddings to the dataset
ds200 = process_dataset(ds200)

# Save embeddings or access them
torch.save(model.state_dict(), 'mathbert_weights.pth')
# model.load_state_dict(torch.load('mathbert_weights.pth'))

print(ds200)


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/441M [00:00<?, ?B/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['problem', 'solution', 'embeddings'],
        num_rows: 200
    })
    test: Dataset({
        features: ['problem', 'solution', 'embeddings'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['problem', 'solution', 'embeddings'],
        num_rows: 200
    })
})


# ENCODER


In [None]:
print(ds200['train'][0])

{'problem': 'Given the sequence $0,1,1,2,3,5,8, \\cdots$, where starting from the 3rd term each term is the sum of the two preceding ones, is there any term among the first 100,000,001 terms of this sequence whose last 4 digits are all zeros?', 'solution': '\n1. We start with the Fibonacci sequence defined as:\n   \\[\n   0, 1, 1, 2, 3, 5, 8, \\ldots\n   \\]\n   where each term from the third term onward is the sum of the two preceding terms.\n\n2. We denote this sequence by \\(a_{1}, a_{2}, \\ldots, a_{n}, \\ldots\\).\n\n3. For \\(n = 1, 2, \\ldots\\), we define a new sequence \\( b_n \\) as follows:\n   \\[\n   b_n = \n   \\begin{cases} \n   a_n & \\text{if } a_n < 10^3 \\\\\n   \\text{the last four digits of } a_n & \\text{if } a_n \\geq 10^3 \n   \\end{cases}\n   \\]\n\n4. Clearly, \\(0 \\leq b_n < 10^4\\).\n\n5. Using modular arithmetic, this sequence satisfies:\n   \\[\n   b_{n} \\equiv b_{n-1} + b_{n-2} \\pmod{10^4} \\text{ for } n = 3, 4, \\ldots\n   \\]\n\n6. Considering the s

In [None]:
train_embeddings = ds200["train"]["embeddings"]
print(train_embeddings[1])

print(len(train_embeddings[1]))


[-1.949852705001831, 0.35216882824897766, 1.208216667175293, -1.0962738990783691, -0.5237234830856323, -0.8230811357498169, 0.7726143002510071, 0.26973316073417664, -0.8318825960159302, -0.7510274648666382, 1.2476102113723755, 2.9874863624572754, -0.6495546698570251, -1.4871567487716675, -1.617428183555603, -0.6537545323371887, -0.19805064797401428, 1.1435359716415405, 0.3574502766132355, 1.137886881828308, -0.6283408999443054, 1.2387630939483643, -0.6004272103309631, 1.5073366165161133, -0.2719491720199585, -0.33186376094818115, 1.932119369506836, 0.28131726384162903, 0.9268313050270081, 0.13414983451366425, -0.6114510297775269, 2.004671573638916, 0.6277562379837036, -0.3652656078338623, 0.3929366171360016, -1.7907633781433105, -0.05376959964632988, -1.0249866247177124, 0.6098984479904175, -0.1702447235584259, 0.49612218141555786, -3.0862550735473633, -0.6346583366394043, -0.5617294907569885, 0.549155056476593, 0.5042513012886047, -1.2283825874328613, 0.11118858307600021, -1.291460633

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Union, Tuple


In [None]:
class CustomConfig:
    def __init__(
        self,
        d_model=768,      # Hidden size
        num_heads=8,       # Number of attention heads
        d_ff= 3072,         # Feedforward network size
        num_encoder_layers=6,  # Number of encoder layers
        dropout_rate=0.1,      # Dropout rate
        attention_dropout_rate=0.1,  # Attention dropout rate
        layer_norm_eps=1e-6,   # Layer norm epsilon for stability
    ):
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.num_encoder_layers = num_encoder_layers
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate
        self.layer_norm_eps = layer_norm_eps

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"


In [None]:
train_embedding = torch.tensor(ds200['train'][1]["embeddings"])
train_embedding = train_embedding.unsqueeze(0).unsqueeze(1)

print(train_embedding)
print(train_embedding.shape)
# need to convert embeddings into tensor-> mb add nn.embedding as the first layer

tensor([[[-1.9499e+00,  3.5217e-01,  1.2082e+00, -1.0963e+00, -5.2372e-01,
          -8.2308e-01,  7.7261e-01,  2.6973e-01, -8.3188e-01, -7.5103e-01,
           1.2476e+00,  2.9875e+00, -6.4955e-01, -1.4872e+00, -1.6174e+00,
          -6.5375e-01, -1.9805e-01,  1.1435e+00,  3.5745e-01,  1.1379e+00,
          -6.2834e-01,  1.2388e+00, -6.0043e-01,  1.5073e+00, -2.7195e-01,
          -3.3186e-01,  1.9321e+00,  2.8132e-01,  9.2683e-01,  1.3415e-01,
          -6.1145e-01,  2.0047e+00,  6.2776e-01, -3.6527e-01,  3.9294e-01,
          -1.7908e+00, -5.3770e-02, -1.0250e+00,  6.0990e-01, -1.7024e-01,
           4.9612e-01, -3.0863e+00, -6.3466e-01, -5.6173e-01,  5.4916e-01,
           5.0425e-01, -1.2284e+00,  1.1119e-01, -1.2915e+00,  2.7266e-01,
          -6.3272e-01,  1.8733e+00, -3.5719e-01, -2.8305e-01, -7.9757e-01,
           5.7095e-01,  2.4766e-01, -6.6385e-01,  5.2976e-01, -1.1068e+00,
           1.6573e+00,  1.3192e+00, -3.3305e-01,  9.1924e-01,  6.7070e-01,
           1.0128e+00, -9

In [None]:
batch_embeddings = []
for i in range(len(ds200['train'])):
  p= torch.tensor(ds200['train'][i]["embeddings"]).unsqueeze(0).unsqueeze(1)
  batch_embeddings.append(p)

print(len(batch_embeddings))
batch_embeddings = torch.cat(batch_embeddings, dim=0)
print(batch_embeddings.shape)

200
torch.Size([200, 1, 768])


In [None]:
class FFNLayer(nn.Module):
    # T5LayerFF
    def __init__(self, config):
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_ff)  # First projection
        self.activation = nn.GELU()  # Non-linear activation
        self.linear2 = nn.Linear(config.d_ff, config.d_model)  # Second projection
        self.dropout = nn.Dropout(config.dropout_rate)  # Dropout regularization
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)  # Layer normalization

    def forward(self, hidden_states):

        #initializing weights in the feed-forward layers to improve stability during training.
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)

         # Input goes through the FFN
        forwarded_states = self.linear1(hidden_states)
        forwarded_states = self.activation(forwarded_states)
        forwarded_states = self.dropout(forwarded_states)
        forwarded_states = self.linear2(forwarded_states)

        # Add residual connection and layer norm
        hidden_states = self.layer_norm(hidden_states + forwarded_states)
        return hidden_states


In [None]:
class AttentionLayer(nn.Module):
    # T5Attention
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_heads  # Number of attention heads
        self.d_model = config.d_model
        self.dropout = config.attention_dropout_rate

        self.self_attention = nn.MultiheadAttention(
            embed_dim=self.d_model,
            num_heads=self.num_heads,
            dropout=self.dropout
        )

        # skipping Relative positional bias

        self.layer_norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states, attention_mask=None, layer_head_mask=None, output_attentions=False):
        """
        Forward pass for the attention layer.

        Args:
            hidden_states (torch.FloatTensor): Input tensor of shape [seq_len, batch_size, d_model].
            attention_mask (torch.FloatTensor, optional): Mask for attention mechanism of shape [batch_size, seq_len].
            layer_head_mask (torch.FloatTensor, optional): Mask for specific attention heads.
            output_attentions (bool, optional): Whether to return attention scores.

        Returns:
            torch.FloatTensor: Updated hidden states after attention.
            Optional[torch.FloatTensor]: Attention weights if `output_attentions=True`.
        """

        # Transpose hidden_states to [seq_len, batch_size, d_model] for nn.MultiheadAttention
        hidden_states = hidden_states.transpose(0, 1)  # [seq_len, batch_size, d_model]

        # Apply multi-head self-attention
        attention_output, attention_weights = self.self_attention(
            query=hidden_states,
            key=hidden_states,
            value=hidden_states,
            attn_mask=attention_mask
        )


        # Residual connection and layer normalization
        hidden_states = self.layer_norm(hidden_states + self.dropout(attention_output))

        # Transpose back to [batch_size, seq_len, d_model]
        hidden_states = hidden_states.transpose(0, 1)


        if output_attentions:
            return hidden_states, attention_weights
        return hidden_states

In [None]:
class EncoderBlock(nn.Module):
   # a single Transformer encoder layer
   # similar to T5 Block
    def __init__(self, config):
        super(EncoderBlock, self).__init__()
        self.layer = nn.ModuleList()
        self.layer.append(AttentionLayer(config))
        self.layer.append(FFNLayer(config))


    def forward(
        self,
        hidden_states,
        attention_mask=None,
        # position_bias=None,
        layer_head_mask=None,
        #cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=True,
        cache_position=None,
    ):
        """
        Forward pass for the encoder block.
        Args:
            hidden_states: Tensor of shape [?].
            attention_mask: Mask to prevent attention to certain positions.
            layer_head_mask: Mask for specific attention heads.
            use_cache: Whether to use caching for inference.
            output_attentions: Whether to return attention outputs.
            cache_position: Cache tracking the current position for inference.
        Returns:
            hidden_states: Updated hidden states after attention and FFN.
            outputs: A tuple containing additional optional outputs.
        """

        self_attention_outputs = self.layer[0](
            hidden_states,
            attention_mask=attention_mask,
           # layer_head_mask=layer_head_mask,
            #use_cache=use_cache,
            #output_attentions=output_attentions,
           # cache_position=cache_position,
        )
        if output_attentions:
            attention_outputs = self_attention_outputs[1:]
            hidden_states = self_attention_outputs[0]
        else:
            hidden_states = self_attention_outputs

        hidden_states = self.layer[1](hidden_states)

        # Clamp to handle FP16 training again
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        # Prepare outputs
        #outputs = (hidden_states,) + attention_outputs

        return hidden_states

# Initialize EncoderBlock
config = CustomConfig(d_model=768, num_heads=8, d_ff=2048)
encoder_block = EncoderBlock(config)

# Pass through EncoderBlock
block_output = encoder_block(hidden_states=train_embedding)
hidden_states= block_output  # Extract the hidden states
print(f"Output Shape from EncoderBlock: {hidden_states.shape}")  # Expected: [1, 1, 768]
print(f"Output from EncoderBlock: {hidden_states}")


Output Shape from EncoderBlock: torch.Size([1, 1, 768])
Output from EncoderBlock: tensor([[[-7.4082e-01,  6.6813e-01,  1.3104e+00, -3.2134e-01,  7.3477e-02,
          -1.9585e+00,  1.3226e-01, -6.7048e-01, -3.4565e-01, -8.7595e-01,
           5.6327e-01,  3.4616e+00, -6.9094e-01, -1.0358e-01, -9.3360e-01,
          -6.7619e-01, -1.9341e-01,  1.2703e+00, -9.1196e-01,  3.9362e-01,
          -8.9678e-01,  1.4997e+00,  1.1863e-01,  1.6597e+00, -2.7383e-01,
          -2.7056e-02,  1.8806e+00,  5.0705e-01,  6.1383e-01, -3.4459e-01,
          -1.1504e-01,  1.0463e+00, -2.0782e-01, -4.6925e-01,  6.2978e-01,
          -1.0400e+00, -3.3170e-01, -5.1406e-01, -2.2220e-02,  1.8377e-02,
           1.1438e+00, -3.1855e+00, -1.1736e+00, -6.3322e-01, -3.7460e-01,
           1.1376e+00, -8.0432e-01, -6.0983e-01, -1.3477e+00,  8.3937e-01,
          -7.6238e-02,  7.1173e-01,  3.5535e-01, -6.2831e-01, -3.6547e-01,
           1.8018e+00, -2.1718e-01,  2.6126e-01,  1.0592e+00, -1.6702e-01,
           8.7669e

In [None]:
class TransformerEncoder(nn.Module):
  # stacks multiple encode layers
  # similar to T5Stack
    def __init__(self, config):
        super(TransformerEncoder, self).__init__()
        #self.embed_tokens = embed_tokens

        self.block = nn.ModuleList(
            [EncoderBlock(config) for i in range(config.num_encoder_layers)]
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout_rate)

        # Initialize weights and apply final processing
        # self.post_init()
        # # Model parallel
        # self.model_parallel = False
        # self.device_map = None
        # self.gradient_checkpointing = False



    def forward(
          self,
          hidden_states,
          attention_mask=None,
          head_mask=None,
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True,
          ):
          """
          Forward pass for the TransformerEncoder.

          Args:
              hidden_states (Tensor): Input embeddings of shape [batch_size, seq_len, d_model].
              attention_mask (Tensor, optional): Mask to prevent attention to certain positions (e.g., padding).
              head_mask (Tensor, optional): Mask for specific attention heads.
              output_attentions (bool, optional): Whether to return attention weights.
              output_hidden_states (bool, optional): Whether to return intermediate hidden states.
              return_dict (bool, optional): Whether to return outputs as a dictionary.

          Returns:
              dict or tuple: Updated hidden states and optional outputs.
          """

          # Prepare outputs if needed
          all_hidden_states = () if output_hidden_states else None
          all_attentions = () if output_attentions else None

          # Apply dropout to input hidden states
          hidden_states = self.dropout(hidden_states)

          # Pass through each layer in the encoder block
          for i, layer_module in enumerate(self.block):
              if output_hidden_states:
                  all_hidden_states = all_hidden_states + (hidden_states,)

              # Forward pass through the encoder block
              print(hidden_states.shape)
              layer_outputs = layer_module(
                  hidden_states,
                  attention_mask=attention_mask,
                  layer_head_mask=head_mask[i] if head_mask is not None else None,
                  output_attentions=output_attentions,
              )

              # Collect attention weights if requested
              if output_attentions:
                  hidden_states = layer_outputs[0]
                  all_attentions = all_attentions + (layer_outputs[1],)
              else:
                  hidden_states = layer_outputs

          # Apply final layer normalization
          hidden_states = self.final_layer_norm(hidden_states)

          # Apply final dropout
          hidden_states = self.dropout(hidden_states)

          # Add the final hidden state to outputs if requested
          if output_hidden_states:
              all_hidden_states = all_hidden_states + (hidden_states,)

          # Prepare return values
          if not return_dict:
              return tuple(
                  v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None
              )

          return {
              "last_hidden_state": hidden_states,
              "hidden_states": all_hidden_states,
              "attentions": all_attentions,
          }


# Initialize TransformerEncoder
config = CustomConfig(d_model=768, num_heads=8, d_ff=2048, num_encoder_layers=6)
transformer_encoder = TransformerEncoder(config)
attention_mask = torch.ones(1, 1)

# Pass through TransformerEncoder
encoder_output = transformer_encoder(hidden_states=train_embedding, attention_mask=attention_mask)

# Print the last hidden state
print(f"Output Shape from TransformerEncoder: {encoder_output['last_hidden_state'].shape}")  # Expected: [1, 1, 768]
print(f"Last Hidden State: {encoder_output['last_hidden_state']}")



torch.Size([1, 1, 768])
torch.Size([1, 1, 768])
torch.Size([1, 1, 768])
torch.Size([1, 1, 768])
torch.Size([1, 1, 768])
torch.Size([1, 1, 768])
Output Shape from TransformerEncoder: torch.Size([1, 1, 768])
Last Hidden State: tensor([[[ 0.0000, -0.6940, -0.6486, -0.6629, -0.6787,  0.0561,  0.0110,
           0.7718, -0.7073, -1.4147,  0.5785,  0.0000, -0.0419, -2.8395,
          -1.5165,  0.9448,  1.8933, -0.1472,  1.2650, -1.7181,  0.7480,
          -1.1613,  0.2203,  0.9507, -1.7259,  0.4316, -0.7485,  0.5906,
          -2.8265, -1.3478,  2.0990, -0.9457,  0.0616, -0.4538, -1.2644,
           1.3285,  0.0000, -0.6078, -2.1234, -0.2391, -0.2414, -1.4904,
          -0.6453, -0.2166,  2.0104, -0.4464, -0.9705, -1.8663,  0.8522,
          -0.2578, -1.0448,  2.5201,  0.0155, -0.6360, -0.2230,  0.3811,
           1.2844,  1.0443, -0.0000, -1.1150,  0.2060,  2.3618,  1.2245,
           1.2458, -0.5566,  0.6972, -0.5094, -1.3318, -1.8702,  0.2512,
           0.8127,  0.2488, -0.9235,  0.4103,

In [None]:
class EncoderModel(nn.Module):
  # wraps the transformer encoder
  # similar to T5EncoderModel


    def __init__(self, config):
      super(EncoderModel, self).__init__()
      # Embedding layer (token + positional)
      # self.token_embedding = nn.Embedding(config., hidden_size)
      # self.position_embedding = nn.Embedding(max_seq_len, hidden_size)
      # self.dropout = nn.Dropout(dropout)
      self.encoder = TransformerEncoder(config)
      # self.hidden_size = config.h


    def forward(
        self,
        hidden_states: torch.FloatTensor,  # Precomputed embeddings
        attention_mask = None,
        output_attentions = False,
        output_hidden_states = False,
        return_dict = True,
    ) -> Union[Tuple[torch.FloatTensor], dict]:
      """
      Forward pass for the EncoderModel.

      Args:
          hidden_states (torch.FloatTensor): Precomputed embeddings of shape [batch_size, seq_len, d_model].
          attention_mask (Optional[torch.FloatTensor]): Mask of shape [batch_size, seq_len] to prevent attention to padding tokens.
          output_attentions (Optional[bool]): Whether to return attention weights.
          output_hidden_states (Optional[bool]): Whether to return hidden states of all layers.
          return_dict (Optional[bool]): Whether to return outputs as a dictionary.

      Returns:
          Union[Tuple, dict]: Final hidden states and optional outputs (attention weights, hidden states).
      """

      # Forward pass through the transformer encoder
      print(hidden_states.shape)
      encoder_outputs = self.encoder(
          hidden_states=hidden_states,
          attention_mask=attention_mask,
          output_attentions=output_attentions,
          output_hidden_states=output_hidden_states,
          return_dict=return_dict,
      )

      return encoder_outputs

      # if not return_dict:
      #     return (hidden_states, all_hidden_states, all_attentions)
      # return {
      #     "last_hidden_state": hidden_states,
      #     "hidden_states": all_hidden_states,
      #     "attentions": all_attentions,
      # }




    # def forward(self, input_ids, attention_mask=None):



    #   return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    #   encoder_outputs = self.encoder(
    #         input_ids=input_ids,
    #         attention_mask=attention_mask,
    #         inputs_embeds=inputs_embeds,
    #         head_mask=head_mask,
    #         output_attentions=output_attentions,
    #         output_hidden_states=output_hidden_states,
    #         return_dict=return_dict,
    #     )

    #   return encoder_outputs


      # seq_len = input_ids.size(1)
      # # Create position IDs
      # position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
      # position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

      # # Combine token and position embeddings
      # embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
      # embeddings = self.dropout(embeddings)

      # # Pass through TransformerEncoder
      # encoded_output = self.encoder(embeddings.transpose(0, 1), attention_mask)  # Transpose for attention

      # return encoded_output.transpose(0, 1)  # Transpose back to original format

    # def parallelize(self, device_map=None):
    #   warnings.warn(
    #       "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
    #       " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
    #       " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
    #       " 'block.1': 1, ...}",
    #       FutureWarning,
    #   )
    #   self.device_map = (
    #       get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
    #       if device_map is None
    #       else device_map
    #   )
    #   assert_device_map(self.device_map, len(self.encoder.block))
    #   self.encoder.parallelize(self.device_map)
    #   self.model_parallel = True

    # @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    # def deparallelize(self):
    #   warnings.warn(
    #       "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
    #       FutureWarning,
    #   )
    #   self.encoder.deparallelize()
    #   self.encoder = self.encoder.to("cpu")
    #   self.model_parallel = False
    #   self.device_map = None
    #   torch.cuda.empty_cache()

    # def get_input_embeddings(self):
    #   return self.shared

    # def set_input_embeddings(self, new_embeddings):
    #   self.shared = new_embeddings
    #   self.encoder.set_input_embeddings(new_embeddings)

    # def _tie_weights(self):
    #   if self.config.tie_word_embeddings:
    #       self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)


In [None]:
config = CustomConfig(d_model=768, num_heads=8, d_ff=2048, num_encoder_layers=6)
encoder_model = EncoderModel(config)

# Pass the embedding through the encoder
output = encoder_model(batch_embeddings)
print(output)

torch.Size([200, 1, 768])
torch.Size([200, 1, 768])
torch.Size([200, 1, 768])
torch.Size([200, 1, 768])
torch.Size([200, 1, 768])
torch.Size([200, 1, 768])
torch.Size([200, 1, 768])
{'last_hidden_state': tensor([[[-1.1197, -0.5308,  0.7889,  ..., -0.3037,  1.7757,  1.6209]],

        [[-2.2962, -0.3699,  1.3987,  ..., -0.0000,  2.4222,  0.9119]],

        [[-1.4379, -0.1271,  0.6503,  ..., -1.4084,  1.5974,  1.1099]],

        ...,

        [[-2.2072, -0.8395, -0.1178,  ..., -1.8743,  1.3957,  0.3964]],

        [[-0.0000, -0.8226, -0.8716,  ..., -0.6011,  2.2185,  0.7359]],

        [[-1.5118,  0.0000, -0.7643,  ...,  1.2085,  0.0000,  1.1711]]],
       grad_fn=<MulBackward0>), 'hidden_states': None, 'attentions': None}


# Decoder



In [None]:
class CrossAttentionLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_heads
        self.d_model = config.d_model
        self.dropout = config.attention_dropout_rate

        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.d_model,
            num_heads=self.num_heads,
            dropout=self.dropout
        )

        self.layer_norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states, encoder_hidden_states, attention_mask=None):
        hidden_states = hidden_states.transpose(0, 1)
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)

        attention_output, attention_weights = self.cross_attention(
            query=hidden_states,
            key=encoder_hidden_states,
            value=encoder_hidden_states,
            attn_mask=attention_mask
        )

        hidden_states = self.layer_norm(hidden_states + self.dropout(attention_output))
        hidden_states = hidden_states.transpose(0, 1)

        return hidden_states

class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = AttentionLayer(config)
        self.cross_attention = CrossAttentionLayer(config)
        self.ffn = FFNLayer(config)

    def forward(
        self,
        hidden_states,
        encoder_hidden_states,
        attention_mask=None,
        encoder_attention_mask=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # Self Attention
        self_attention_outputs = self.self_attention(
            hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions
        )
        hidden_states = self_attention_outputs if not output_attentions else self_attention_outputs[0]

        # Cross Attention
        cross_attention_outputs = self.cross_attention(
            hidden_states,
            encoder_hidden_states,
            attention_mask=encoder_attention_mask
        )
        hidden_states = cross_attention_outputs

        # Feed Forward
        hidden_states = self.ffn(hidden_states)

        return hidden_states

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.block = nn.ModuleList(
            [DecoderBlock(config) for _ in range(config.num_decoder_layers)]
        )
        self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.output_projection = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(
        self,
        hidden_states,
        encoder_hidden_states,
        attention_mask=None,
        encoder_attention_mask=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attns = () if output_attentions else None

        hidden_states = self.dropout(hidden_states)

        for i, layer_module in enumerate(self.block):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            cross_attn_layer_head_mask = (
                cross_attn_head_mask[i] if cross_attn_head_mask is not None else None
            )

            hidden_states = layer_module(
                hidden_states,
                encoder_hidden_states,
                attention_mask=attention_mask,
                encoder_attention_mask=encoder_attention_mask,
                layer_head_mask=layer_head_mask,
                cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                output_attentions=output_attentions,
            )

        # Final layer norm
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)

        # Project to vocabulary
        logits = self.output_projection(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [
                logits,
                all_hidden_states,
                all_self_attns,
                all_cross_attns
            ] if v is not None)

        return {
            "logits": logits,
            "hidden_states": all_hidden_states,
            "self_attentions": all_self_attns,
            "cross_attentions": all_cross_attns,
        }

In [None]:
class T5Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.decoder = TransformerDecoder(config)
        self.config = config

    def forward(
        self,
        encoder_input,
        decoder_input,
        attention_mask=None,
        decoder_attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        # Encode
        encoder_outputs = self.encoder(
            hidden_states=encoder_input,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Decode
        decoder_outputs = self.decoder(
            hidden_states=decoder_input,
            encoder_hidden_states=encoder_outputs["last_hidden_state"],
            attention_mask=decoder_attention_mask,
            encoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return (decoder_outputs["logits"],)

        return {
            "logits": decoder_outputs["logits"],
            "encoder_hidden_states": encoder_outputs["hidden_states"],
            "decoder_hidden_states": decoder_outputs["hidden_states"],
            "encoder_attentions": encoder_outputs["attentions"],
            "decoder_attentions": decoder_outputs["self_attentions"],
            "cross_attentions": decoder_outputs["cross_attentions"],
        }

In [None]:
def create_sample_config():
    return CustomConfig(
        d_model=768,
        num_heads=8,
        d_ff=2048,
        num_encoder_layers=6,
        num_decoder_layers=6,
        vocab_size=32128,  # Example vocabulary size
        dropout_rate=0.1,
        attention_dropout_rate=0.1,
        layer_norm_eps=1e-6
    )

# Initialize model
config = create_sample_config()
model = T5Model(config)

# Create sample inputs
batch_size = 2
seq_length = 10
encoder_input = torch.randn(batch_size, seq_length, config.d_model)
decoder_input = torch.randn(batch_size, seq_length, config.d_model)
attention_mask = torch.ones(batch_size, seq_length)
decoder_attention_mask = torch.ones(batch_size, seq_length)

# Forward pass
outputs = model(
    encoder_input=encoder_input,
    decoder_input=decoder_input,
    attention_mask=attention_mask,
    decoder_attention_mask=decoder_attention_mask,
    output_attentions=True,
    output_hidden_states=True,
)

# Print output shapes
print("Logits shape:", outputs["logits"].shape)
print("Encoder hidden states shape:",
      [hidden.shape for hidden in outputs["encoder_hidden_states"]] if outputs["encoder_hidden_states"] else None)
print("Decoder hidden states shape:",
      [hidden.shape for hidden in outputs["decoder_hidden_states"]] if outputs["decoder_hidden_states"] else None)