# Custom Transformer
Building a custom summarization transformer. Comparing the resulting summaries with the previous methods and models.

## Transformer Model #1

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        # Ensure d_model is divisible by nhead
        assert d_model % nhead == 0, "d_model must be divisible by nhead"
        
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, src):
        # Ensure src shape is [sequence_length, batch_size, d_model]
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.norm1(src2)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.norm2(src2)
        return src


# Transformer Decoder Layer
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, tgt, memory):
        tgt2 = self.self_attn(tgt, tgt, tgt)[0]
        tgt = tgt + self.norm1(tgt2)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.norm2(tgt2)
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.norm3(tgt2)
        return tgt

# Full Transformer Model for Summarization
class CustomTransformerSummarizer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout=0.1, max_len=500):
        super(CustomTransformerSummarizer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        self.pos_decoder = PositionalEncoding(d_model, max_len)
        self.encoder_layers = nn.ModuleList([TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_decoder_layers)])
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
    def encode(self, src):
        src = self.encoder_embedding(src) * math.sqrt(src.size(-1))
        src = self.pos_encoder(src)
        for layer in self.encoder_layers:
            src = layer(src)
        return src

    def decode(self, tgt, memory):
        tgt = self.decoder_embedding(tgt) * math.sqrt(tgt.size(-1))
        tgt = self.pos_decoder(tgt)
        for layer in self.decoder_layers:
            tgt = layer(tgt, memory)
        return tgt

    def forward(self, src, tgt):
        memory = self.encode(src)
        output = self.decode(tgt, memory)
        return self.fc_out(output)


In [None]:
## Instantiate the CustomTransformerSummarizer Class
# custom_transformer_model = CustomTransformerSummar()

## Get import datasets to train the custom transformer model

In [8]:
# Custom training loop for the custom transformer model
num_epochs = 5

# Assuming `custom_transformer_model` is the custom model name - will need to instantiate/create an instance of the CustomTransformerSummarizer Class
for epoch in range(num_epochs):
    custom_transformer_model.train()
    total_loss = 0

    for batch in dataloader:
        # Unpack input and target from the batch
        input_ids, attention_mask, target_ids = batch  # Adjust as per dataloader's batch output
        
        # Forward pass through the custom model
        # You may need to pass both encoder and decoder inputs separately, depending on the model's structure
        outputs = custom_transformer_model(input_ids=input_ids, attention_mask=attention_mask, target_ids=target_ids)

        # Assuming outputs are logits and compatible with CrossEntropyLoss
        # Reshape outputs to match target shape if needed
        outputs = outputs.view(-1, outputs.size(-1))  # Flatten for loss calculation
        target_ids = target_ids.view(-1)  # Flatten target

        loss = criterion(outputs, target_ids)  # Compute loss between predicted and target tokens
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}")


In [None]:
# After the model has been trained, we can defined a summarization function to execute the summaries 
# Define a summarization function
def summarize_text(model, tokenizer, text, src_vocab_size, tgt_vocab_size, device, max_len=128):
    # Tokenize and prepare input
    src_ids = tokenizer.encode(text, return_tensors='pt', max_length=max_len, truncation=True).to(device)
    tgt_ids = torch.tensor([[tokenizer.cls_token_id]], device=device)  # Start with a [CLS] token

    # Summarize
    with torch.no_grad():
        for i in range(max_len):
            outputs = model(src_ids, tgt_ids)
            next_token_id = outputs.argmax(dim=-1)[:, -1].unsqueeze(0)
            tgt_ids = torch.cat((tgt_ids, next_token_id), dim=1)
            if next_token_id.item() == tokenizer.sep_token_id:  # End on [SEP]
                break

    # Decode the output tokens
    summary = tokenizer.decode(tgt_ids[0], skip_special_tokens=True)
    return summary


In [5]:
# # Example usage
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward = 30522, 30522, 512, 8, 4, 4, 2048

# model = CustomTransformerSummarizer(src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward).to(device)

# # Test text to summarize
# sample_text = "Artificial intelligence has become a transformative force in multiple domains, including healthcare, finance, and autonomous systems..."

# # Assuming you have a tokenizer compatible with BERT or similar (for example, BERT tokenizer)
# from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# summary = summarize_text(model, tokenizer, sample_text, src_vocab_size, tgt_vocab_size, device)
# print("Generated Summary:\n", summary)

### Training the Transformer Model #2

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Hyperparameters
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1
vocab_size = 10000  # Example vocabulary size
max_seq_length = 512

class CustomTransformerSummarizer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout=0.1):
        super(CustomTransformerSummarizer, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))
        
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                          dropout=dropout, batch_first=True)
        
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids, target_ids):
        # Embed and add positional encoding
        input_embedding = self.embedding(input_ids) + self.positional_encoding[:, :input_ids.size(1), :]
        target_embedding = self.embedding(target_ids) + self.positional_encoding[:, :target_ids.size(1), :]

        # Transformer expects (batch, seq, feature)
        output = self.transformer(src=input_embedding, tgt=target_embedding)

        # Map output to vocabulary size
        output = self.fc_out(output)
        
        return output

# Initialize the custom model
model = CustomTransformerSummarizer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Assuming 0 is the padding token ID
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Dummy dataset
input_ids = torch.randint(1, vocab_size, (32, max_seq_length))  # Example input tokens
target_ids = torch.randint(1, vocab_size, (32, max_seq_length))  # Example target tokens
dataset = TensorDataset(input_ids, target_ids)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Training loop
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in dataloader:
        input_ids, target_ids = batch

        # Shift target tokens to get input and output labels
        target_input = target_ids[:, :-1]
        target_output = target_ids[:, 1:]

        # Forward pass
        outputs = model(input_ids=input_ids, target_ids=target_input)

        # Reshape for calculating loss
        outputs = outputs.view(-1, outputs.size(-1))  # Flatten for loss calculation
        target_output = target_output.reshape(-1)     # Flatten target

        # Calculate loss
        loss = criterion(outputs, target_output)
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}")

# Testing / Inference
model.eval()
with torch.no_grad():
    test_input = input_ids[:1]  # Using one example for inference
    test_target_input = torch.tensor([[1]])  # Start token for the decoder, adjust as needed
    summary = []

    for _ in range(100):  # Limit to max 100 tokens in summary
        output = model(test_input, test_target_input)
        next_token = output.argmax(dim=-1)[:, -1]  # Get last token in sequence
        summary.append(next_token.item())
        
        # Update the target sequence with the new token
        test_target_input = torch.cat([test_target_input, next_token.unsqueeze(0)], dim=1)
        
        if next_token.item() == 2:  # End token ID, adjust as needed
            break

    # Decode summary to text if you have a tokenizer
    print("Generated Summary:", summary)  # Replace with tokenizer.decode if using a tokenizer


Epoch 1/3, Loss: 9.381876230239868
Epoch 2/3, Loss: 9.252301692962646
Epoch 3/3, Loss: 9.157200932502747
Generated Summary: [1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133, 1133]
