<a href="https://colab.research.google.com/github/newmantic/T5/blob/main/T5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, x):
        return self.embedding(x)

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        i = torch.arange(0, embed_size // 2).float()
        angle_rates = 1 / (10000 ** (2 * i / embed_size))
        self.encoding[:, 0::2] = torch.sin(pos * angle_rates)
        self.encoding[:, 1::2] = torch.cos(pos * angle_rates)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1), :]

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )

        out = self.fc_out(out)
        return out

In [5]:
class FeedForward(nn.Module):
    def __init__(self, embed_size, ff_hidden_size, dropout):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden_size)
        self.fc2 = nn.Linear(ff_hidden_size, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden_size, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden_size, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.ff(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [7]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len):
        super(Encoder, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.position_encoding = PositionalEncoding(embed_size, max_len)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, ff_hidden_size, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, mask):
        out = self.token_embedding(x)
        out = self.position_encoding(out)

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden_size, dropout):
        super(DecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.cross_attention = MultiHeadAttention(embed_size, heads)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, ff_hidden_size, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm1(attention + x))
        cross_attention = self.cross_attention(value, key, query, src_mask)
        x = self.dropout(self.norm2(cross_attention + query))
        forward = self.ff(x)
        out = self.dropout(self.norm3(forward + x))
        return out

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len):
        super(Decoder, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, embed_size)
        self.position_encoding = PositionalEncoding(embed_size, max_len)
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, ff_hidden_size, dropout)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x, enc_out, src_mask, trg_mask):
        out = self.token_embedding(x)
        out = self.position_encoding(out)

        for layer in self.layers:
            out = layer(out, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(out)
        return out

In [9]:
class T5(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len):
        super(T5, self).__init__()
        self.encoder = Encoder(vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len)
        self.decoder = Decoder(vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len)

    def forward(self, src, trg, src_mask, trg_mask):
        enc_out = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_out, src_mask, trg_mask)
        return out

In [12]:
# Define a small vocab size and model parameters for the example
vocab_size = 10000
embed_size = 128
num_layers = 2
heads = 8
ff_hidden_size = 512
dropout = 0.3
max_len = 512

# Instantiate the T5 model
model = T5(vocab_size, embed_size, num_layers, heads, ff_hidden_size, dropout, max_len)

# Example input: batch size of 2, sequence length of 10 for both source and target sequences
src = torch.randint(0, vocab_size, (2, 10))
trg = torch.randint(0, vocab_size, (2, 10))

# Masking for the source and target sequences (None for simplicity in this example)
src_mask = None
trg_mask = None

# Forward pass through the model
output = model(src, trg, src_mask, trg_mask)

# Print the shapes of the input and output tensors to verify
print(f"Source Input (src): \n{src}")
print(f"Target Input (trg): \n{trg}")
print(f"Model Output: \n{output}")

Source Input (src): 
tensor([[5860, 6047, 2119, 7529,   56, 8197, 2247, 2751, 7140, 1277],
        [2096, 1653, 6168, 1225,  371, 5284,  365, 7736, 3973, 7943]])
Target Input (trg): 
tensor([[9406, 7932, 3639, 5834, 9317, 3421, 2432, 7695, 8233, 3720],
        [ 946, 2875, 8101,  963, 3326, 5471, 9853, 5283, 7959, 9894]])
Model Output: 
tensor([[[ 0.1689, -0.0824, -0.0536,  ..., -0.3140,  0.9741, -0.1810],
         [ 0.2126,  0.4521,  2.0198,  ..., -1.7455,  0.0716,  1.1459],
         [-0.8085, -0.0210, -1.4434,  ...,  0.3958,  0.4651,  0.8200],
         ...,
         [ 0.6445,  0.4750,  0.5597,  ..., -0.6545, -0.3813,  0.9749],
         [ 0.0367, -0.8057,  1.6405,  ..., -1.4660,  0.7535,  1.3746],
         [-0.4985,  0.3318, -0.8476,  ..., -0.1576, -0.4301,  0.2718]],

        [[ 0.9819, -0.6797, -0.7677,  ..., -0.2764,  1.0822, -0.5082],
         [-0.6578,  0.0683,  0.7709,  ..., -0.1163,  0.1826,  0.9121],
         [-0.1463, -1.5678, -0.3360,  ...,  0.1002,  0.9829,  0.6202],
      