In [1]:
# Import libraries
import torch
import torch.nn as nn

from src.data.data_loader import create_dataloaders
from src.model.transformer import build_transformer
from src.model.transformer import Transformer
from src.train.training import train_model
from src.utils.utils import get_device
from nltk.tokenize import word_tokenize

from src.utils.constants import PADDING, UNKNOWN, START_OF_SENTENCE, END_OF_SENTENCE

  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /opt/anaconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [2]:
# Initialize model and training parameters

# Size of embedding vector
d_model = 512
# Max sequence length for input words/tokens
seq_len = 100
# Dropout rate
dropout = 0.1
# number of encoder blocks
num_layers = 1
# number of attention heads
num_heads = 8
# Number of hidden nodes for feed-forward layer
d_ff = 4*d_model

# Number of epochs
epochs = 5
# Batch size for training
batch_size = 128

# Train file
train_file = './data/train/poems.txt'

In [3]:
# Get a device to use for training/inference
device = get_device()

# Create training and testing data loaders
train_dataloader, vocab = create_dataloaders(batch_size, seq_len, train_file)

print(f'Training data size: {len(train_dataloader) * batch_size}')

Number of tokenized words:  194655
Number of tokenized words after adding <eos>:  194755
Training data size: 194688


In [4]:
# Create encoder only transformer model
encoder_only_transformer_model = build_transformer(d_model, len(vocab), seq_len, dropout,
                                                   num_layers, num_heads, d_ff).to(device)

print(encoder_only_transformer_model)

Transformer(
  (embed): InputEmbedding(
    (embedding): Embedding(6993, 512)
  )
  (pos): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderBlock(
        (self_attention): MultiHeadAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (query_linear_layer): Linear(in_features=512, out_features=512, bias=True)
          (key_linear_layer): Linear(in_features=512, out_features=512, bias=True)
          (value_linear_layer): Linear(in_features=512, out_features=512, bias=True)
          (output_linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForward(
          (linear_1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear_2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (norm): LayerN

In [5]:
# Train model

# Create optimizer and loss function
optimizer = torch.optim.Adam(encoder_only_transformer_model.parameters())
loss_fn = nn.CrossEntropyLoss()

# Start training the model
train_model(epochs, encoder_only_transformer_model, train_dataloader,
            loss_fn, optimizer, device)

Epoch 1, Train Loss 0.8270629644393921
Epoch 2, Train Loss 0.23350630700588226
Epoch 3, Train Loss 0.20912890136241913
Epoch 4, Train Loss 0.19805137813091278
Epoch 5, Train Loss 0.19233782589435577


In [6]:
# Save model
from pathlib import Path

# Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

# Create model save path
MODEL_NAME = "07_text_generation.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

In [7]:
# Save the model state dict
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=encoder_only_transformer_model.state_dict(), f=MODEL_SAVE_PATH)

Saving model to: models/07_text_generation.pth


In [8]:
# Create new instance of model and load saved state dict
loaded_model = build_transformer(d_model, len(vocab), seq_len, dropout,
                                num_layers, num_heads, d_ff)
loaded_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
loaded_model.to(device)

UNK = 1

def generate_text(input, max_tokens_to_generate):
    with torch.inference_mode():
        
        output = input.clone()
        for _ in range(max_tokens_to_generate):
            
            curr_seq_len = input.size(1)
            
            if curr_seq_len > seq_len:
                input = input[:, -seq_len:]
            
            encoder_output = loaded_model.encode(input)
            y_logits = loaded_model.project(encoder_output)
            
            # for all the batches, get the embeds for last predicted sequence
            y_logits = y_logits[:, -1, :] 
            
            # for all the batches, get the embeds for last predicted sequence
            probs = y_logits.softmax(dim=1)            
            # get the probable token based on the input probs
            idx_next = torch.multinomial(probs, num_samples=1) 

            input = torch.cat([input, idx_next], dim=1)
            output = torch.cat([output, idx_next], dim=1)
            
        return output

In [13]:
text = 'Love'
input = [vocab.get(token, UNKNOWN) for token in word_tokenize(text)]
input_tensor = torch.tensor(input, dtype=torch.long).unsqueeze(0).to(device)

output_tensor = generate_text(input_tensor, 200).squeeze()
output_array_tokens = output_tensor.cpu().numpy()

sorted_items = sorted(vocab.items(), key=lambda item: item[1])
sorted_keys = [item[0] for item in sorted_items]

output_array_words = [sorted_keys[token] for token in output_array_tokens]
print(' '.join(output_array_words))

Love the way you conquer your fear , You know hearts do n't break around here , Oh yeah , yeah , Yeah-yeah , yeah-yeah She is the river flowin ' nowhere , And tin wind chimes used for doorbells , Fields and trees and her smell fill my lungs , Spend my summertime beside her , And the rest of the year the same , She is the flint that sparks the lighter , And the fuel that will hold the flame , oh Roses , roses laid upon your bed spread , oh my , All this , all this , all this I know But every night I 'll kiss you , you 'll say in my ear , Oh we 're in love , are n't we ? Hands in your hair Fingers and thumbs , baby I feel safe when you 're holding me near , Love the way that you conquer your fear , You know hearts do n't break around here , Oh yeah , yeah , yeah , yeah , Yeah-yeah , yeah-yeah Well , I 've found love inside , The arms of the river flowin ' nowhere And tin wind chimes


In [14]:
text = 'conquer'
input = [vocab.get(token, UNKNOWN) for token in word_tokenize(text)]
input_tensor = torch.tensor(input, dtype=torch.long).unsqueeze(0).to(device)

output_tensor = generate_text(input_tensor, 200).squeeze()
output_array_tokens = output_tensor.cpu().numpy()

sorted_items = sorted(vocab.items(), key=lambda item: item[1])
sorted_keys = [item[0] for item in sorted_items]

output_array_words = [sorted_keys[token] for token in output_array_tokens]
print(' '.join(output_array_words))

conquer , You know maybe These people that hate me But you , I knew you were n't around here And you wanted to make me and it 's real , When we watched the sunset over the castle on the hill , Over the castle on the hill , Over the castle on the hill '' When I was six years old , I broke my leg , I was running from my brother and his friends , And tasted the sweet perfume of the mountain grass I rolled down , I was younger then , Take me back that you said , take me back now Now we got down , `` I 'll never ring '' He 's in the rain '' He said , `` no do n't get '' Have ever known That 's been and I would 've never leave me I thought , Back when we had you figured out Something 's We were the nerve To touch my hand It 's nice to have to have a friend ( Ooh ) It was so nice being midnight It 's nice to leave It 's stayed true I kind of knew you He 's
