# Custom Transformer
Building a custom summarization transformer. Comparing the resulting summaries with the previous methods and models.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        # Ensure d_model is divisible by nhead
        assert d_model % nhead == 0, "d_model must be divisible by nhead"
        
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, src):
        # Ensure src shape is [sequence_length, batch_size, d_model]
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.norm1(src2)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.norm2(src2)
        return src


# Transformer Decoder Layer
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, tgt, memory):
        tgt2 = self.self_attn(tgt, tgt, tgt)[0]
        tgt = tgt + self.norm1(tgt2)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.norm2(tgt2)
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.norm3(tgt2)
        return tgt

# Full Transformer Model for Summarization
class CustomTransformerSummarizer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout=0.1, max_len=500):
        super(CustomTransformerSummarizer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        self.pos_decoder = PositionalEncoding(d_model, max_len)
        self.encoder_layers = nn.ModuleList([TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_decoder_layers)])
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
    def encode(self, src):
        src = self.encoder_embedding(src) * math.sqrt(src.size(-1))
        src = self.pos_encoder(src)
        for layer in self.encoder_layers:
            src = layer(src)
        return src

    def decode(self, tgt, memory):
        tgt = self.decoder_embedding(tgt) * math.sqrt(tgt.size(-1))
        tgt = self.pos_decoder(tgt)
        for layer in self.decoder_layers:
            tgt = layer(tgt, memory)
        return tgt

    def forward(self, src, tgt):
        memory = self.encode(src)
        output = self.decode(tgt, memory)
        return self.fc_out(output)

# Define a summarization function
def summarize_text(model, tokenizer, text, src_vocab_size, tgt_vocab_size, device, max_len=128):
    # Tokenize and prepare input
    src_ids = tokenizer.encode(text, return_tensors='pt', max_length=max_len, truncation=True).to(device)
    tgt_ids = torch.tensor([[tokenizer.cls_token_id]], device=device)  # Start with a [CLS] token

    # Summarize
    with torch.no_grad():
        for i in range(max_len):
            outputs = model(src_ids, tgt_ids)
            next_token_id = outputs.argmax(dim=-1)[:, -1].unsqueeze(0)
            tgt_ids = torch.cat((tgt_ids, next_token_id), dim=1)
            if next_token_id.item() == tokenizer.sep_token_id:  # End on [SEP]
                break

    # Decode the output tokens
    summary = tokenizer.decode(tgt_ids[0], skip_special_tokens=True)
    return summary




In [5]:
# # Example usage
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward = 30522, 30522, 512, 8, 4, 4, 2048

# model = CustomTransformerSummarizer(src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward).to(device)

# # Test text to summarize
# sample_text = "Artificial intelligence has become a transformative force in multiple domains, including healthcare, finance, and autonomous systems..."

# # Assuming you have a tokenizer compatible with BERT or similar (for example, BERT tokenizer)
# from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# summary = summarize_text(model, tokenizer, sample_text, src_vocab_size, tgt_vocab_size, device)
# print("Generated Summary:\n", summary)