In [1]:
import random
import re

import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from torchinfo import summary
from gensim.models.word2vec import LineSentence, Word2Vec
from tqdm import tqdm

In [2]:
class TextTrainDataset(IterableDataset):
    
    def __init__(self, text_file_path, w2v, seq_length=10):
        self.text_file_path = text_file_path
        self.seq_length = seq_length
        self.w2v = w2v
        self.__len = self.__count_lines_in_file()
        
    def __len__(self):
        return self.__len
        
    def __iter__(self):
        for text in LineSentence(self.text_file_path):
            start_idx = random.randint(-self.seq_length+1, len(text)-self.seq_length-1)
            cropped_text = text[max(start_idx, 0) : start_idx+self.seq_length]
            cropped_text = self.__padd(cropped_text)
            target = text[start_idx+self.seq_length]
            yield self.w2v.wv[cropped_text], self.w2v.wv.key_to_index[target]
            
    def __padd(self, text):
        if len(text) < self.seq_length:
            padding = ['<pad>']*(self.seq_length-len(text))
            text = padding + text
        return text
            
    def __count_lines_in_file(self):
        with open(self.text_file_path) as f:
            return sum(1 for _ in f)

In [3]:
class TextValidationDataset(IterableDataset):
    
    def __init__(self, text_file_path):
        self.text_file_path = text_file_path
        
    def __iter__(self):
        for text in LineSentence(self.text_file_path):
            yield ' '.join(text)

In [4]:
class LstmTextGenerator(LightningModule):
    
    def __init__(self, text_file_path, val_file_path, word2vec_path, seq_length=10, 
                 batch_size=64, seed=42, lstm_layers=1, lstm_dropout=0, lstm_hidden_size=100):
        super().__init__()
        self.save_hyperparameters()
        
        self.w2v = Word2Vec.load(self.hparams.word2vec_path)
        np.random.seed(seed)
        self.w2v.wv['<pad>'] = np.random.rand(100)
        
        self.lstm = nn.LSTM(
            input_size=100,
            hidden_size=self.hparams.lstm_hidden_size,
            batch_first=True,
            num_layers=self.hparams.lstm_layers,
            dropout=self.hparams.lstm_dropout
        )
        self.fc = nn.Linear(self.hparams.lstm_hidden_size, len(self.w2v.wv))
        
        self.loss = nn.CrossEntropyLoss()
        
    def generate(self, prompt, length=50):
        generated = prompt
        prompt = self.__preprocess_prompt(prompt)
        
        for _ in range(length):
            embedded_prompt = self.w2v.wv[prompt]
            embedded_prompt = torch.tensor(embedded_prompt, device=self.device)
            prediction = self(torch.unsqueeze(embedded_prompt, dim=0))
            word_index = prediction.argmax(axis=-1)
            word = self.w2v.wv.index_to_key[word_index]
            prompt = prompt[1:] + [word]
            
            if word not in list('.!?,'):
                generated += ' '
            generated += word
        
        return generated
        
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        out = torch.softmax(out, dim=-1)
        return out
        
    def training_step(self, batch, batch_no):
        text, target = batch
        predicted = self.forward(text)
        loss = self.loss(predicted, target)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_no):
        prompt = batch[0]
        tensorboard = self.logger.experiment
        generated = self.generate(prompt, length=100)
        tensorboard.add_text(f'val_generated_{batch_no}', generated, global_step=self.current_epoch)
        
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.001)
        return optimizer
    
    def train_dataloader(self):
        dataset = TextTrainDataset(
            self.hparams.text_file_path,
            self.w2v,
            self.hparams.seq_length
        )
        
        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
        )
        
    def val_dataloader(self):
        dataset = TextValidationDataset(
            self.hparams.val_file_path,
        )
        
        return DataLoader(
            dataset=dataset,
            batch_size=1
        )
        
    def __preprocess_prompt(self, prompt):
        prompt = prompt.lower().strip()
        prompt = re.sub(r'[^a-ząćęłńóśźż.,!? ]', '', prompt)
        prompt = prompt.replace('.', ' . ').replace('!', ' ! ').replace('?', ' ? ').replace(',', ' , ')
        prompt = prompt.split()
        prompt = [word for word in prompt if word in self.w2v.wv]
        padding = ['<pad>']*(max(self.hparams.seq_length-len(prompt), 0))
        prompt = padding + prompt
        return prompt

In [5]:
logger = TensorBoardLogger(
    save_dir='../..',
    name='logs'
)

trainer = Trainer(
    accelerator='cuda',
    max_epochs=-1,
    enable_progress_bar=True,
    logger = logger,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
generator = LstmTextGenerator(
    text_file_path='../../data/line_sentence/texts_punctuation.txt',
    val_file_path='../../data/line_sentence/texts_validation.txt',
    word2vec_path='../../models/word2vec/punctuation/word2vec',
    lstm_layers=2,
    lstm_dropout=0.2,
    batch_size=128
)

In [7]:
summary(
    generator,
    input_size=(64, 20, 100),
    col_names=['input_size', 'output_size', 'num_params', 'params_percent']
)

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %
LstmTextGenerator                        [64, 20, 100]             [64, 663962]              --                             --
├─LSTM: 1-1                              [64, 20, 100]             [64, 20, 100]             161,600                     0.24%
├─Linear: 1-2                            [64, 100]                 [64, 663962]              67,060,162                 99.76%
Total params: 67,221,762
Trainable params: 67,221,762
Non-trainable params: 0
Total mult-adds (G): 4.50
Input size (MB): 0.51
Forward/backward pass size (MB): 340.97
Params size (MB): 268.89
Estimated Total Size (MB): 610.37

In [8]:
trainer.fit(generator)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type             | Params
------------------------------------------
0 | lstm | LSTM             | 161 K 
1 | fc   | Linear           | 67.1 M
2 | loss | CrossEntropyLoss | 0     
------------------------------------------
67.2 M    Trainable params
0         Non-trainable params
67.2 M    Total params
268.887   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Epoch 5:  68%|██████▊   | 45/66 [00:04<00:02,  9.01it/s, v_num=2]          

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
generator.generate('dawno, dawno temu, za siedmioma górami i siedmioma lasami')

['dawno', ',', 'dawno', 'temu', ',', 'za', 'siedmioma', 'górami', 'i', 'siedmioma', 'lasami']


'dawno, dawno temu, za siedmioma górami i siedmioma lasami, a w tym, że nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie nie'