In [109]:
from collections.abc import Callable
import torch, torch.nn as nn
torch.manual_seed(123)

def read_dataset(filename: str, visualize: bool = False) -> str:
    ''' Read a text file and return the contents as a string. 
    
        Args:
            filename - Path of the text file to be read.
            visualize - Whether to visualize the statistics of the file contents.
            
        Returns
            Content of the file as a string.
    '''
    with open(filename, 'r') as reader:
        data = reader.read()
        
    if visualize:
        print(f'Visualizing dataset at path {filename}.')
        print(f'First 100 characters:\n{data[0:100]}.')
        print(f'Length: {len(data)}.')
    return data
        
def create_vocabulary(data: str, visualize: bool = False) -> (list[str], dict[str, int], dict[int, str], Callable[str, list[int]], Callable[list[int], str]):
    vocabulary = sorted(list(set(data)))
    token_to_index_map = {token:index for (index, token) in enumerate(vocabulary)}
    index_to_token_map = {index:token for (index, token) in enumerate(vocabulary)}

    if visualize:
        print(f'Visualizing vocabulary.')
        print(f'Length of vocabulary: {len(vocabulary)}.')
        print(f'Vocabulary is {"".join(vocabulary)}.')
        print(f'Token to index map sorted is {token_to_index_map}')
        print(f'Index to token map sorted is {index_to_token_map}')              
        
    def encoder(input: str) -> (list[int]):
        ''' Encodes the input string. 
            
            Args:
                input: string of text to be encoded.
                
            Returns:
                List of indices of the tokens in the input string.
        '''
        return [token_to_index_map[token] for token in input]
    
    def decoder(input: list[int]) -> str:
        ''' Decodes the input token index into text.
        
            Args:
                input: List of indices of tokens in the text to be decoded.
                
            Returns:
                String corresponding to the decoded text.
        '''
        return ''.join([index_to_token_map[index] for index in input])
        
    return (vocabulary, token_to_index_map, index_to_token_map, encoder, decoder)

def run_tokenizer_example(run: bool = False) -> None:
    ''' Run example text using character level tokenizer.'''
    if run:
        print('Running Tokenizer example.')
        input_text = 'Hello, how are you?'
        tokenized_text = encoder(input_text)
        decoded_text = decoder(tokenized_text)
        print(f'{input_text=}, {tokenized_text=}, {decoded_text=}.')
        
def visualize_batch(x, y, skip_visualization: bool = True):
    if not skip_visualization:
        for sample in range(x.shape[0]):
            for context in range(x.shape[1]):
                print(f' Context: {x[sample, :context+1]}. Target: {y[sample, context]}.')

def create_batch(split, block_size, batch_size):
    data = train_set if split == 'train' else val_set
    batch_start_index = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in batch_start_index])
    y = torch.stack([data[i+1: i+1+block_size] for i in batch_start_index])
    return (x,y)
            
@torch.no_grad()
def evaluate_loss(batch_index, model, batch_size, block_size):
    model.eval()

    (x, y) = create_batch('train', block_size, batch_size)
    predictions = model(x)
    (B, T, C) = predictions.shape
    predictions = predictions.view(B*T, C)
    y = y.view(-1)
    train_loss = nn.functional.cross_entropy(predictions, y)
    
    (x, y) = create_batch('val', block_size, batch_size)
    predictions = model(x)
    (B, T, C) = predictions.shape
    predictions = predictions.view(B*T, C)
    y = y.view(-1)
    val_loss = nn.functional.cross_entropy(predictions, y)
    model.train()
    
    print(f'Step: {batch_index}. Train loss: {train_loss.item()}. Validation loss: {val_loss.item()}')
    

# Read input file.
filename = 'data/tinyshakespeare.txt'
data = read_dataset(filename)
(vocabulary, token_to_index_map, index_to_token_map, encoder, decoder) = create_vocabulary(data, False)
run_tokenizer_example(False)

# Tokenize dataset.
input_sequence = torch.tensor(encoder(data), dtype=torch.long)
#print(f'Tokenized input sequence is {input_sequence}.')
print(f'Shape: {input_sequence.shape}, Type: {input_sequence.dtype}.')

dataset_split_fraction = 0.9
num_train_samples = int(dataset_split_fraction * len(data))
train_set = input_sequence[:num_train_samples]
val_set = input_sequence[num_train_samples:]
print(f'Number of train samples is {len(train_set)}. Number of validation samples is {len(val_set)}.')



max_block_size = 24
batch_size = 32
num_batches = 1000
num_decoder_blocks = 10
vocabulary_size = len(vocabulary)
embedding_dimension = 64
num_heads = 8
head_dimension = 16


#model = BigramLanguageModel(len(vocabulary))
model = GPT(num_decoder_blocks, vocabulary_size, embedding_dimension, num_heads, head_dimension, max_block_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for batch_index in range(num_batches):
    (x, y) = create_batch('train', max_block_size, batch_size)
    visualize_batch(x, y, True)
    predictions = model(x)

    (B, T, C) = predictions.shape
    predictions = predictions.view(B*T, C)
    y = y.view(-1)
    loss = nn.functional.cross_entropy(predictions, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if batch_index % 10 == 0:
        evaluate_loss(batch_index, model, batch_size, max_block_size)
        generated_tokens = model.generate(torch.zeros((1, 1), dtype = torch.long), 10).tolist()[0]
        generated_text = decoder(generated_tokens)
        print(f'Generated text is \n {generated_text}.')

    
        


Shape: torch.Size([1115394]), Type: torch.int64.
Number of train samples is 1003854. Number of validation samples is 111540.
Step: 0. Train loss: 4.122864723205566. Validation loss: 4.122955799102783
Generated text is 
 
SqaVBGzyea.
Step: 10. Train loss: 4.019296169281006. Validation loss: 4.034889221191406
Generated text is 
 
3VbgkLrVVX.
Step: 20. Train loss: 3.9486913681030273. Validation loss: 3.9343836307525635
Generated text is 
 
d?jSr:ta!3.
Step: 30. Train loss: 3.845771074295044. Validation loss: 3.8109352588653564
Generated text is 
 
e&as'Ca,wg.
Step: 40. Train loss: 3.6778366565704346. Validation loss: 3.726238489151001
Generated text is 
 
ozRd
I$n: .
Step: 50. Train loss: 3.540225028991699. Validation loss: 3.664419412612915
Generated text is 
 
kaZUri OOo.
Step: 60. Train loss: 3.469433069229126. Validation loss: 3.4516074657440186
Generated text is 
 
a  AaXual .
Step: 70. Train loss: 3.41318416595459. Validation loss: 3.416707754135132
Generated text is 
 
 Y& SlvgaG.


Step: 730. Train loss: 2.636291027069092. Validation loss: 2.6672399044036865
Generated text is 
 
te ganir w.
Step: 740. Train loss: 2.8083455562591553. Validation loss: 2.704103469848633
Generated text is 
 
erthiyar;
.
Step: 750. Train loss: 2.663700819015503. Validation loss: 2.6706902980804443
Generated text is 
 
oe wedor,
.
Step: 760. Train loss: 2.7075910568237305. Validation loss: 2.7369544506073
Generated text is 
 
rshakochov.
Step: 770. Train loss: 2.7195827960968018. Validation loss: 2.6510214805603027
Generated text is 
 
l la. t ce.
Step: 780. Train loss: 2.621515989303589. Validation loss: 2.7820844650268555
Generated text is 
 
i HDe n-i .
Step: 790. Train loss: 2.6632611751556396. Validation loss: 2.766740083694458
Generated text is 
 
le icoard:.
Step: 800. Train loss: 2.7473552227020264. Validation loss: 2.641324996948242
Generated text is 
 
i flt y hy.
Step: 810. Train loss: 2.7233469486236572. Validation loss: 2.7877390384674072
Generated text is 
 
er g bi ag.
S

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[0.2745,   -inf,   -inf],
        [0.8573, 0.8993,   -inf],
        [0.9268, 0.7388, 0.7179]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.4895, 0.5105, 0.0000],
        [0.3788, 0.3138, 0.3074]])


tensor([[-0.1115,  0.1204, -0.3696, -0.2404],
        [-1.1969,  0.2093, -0.9724, -0.7550],
        [ 0.3239, -0.1085,  0.2103, -0.3908]])
tensor([-0.6012, -2.7151,  0.0349])


In [104]:
print(model)

GPT(
  (token_embedding_layer): Embedding(65, 64)
  (positional_encoding_layer): Embedding(24, 64)
  (transformer_decoders): ModuleList(
    (0-9): 10 x TransformerDecoderBlock(
      (attention_layer): MultiHeadMaskedAttention(
        (Wq): Linear(in_features=64, out_features=64, bias=True)
        (Wk): Linear(in_features=64, out_features=64, bias=True)
        (Wv): Linear(in_features=64, out_features=64, bias=True)
        (head_merge_layer): Linear(in_features=64, out_features=64, bias=True)
      )
      (mlp_layer): MLP(
        (layer): Sequential(
          (0): Linear(in_features=64, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=64, bias=True)
        )
      )
      (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (head_layer): Linear(in_features=64, out_features=65, bias=True)
)
