In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from collections import Counter

## Dataset Preparation

In [2]:
# Dataset Preparation
with open('data/alice_1.txt', 'r', encoding='utf-8') as file:
    text = file.read()

# Tokenize the text into words
words = text.split()
word_counts = Counter(words)

vocab = list(word_counts.keys())
vocab_size = len(vocab)
word_to_int = {word: i for i, word in enumerate(vocab)}
int_to_word = {i: word for word, i in word_to_int.items()}

SEQUENCE_LENGTH = 64
samples = [words[i:i+SEQUENCE_LENGTH+1] for i in range(len(words)-SEQUENCE_LENGTH)]

print(word_to_int)

{'Alice': 0, 'was': 1, 'a': 2, 'curious': 3, 'and': 4, 'imaginative': 5, 'young': 6, 'girl': 7, 'who': 8, 'lived': 9, 'in': 10, 'quiet': 11, 'village.': 12, 'She': 13, 'had': 14, 'wild': 15, 'mop': 16, 'of': 17, 'blonde': 18, 'curls': 19, 'that': 20, 'seemed': 21, 'to': 22, 'match': 23, 'her': 24, 'adventurous': 25, 'spirit.': 26, 'One': 27, 'sunny': 28, 'afternoon,': 29, 'while': 30, 'chasing': 31, 'playful': 32, 'white': 33, 'rabbit': 34, 'through': 35, 'the': 36, 'meadow,': 37, 'stumbled': 38, 'upon': 39, 'hidden': 40, 'hole.': 41, 'Without': 42, 'second': 43, 'thought,': 44, 'she': 45, 'decided': 46, 'follow': 47, 'rabbit,': 48, 'tumbling': 49, 'headfirst': 50, 'into': 51, 'an': 52, 'enchanting': 53, 'world': 54, 'called': 55, 'Wonderland.': 56, 'As': 57, 'fell': 58, 'hole,': 59, 'around': 60, 'began': 61, 'twist': 62, 'distort.': 63, 'felt': 64, 'as': 65, 'if': 66, 'were': 67, 'floating': 68, 'kaleidoscope': 69, 'colors': 70, 'shapes.': 71, 'When': 72, 'finally': 73, 'landed,': 74

In [3]:
print(samples)

[['Alice', 'was', 'a', 'curious', 'and', 'imaginative', 'young', 'girl', 'who', 'lived', 'in', 'a', 'quiet', 'village.', 'She', 'had', 'a', 'wild', 'mop', 'of', 'blonde', 'curls', 'that', 'seemed', 'to', 'match', 'her', 'adventurous', 'spirit.', 'One', 'sunny', 'afternoon,', 'while', 'chasing', 'a', 'playful', 'white', 'rabbit', 'through', 'the', 'meadow,', 'Alice', 'stumbled', 'upon', 'a', 'hidden', 'rabbit', 'hole.', 'Without', 'a', 'second', 'thought,', 'she', 'decided', 'to', 'follow', 'the', 'rabbit,', 'tumbling', 'headfirst', 'into', 'an', 'enchanting', 'world', 'called'], ['was', 'a', 'curious', 'and', 'imaginative', 'young', 'girl', 'who', 'lived', 'in', 'a', 'quiet', 'village.', 'She', 'had', 'a', 'wild', 'mop', 'of', 'blonde', 'curls', 'that', 'seemed', 'to', 'match', 'her', 'adventurous', 'spirit.', 'One', 'sunny', 'afternoon,', 'while', 'chasing', 'a', 'playful', 'white', 'rabbit', 'through', 'the', 'meadow,', 'Alice', 'stumbled', 'upon', 'a', 'hidden', 'rabbit', 'hole.', '

In [4]:
class TextDataset(Dataset):
    def __init__(self, samples, word_to_int):
        self.samples = samples
        self.word_to_int = word_to_int

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        input_seq = torch.LongTensor([self.word_to_int[word] for word in sample[:-1]])
        target_seq = torch.LongTensor([self.word_to_int[word] for word in sample[1:]])
        return input_seq, target_seq

In [5]:
BATCH_SIZE = 32
dataset = TextDataset(samples, word_to_int)
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
)

print(dataset[1])

(tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  2, 11, 12, 13, 14,  2, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  2, 32, 33,
        34, 35, 36, 37,  0, 38, 39,  2, 40, 34, 41, 42,  2, 43, 44, 45, 46, 22,
        47, 36, 48, 49, 50, 51, 52, 53, 54, 55]), tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10,  2, 11, 12, 13, 14,  2, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  2, 32, 33, 34,
        35, 36, 37,  0, 38, 39,  2, 40, 34, 41, 42,  2, 43, 44, 45, 46, 22, 47,
        36, 48, 49, 50, 51, 52, 53, 54, 55, 56]))


## Transformer Model

In [6]:
def generate_square_subsequent_mask(sz):
    """
    Generate a square mask for the sequence. The masked positions are filled with float('-inf').
    Unmasked positions are filled with float(0.0).
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model, dropout=0.1):
        """
        :param max_len: Input length sequence.
        :param d_model: Embedding dimension.
        :param dropout: Dropout value (default=0.1)
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Inputs of forward function
        :param x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [8]:
class TextGen(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads):
        super(TextGen, self).__init__()
        self.pos_encoder = PositionalEncoding(max_len=SEQUENCE_LENGTH, d_model=embed_dim)
        self.emb = nn.Embedding(vocab_size, embed_dim)
        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=self.decoder_layer,
            num_layers=num_layers,
        )
        self.linear = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(0.2)
        
    # Positional encoding is required. Else the model does not learn.
    def forward(self, x):
        emb = self.emb(x)
        
        # Generate input sequence mask with shape (SEQUENCE_LENGTH, SEQUENCE_LENGTH)
        input_mask = generate_square_subsequent_mask(x.size(1)).to(x.device)
        
        x = self.pos_encoder(emb)
        x = self.decoder(x, memory=x, tgt_mask=input_mask, memory_mask=input_mask)
        x = self.dropout(x)
        out = self.linear(x)
        return out

## Training

In [9]:
epochs = 100
learning_rate = 0.001

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextGen(
    vocab_size=vocab_size, 
    embed_dim=100,
    num_layers=2, 
    num_heads=2,
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.\n")

TextGen(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (emb): Embedding(419, 100)
  (decoder_layer): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
    (linear1): Linear(in_features=100, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=100, bias=True)
    (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (dropout3): Dropout(p=0.1, inplace=False)
  )
  (decoder): Transforme

In [12]:
# Training
def train(model, epochs, dataloader, criterion):
    model.train()
    for epoch in range(epochs):
        running_loss = 0
        for input_seq, target_seq in dataloader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            outputs = model(input_seq)
            target_seq = target_seq.contiguous().view(-1)
            outputs = outputs.view(-1, vocab_size)
            loss = criterion(outputs, target_seq)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().cpu().numpy()
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch} loss: {epoch_loss:.3f}")

train(model, epochs, dataloader, criterion)

Epoch 0 loss: 5.048
Epoch 1 loss: 3.241
Epoch 2 loss: 1.936
Epoch 3 loss: 1.178
Epoch 4 loss: 0.713
Epoch 5 loss: 0.437
Epoch 6 loss: 0.283
Epoch 7 loss: 0.202
Epoch 8 loss: 0.157
Epoch 9 loss: 0.127
Epoch 10 loss: 0.106
Epoch 11 loss: 0.094
Epoch 12 loss: 0.083
Epoch 13 loss: 0.074
Epoch 14 loss: 0.069
Epoch 15 loss: 0.064
Epoch 16 loss: 0.058
Epoch 17 loss: 0.056
Epoch 18 loss: 0.055
Epoch 19 loss: 0.053
Epoch 20 loss: 0.050
Epoch 21 loss: 0.049
Epoch 22 loss: 0.048
Epoch 23 loss: 0.046
Epoch 24 loss: 0.044
Epoch 25 loss: 0.043
Epoch 26 loss: 0.041
Epoch 27 loss: 0.043
Epoch 28 loss: 0.042
Epoch 29 loss: 0.039
Epoch 30 loss: 0.038
Epoch 31 loss: 0.038
Epoch 32 loss: 0.039
Epoch 33 loss: 0.037
Epoch 34 loss: 0.036
Epoch 35 loss: 0.037
Epoch 36 loss: 0.034
Epoch 37 loss: 0.035
Epoch 38 loss: 0.036
Epoch 39 loss: 0.035
Epoch 40 loss: 0.035
Epoch 41 loss: 0.033
Epoch 42 loss: 0.034
Epoch 43 loss: 0.033
Epoch 44 loss: 0.033
Epoch 45 loss: 0.034
Epoch 46 loss: 0.033
Epoch 47 loss: 0.033
Ep

## Inference

In [13]:
def return_int_vector(text):
    words = text.split()
    input_seq = torch.LongTensor([word_to_int[word] for word in words[-SEQUENCE_LENGTH:]]).unsqueeze(0)
    return input_seq

In [14]:
def sample_next(predictions):
    """
    Greedy sampling.
    """
    # Greedy approach.
    probabilities = F.softmax(predictions[:, -1, :], dim=-1).cpu()
    next_token = torch.argmax(probabilities)
    return int(next_token.cpu())

def text_generator(sentence, generate_length):
    model.eval()
    sample = sentence
    for i in range(generate_length):
        int_vector = return_int_vector(sample)
        if len(int_vector) >= SEQUENCE_LENGTH - 1:
            break
        input_tensor = int_vector.to(device)
        with torch.no_grad():
            predictions = model(input_tensor)
        next_token = sample_next(predictions)
        sample += ' ' + int_to_word[next_token]
    print(sample)
    print('\n')

In [15]:
sentences = [
    "Alice was a"
]

In [16]:
generate_length = 100

In [17]:
for sentence in sentences:
    print(f"PROMPT: {sentence}")
    text_generator(sentence, generate_length)

PROMPT: Alice was a
Alice was a curious and imaginative young girl who lived in a quiet village. She had a wild mop of blonde curls that seemed to match her adventurous spirit. One sunny afternoon, while chasing a playful white rabbit through the meadow, Alice stumbled upon a hidden rabbit hole. Without a second thought, she decided to follow the rabbit, tumbling headfirst into an enchanting world called Wonderland. As Alice fell through the rabbit hole, the world around her began to twist and distort. She felt as if she were floating in a kaleidoscope of colors and shapes. When she finally landed, she found herself


