In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import string
from tqdm import tqdm

In [2]:
class TextDataset(Dataset):
    def __init__(self, file_path, sequence_length, min_word_freq=8):
        self.sequence_length = sequence_length
        self.tokens, self.vocab = self.load_and_preprocess(
            file_path, min_word_freq=min_word_freq
        )
        self.data = self.create_sequences()

    def load_and_preprocess(self, file_path, min_word_freq=8):
        with open(file_path, "r", encoding="utf-8") as file:
            text = file.read().lower()

        # Remove punctuation
        text = text.translate(str.maketrans("", "", string.punctuation))

        # Tokenize the text using spaces
        tokens = text.split()

        # Count word frequencies
        word_freq = Counter(tokens)

        # Build vocabulary with words that occur more than `min_word_freq`
        vocab = {"<UNK>": 0}
        for word, freq in word_freq.items():
            if freq >= min_word_freq:
                vocab[word] = len(vocab)

        # Replace rare words with <UNK>
        tokens = [
            word if word_freq[word] >= min_word_freq else "<UNK>" for word in tokens
        ]

        return tokens, vocab

    def create_sequences(self):
        sequences = []
        for i in range(0, len(self.tokens) - self.sequence_length):
            seq = self.tokens[i : i + self.sequence_length]
            seq_idx = [self.vocab.get(word, 0) for word in seq]
            sequences.append(seq_idx)
        return sequences

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.data[idx])
        return sequence[:-1], sequence[1:]

In [3]:
file_path = "input.txt"
sequence_length = 20
dataset = TextDataset(file_path, sequence_length)

for x, y in dataset:
    print("Input Sequence: ", x)
    print("Target Sequence: ", y)
    break

Input Sequence:  tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 10, 10,  1,  2, 12, 13, 11,
        14])
Target Sequence:  tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 10, 10,  1,  2, 12, 13, 11, 14,
        15])


In [4]:
len(dataset.vocab)

2233

In [5]:
dataset.tokens[:30]

['first',
 'citizen',
 'before',
 'we',
 'proceed',
 'any',
 'further',
 'hear',
 'me',
 'speak',
 'all',
 'speak',
 'speak',
 'first',
 'citizen',
 'you',
 'are',
 'all',
 'resolved',
 'rather',
 'to',
 'die',
 'than',
 'to',
 '<UNK>',
 'all',
 'resolved',
 'resolved',
 'first',
 'citizen']

In [7]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does matrix multiplication for query*keys for each training example
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [9]:
class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Transformer, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Handling the case when mask is None
        if mask is None:
            mask = torch.zeros((seq_length, seq_length), device=self.device).type(
                torch.bool
            )

        for layer in self.layers:
            out = layer(out, out, out, mask)

        out = self.fc_out(out)
        return out

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [12]:
vocab = dataset.vocab
vocab_size = len(vocab)  # Replace with your vocabulary size
embed_size = 128  # Embedding size
num_layers = 2  # Number of Transformer layers
heads = 4  # Number of heads in multi-head attention
forward_expansion = 4
dropout = 0.2  # Dropout rate
max_length = sequence_length  # Maximum length of a sequence
idx_to_word = {v: k for k, v in vocab.items()}

In [13]:
# Hyperparameters
num_epochs = 3
learning_rate = 0.01

In [14]:
# Initialize the model
model = Transformer(
    vocab_size=vocab_size,
    embed_size=embed_size,
    num_layers=num_layers,
    heads=heads,
    device=device,
    forward_expansion=forward_expansion,
    dropout=dropout,
    max_length=max_length,
)

In [15]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [16]:
# Training loop
model.to(device)
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    # Wrap the data_loader with tqdm for a progress bar
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader))
    for batch_idx, (data, targets) in progress_bar:
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Forward pass
        scores = model(data, None)
        scores = scores.view(-1, scores.size(-1))
        targets = targets.view(-1)
        loss = criterion(scores, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Update progress bar
        progress_bar.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        progress_bar.set_postfix(loss=total_loss / (batch_idx + 1))

Epoch [1/3]: 100%|██████████| 6333/6333 [01:29<00:00, 70.65it/s, loss=4.77]
Epoch [2/3]: 100%|██████████| 6333/6333 [01:27<00:00, 72.56it/s, loss=4.05]
Epoch [3/3]: 100%|██████████| 6333/6333 [01:27<00:00, 72.60it/s, loss=3.81]


In [17]:
# Save the model checkpoint
model.to(device)
torch.save(model.state_dict(), "model.pth")

In [None]:
# Load the saved model's state dictionary
model.to(device)
model.load_state_dict(torch.load("model.pth", map_location=device))

In [18]:
# Set the model to evaluation mode
model.eval()

Transformer(
  (word_embedding): Embedding(2233, 128)
  (position_embedding): Embedding(20, 128)
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attention): SelfAttention(
        (values): Linear(in_features=32, out_features=32, bias=False)
        (keys): Linear(in_features=32, out_features=32, bias=False)
        (queries): Linear(in_features=32, out_features=32, bias=False)
        (fc_out): Linear(in_features=128, out_features=128, bias=True)
      )
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
  (fc_out): Linear(in_features=128, out_features=2233, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [27]:
input_text = "what he hath done famously he"
input_text = input_text.lower()
input_text = input_text.translate(str.maketrans("", "", string.punctuation))
words = input_text.split()
words

['what', 'he', 'hath', 'done', 'famously', 'he']

In [28]:
vocab = dataset.vocab
sequence = [vocab.get(word, vocab["<UNK>"]) for word in words]
sequence

[53, 93, 108, 44, 0, 93]

In [29]:
input_tensor = torch.tensor([sequence], dtype=torch.long)
input_tensor

tensor([[ 53,  93, 108,  44,   0,  93]])

In [30]:
input_tensor = input_tensor.to(device)

In [31]:
# Disable gradient calculations
with torch.no_grad():
    # Feed the tensor to the model
    output = model(input_tensor, None)

In [32]:
# Get the last word logits and apply softmax
last_word_logits = output[0, -1, :]
probabilities = F.softmax(last_word_logits, dim=0)
probabilities

tensor([1.3974e-02, 2.9498e-05, 3.6062e-06,  ..., 2.5875e-09, 1.9804e-08,
        4.2472e-07], device='cuda:0')

In [33]:
# Get the most likely next word index
next_word_idx = torch.argmax(probabilities).item()
next_word_idx

108

In [34]:
idx_to_word[next_word_idx]

'hath'

In [35]:
def generate_text(input_text):
    input_text = input_text.translate(str.maketrans("", "", string.punctuation))
    words = input_text.lower().split()
    sequence = [vocab.get(word, vocab["<UNK>"]) for word in words]

    for _ in range(10):
        # Tokenize the current sequence of words and convert to a tensor
        input_tensor = torch.tensor([sequence], dtype=torch.long)
        input_tensor = input_tensor.to(device)

        # Get the model's prediction for the next word
        with torch.no_grad():
            output = model(input_tensor, None)

        # Extract the last word (the next word prediction)
        last_word_logits = output[0, -1, :]
        probabilities = F.softmax(last_word_logits, dim=0)
        next_word_idx = torch.argmax(probabilities).item()

        # Find the predicted word and append to the current sequence
        predicted_word = idx_to_word.get(next_word_idx)

        words.append(predicted_word)
        sequence.append(next_word_idx)

    return " ".join(words)

In [36]:
input_text = "what he hath done famously he"

In [37]:
generate_text(input_text)

'what he hath done famously he hath he hath he hath he hath he hath he'