<a href="https://colab.research.google.com/github/codewithdark-git/Titans_Paper_Implementation/blob/main/Titans_Paper_Implantation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision torchaudio



In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader, Dataset

class TitansMemoryModule(nn.Module):
    def __init__(self, d_model, memory_size=512):
        super().__init__()
        self.memory_size = memory_size
        self.memory = nn.Parameter(torch.zeros(memory_size, d_model))
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.forgetting_gate = nn.Linear(d_model, 1)

    def forward(self, x):
        # x shape: [batch_size, seq_len, d_model]
        batch_size, seq_len, d_model = x.shape

        # Project input to keys and values
        keys = self.key_proj(x)  # [batch_size, seq_len, d_model]
        values = self.value_proj(x)  # [batch_size, seq_len, d_model]

        # Compute attention scores with memory
        attention_scores = torch.matmul(keys, self.memory.T)  # [batch_size, seq_len, memory_size]
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Retrieve from memory
        retrieved_memory = torch.matmul(attention_weights, self.memory)  # [batch_size, seq_len, d_model]

        # Update memory based on surprise
        surprise = torch.norm(values - retrieved_memory, dim=-1, keepdim=True)  # [batch_size, seq_len, 1]
        forgetting_weights = torch.sigmoid(self.forgetting_gate(values))  # [batch_size, seq_len, 1]

        # Update memory (during inference only)
        if not self.training:
            # Reduce batch and seq dimensions to match memory size
            avg_forgetting_weights = forgetting_weights.mean(dim=(0, 1))  # [1, d_model]
            avg_values = values.mean(dim=(0, 1))  # [1, d_model]

            # Expand or reshape to match memory shape
            avg_forgetting_weights = avg_forgetting_weights.unsqueeze(0).expand(self.memory.size(0), -1)  # [memory_size, d_model]
            avg_values = avg_values.unsqueeze(0).expand(self.memory.size(0), -1)  # [memory_size, d_model]

            # Update memory
            self.memory.data = avg_forgetting_weights * self.memory + (1 - avg_forgetting_weights) * avg_values

        return retrieved_memory


class TitansTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, memory_size=512):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.titans_memory = TitansMemoryModule(d_model, memory_size)

        # Feed-forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Self-attention
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                             key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Titans memory integration
        memory_output = self.titans_memory(src)
        src = src + self.dropout2(memory_output)
        src = self.norm2(src)

        # Feed-forward network
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm3(src)

        return src

class TitansTransformer(nn.Module):
    def __init__(self, num_tokens, d_model=512, nhead=8, num_layers=6,
                 dim_feedforward=2048, dropout=0.1, memory_size=512):
        super().__init__()

        self.embedding = nn.Embedding(num_tokens, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        # Create encoder layers with Titans memory
        self.layers = nn.ModuleList([
            TitansTransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                        dropout, memory_size)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, num_tokens)

        self.d_model = d_model
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        for layer in self.layers:
            src = layer(src, src_mask, src_key_padding_mask)

        src = self.norm(src)
        output = self.fc_out(src)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Training utilities
def create_mask(size):
    mask = torch.triu(torch.ones(size, size) * float('-inf'), diagonal=1)
    return mask

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for batch in dataloader:
        optimizer.zero_grad()

        src = batch[:-1].to(device)
        tgt = batch[1:].to(device)

        mask = create_mask(src.size(1)).to(device)

        output = model(src, src_mask=mask)
        loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# Example usage:
def main():
    # Model parameters
    num_tokens = 50000  # Vocabulary size
    d_model = 512
    nhead = 8
    num_layers = 6
    memory_size = 512

    # Initialize model
    model = TitansTransformer(
        num_tokens=num_tokens,
        d_model=d_model,
        nhead=nhead,
        num_layers=num_layers,
        memory_size=memory_size
    )

    # Training setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    print("Model initialized and ready for training")

if __name__ == "__main__":
    main()

Model initialized and ready for training


In [5]:
!pip install datasets



In [10]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import time
import torch.nn as nn


def get_data(subset_size=1000):
    # Load dataset from Hugging Face
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    tokenizer = AutoTokenizer.from_pretrained('gpt2')

    tokenizer.pad_token = tokenizer.eos_token

    def tokenize_function(examples):
        return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

    # Tokenize the dataset and limit the size
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset['train'].column_names
    )

    # Select a subset of the dataset
    train_data = tokenized_dataset['train'].select(range(subset_size))
    val_data = tokenized_dataset['validation'].select(range(subset_size // 10))  # Smaller validation set
    test_data = tokenized_dataset['test'].select(range(subset_size // 10))       # Smaller test set

    # Convert to PyTorch tensors
    train_data = torch.tensor(train_data['input_ids'], dtype=torch.long)
    val_data = torch.tensor(val_data['input_ids'], dtype=torch.long)
    test_data = torch.tensor(test_data['input_ids'], dtype=torch.long)

    vocab_size = tokenizer.vocab_size

    return train_data, val_data, test_data, vocab_size


def batchify(data, batch_size, device):
    # Divide data into batch_size parts
    nbatch = data.size(0) // batch_size
    data = data.narrow(0, 0, nbatch * batch_size)
    data = data.view(batch_size, -1).t().contiguous()
    return data.to(device)


def evaluate(model, data_source, criterion, batch_size, device):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, batch_size):
            data = data_source[i:i + batch_size].to(device)
            targets = data_source[i + 1:i + 1 + batch_size].to(device)
            # Ensure input and target sizes match
            if data.size(0) != targets.size(0):
                break  # Skip incomplete batch
            output = model(data)
            total_loss += criterion(output.view(-1, output.size(-1)), targets.view(-1)).item()
    return total_loss / (data_source.size(0) - 1)


def train_model():
    # Hyperparameters
    batch_size = 16
    eval_batch_size = 10
    d_model = 512
    nhead = 8
    num_layers = 6
    memory_size = 512
    epochs = 3  # Fewer epochs for testing
    subset_size = 1000  # Limit dataset size for testing

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get data (subset for testing)
    train_data, val_data, test_data, vocab_size = get_data(subset_size=subset_size)

    # Initialize model
    model = TitansTransformer(
        num_tokens=vocab_size,
        d_model=d_model,
        nhead=nhead,
        num_layers=num_layers,
        memory_size=memory_size
    ).to(device)

    # Batchify data
    train_data = batchify(train_data, batch_size, device)
    val_data = batchify(val_data, eval_batch_size, device)
    test_data = batchify(test_data, eval_batch_size, device)

    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

    # Training loop
    best_val_loss = float('inf')

    for epoch in range(epochs):
        epoch_start_time = time.time()

        # Train
        model.train()
        total_loss = 0.
        for batch, i in enumerate(range(0, train_data.size(0) - 1, batch_size)):
            data = train_data[i:i + batch_size].to(device)
            targets = train_data[i + 1:i + 1 + batch_size].to(device)

            # Ensure input and target sizes match
            if data.size(0) != targets.size(0):
                break  # Skip incomplete batch

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output.view(-1, output.size(-1)), targets.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            total_loss += loss.item()

            if batch % 100 == 0:
                curr_loss = total_loss / (batch + 1)
                print(f'| epoch {epoch + 1:3d} | batch {batch:3d} | '
                      f'loss {curr_loss:5.2f}')


        # Evaluate
        val_loss = evaluate(model, val_data, criterion, eval_batch_size, device)
        print('-' * 89)
        print(f'| end of epoch {epoch + 1:3d} | time: {time.time() - epoch_start_time:5.2f}s | '
              f'valid loss {val_loss:5.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'titans_transformer_model.pt')

        scheduler.step()

    # Test
    model.load_state_dict(torch.load('titans_transformer_model.pt'))
    test_loss = evaluate(model, test_data, criterion, eval_batch_size, device)
    print('=' * 89)
    print(f'| End of training | test loss {test_loss:5.2f}')
    print('=' * 89)


if __name__ == "__main__":
    train_model()


| epoch   1 | batch   0 | loss 11.05
| epoch   1 | batch 100 | loss  4.72
| epoch   1 | batch 200 | loss  3.27
| epoch   1 | batch 300 | loss  2.55
| epoch   1 | batch 400 | loss  2.26
| epoch   1 | batch 500 | loss  2.05
| epoch   1 | batch 600 | loss  1.89
| epoch   1 | batch 700 | loss  1.81
| epoch   1 | batch 800 | loss  1.75
| epoch   1 | batch 900 | loss  1.76
| epoch   1 | batch 1000 | loss  1.79
| epoch   1 | batch 1100 | loss  1.72
| epoch   1 | batch 1200 | loss  1.72
| epoch   1 | batch 1300 | loss  1.64
| epoch   1 | batch 1400 | loss  1.62
| epoch   1 | batch 1500 | loss  1.58
| epoch   1 | batch 1600 | loss  1.57
| epoch   1 | batch 1700 | loss  1.55
| epoch   1 | batch 1800 | loss  1.53
| epoch   1 | batch 1900 | loss  1.53
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 108.01s | valid loss  0.14
-----------------------------------------------------------------------------------------
| epoch   2 | ba

  model.load_state_dict(torch.load('titans_transformer_model.pt'))


| End of training | test loss  0.12


In [14]:
import torch
import torch.nn as nn
import time
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import math
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer


# Positional Encoding for Standard Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # The line below is changed to slice along dimension 0 up to x.size(0)
        x = x + self.pe[:x.size(0), :, :]  # Adjust slicing to match input sequence length
        return self.dropout(x)


# Define Standard Transformer
class StandardTransformer(nn.Module):
    def __init__(self, num_tokens, d_model=512, nhead=8, num_layers=6,
                 dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(num_tokens, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=0,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(d_model, num_tokens)

    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer.encoder(src, src_mask)
        return self.fc_out(output)


# Function to load data
def get_data():
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    train_data = torch.tensor(tokenized_dataset["train"]["input_ids"], dtype=torch.long)
    val_data = torch.tensor(tokenized_dataset["validation"]["input_ids"], dtype=torch.long)
    test_data = torch.tensor(tokenized_dataset["test"]["input_ids"], dtype=torch.long)

    vocab_size = tokenizer.vocab_size
    return train_data, val_data, test_data, vocab_size


# Benchmarking function
def benchmark_models(sequence_lengths=[128, 256, 512, 1024], batch_size=8):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {
        "standard": {"time": [], "memory": [], "perplexity": []},
        "titans": {"time": [], "memory": [], "perplexity": []},
    }

    train_data, val_data, test_data, vocab_size = get_data()
    criterion = nn.CrossEntropyLoss()

    for seq_len in sequence_lengths:
        print(f"\nTesting sequence length: {seq_len}")

        # Initialize models
        standard_transformer = StandardTransformer(num_tokens=vocab_size).to(device)
        titans_transformer = TitansTransformer(num_tokens=vocab_size, memory_size=512).to(device)

        # Prepare test batch
        test_batch = test_data[:batch_size].to(device)

        # Test Standard Transformer
        torch.cuda.empty_cache()
        start_time = time.time()
        with torch.no_grad():
            standard_transformer.eval()
            output_standard = standard_transformer(test_batch)
            perplexity_standard = torch.exp(
                criterion(output_standard.view(-1, output_standard.size(-1)), test_batch.view(-1))
            )
        end_time = time.time()

        results["standard"]["time"].append(end_time - start_time)
        results["standard"]["perplexity"].append(perplexity_standard.item())

        # Test Titans Transformer
        torch.cuda.empty_cache()
        start_time = time.time()
        with torch.no_grad():
            titans_transformer.eval()
            output_titans = titans_transformer(test_batch)
            perplexity_titans = torch.exp(
                criterion(output_titans.view(-1, output_titans.size(-1)), test_batch.view(-1))
            )
        end_time = time.time()

        results["titans"]["time"].append(end_time - start_time)
        results["titans"]["perplexity"].append(perplexity_titans.item())

        print(f"Standard Transformer - Time: {results['standard']['time'][-1]:.4f}s, "
              f"Perplexity: {results['standard']['perplexity'][-1]:.2f}")
        print(f"Titans Transformer - Time: {results['titans']['time'][-1]:.4f}s, "
              f"Perplexity: {results['titans']['perplexity'][-1]:.2f}")

    return results, sequence_lengths


# Plot benchmark results
def plot_benchmark_results(results, sequence_lengths):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Inference time
    axes[0].plot(sequence_lengths, results["standard"]["time"], label="Standard Transformer", color="blue")
    axes[0].plot(sequence_lengths, results["titans"]["time"], label="Titans Transformer", color="red")
    axes[0].set_title("Inference Time")
    axes[0].set_xlabel("Sequence Length")
    axes[0].set_ylabel("Time (s)")
    axes[0].legend()

    # Perplexity
    axes[1].plot(sequence_lengths, results["standard"]["perplexity"], label="Standard Transformer", color="blue")
    axes[1].plot(sequence_lengths, results["titans"]["perplexity"], label="Titans Transformer", color="red")
    axes[1].set_title("Perplexity")
    axes[1].set_xlabel("Sequence Length")
    axes[1].set_ylabel("Perplexity")
    axes[1].legend()

    plt.tight_layout()
    plt.savefig("benchmark_results.png")
    plt.close()


# Main function
if __name__ == "__main__":
    print("\nRunning benchmarks...")
    results, sequence_lengths = benchmark_models()
    plot_benchmark_results(results, sequence_lengths)
    print("\nBenchmark results saved to 'benchmark_results.png'.")



Running benchmarks...

Testing sequence length: 128
Standard Transformer - Time: 0.1598s, Perplexity: 74703.84
Titans Transformer - Time: 0.0329s, Perplexity: 63075.21

Testing sequence length: 256
Standard Transformer - Time: 0.0040s, Perplexity: 63956.66
Titans Transformer - Time: 0.0078s, Perplexity: 56063.02

Testing sequence length: 512
Standard Transformer - Time: 0.0044s, Perplexity: 40087.42
Titans Transformer - Time: 0.0086s, Perplexity: 40837.51

Testing sequence length: 1024
Standard Transformer - Time: 0.0038s, Perplexity: 31984.99
Titans Transformer - Time: 0.0086s, Perplexity: 47995.84

Benchmark results saved to 'benchmark_results.png'.
