In [1]:
import random
import re

import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
import torch.nn.functional as F
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from torchinfo import summary
from gensim.models.word2vec import LineSentence, Word2Vec
from tqdm import tqdm

In [2]:
class TextTrainDataset(IterableDataset):
    
    def __init__(self, text_file_path, w2v, seq_length=10):
        self.text_file_path = text_file_path
        self.seq_length = seq_length
        self.w2v = w2v
        self.__len = self.__count_lines_in_file()
        
    def __iter__(self):
        for text in LineSentence(self.text_file_path):
            if len(text) < 2: continue
            for start_idx in range(-self.seq_length+1, len(text)-self.seq_length-1):
                cropped_text = text[max(start_idx, 0) : start_idx+self.seq_length]
                cropped_text = self.__padd(cropped_text)
                target = text[start_idx+self.seq_length]
                yield self.w2v.wv[cropped_text], self.w2v.wv.key_to_index[target]
            
    def __padd(self, text):
        if len(text) < self.seq_length:
            padding = ['<pad>']*(self.seq_length-len(text))
            text = padding + text
        return text
            
    def __count_lines_in_file(self):
        with open(self.text_file_path) as f:
            return sum(1 for _ in f)

In [3]:
class TextValidationDataset(IterableDataset):
    
    def __init__(self, text_file_path):
        self.text_file_path = text_file_path
        
    def __iter__(self):
        for text in LineSentence(self.text_file_path):
            yield ' '.join(text)

In [4]:
class LstmTextGenerator(LightningModule):
    
    def __init__(self, text_file_path, val_file_path, word2vec_path, seq_length=10, 
                 batch_size=64, seed=42, lstm_layers=1, lstm_dropout=0, lstm_hidden_size=100,
                 dropout=0.2, bidirectional=False):
        super().__init__()
        self.save_hyperparameters()
        
        self.w2v = Word2Vec.load(self.hparams.word2vec_path)
        np.random.seed(seed)
        self.w2v.wv['<pad>'] = np.random.rand(100)
        
        self.lstm = nn.LSTM(
            input_size=100,
            hidden_size=self.hparams.lstm_hidden_size,
            batch_first=True,
            num_layers=self.hparams.lstm_layers,
            dropout=self.hparams.lstm_dropout,
            bidirectional=self.hparams.bidirectional
        )
        self.fc = nn.Linear(2*self.hparams.lstm_hidden_size, len(self.w2v.wv))
        
        self.dropout = nn.Dropout(self.hparams.dropout)
        
        self.loss = nn.CrossEntropyLoss()
        
    def generate(self, prompt, length=50, temperature=0.5):
        generated = prompt
        prompt = self.__preprocess_prompt(prompt)
        
        for _ in range(length):
            embedded_prompt = self.w2v.wv[prompt]
            embedded_prompt = torch.tensor(embedded_prompt, device=self.device)
            next_word_logits = self(torch.unsqueeze(embedded_prompt, dim=0))[0]
            word = self.__get_word_from_logits(next_word_logits, temperature)
            prompt = prompt[1:] + [word]
            
            if word not in list('.!?,'):
                generated += ' '
            generated += word
        
        return generated
    
    def __get_word_from_logits(self, next_word_logits, temperature=0.5):
        scaled_logits = next_word_logits / temperature
        adjusted_probs = F.softmax(scaled_logits, dim=-1)
        next_word_index = torch.multinomial(adjusted_probs, num_samples=1).item()
        next_word = self.w2v.wv.index_to_key[next_word_index]
        return next_word
        
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out)
        out = self.fc(out[:, -1, :])
        return out
        
    def training_step(self, batch, batch_no):
        text, target = batch
        predicted = self.forward(text)
        loss = self.loss(predicted, target)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_no):
        prompt = batch[0]
        tensorboard = self.logger.experiment
        for temperature in [1, 0.5, 0.2, 0.1, 0.01]:
            generated = self.generate(prompt, length=100, temperature=temperature)
            tensorboard.add_text(f'val_generated_{temperature}_{batch_no}', generated, global_step=self.current_epoch)
        
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.001)
        return optimizer
    
    def train_dataloader(self):
        dataset = TextTrainDataset(
            self.hparams.text_file_path,
            self.w2v,
            self.hparams.seq_length,
        )
        
        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=24
        )
        
    def val_dataloader(self):
        dataset = TextValidationDataset(
            self.hparams.val_file_path,
        )
        
        return DataLoader(
            dataset=dataset,
            batch_size=1
        )
        
    def __preprocess_prompt(self, prompt):
        prompt = prompt.lower().strip()
        prompt = re.sub(r'[^a-ząćęłńóśźż.,!? ]', '', prompt)
        prompt = prompt.replace('.', ' . ').replace('!', ' ! ').replace('?', ' ? ').replace(',', ' , ')
        prompt = prompt.split()
        prompt = [word for word in prompt if word in self.w2v.wv]
        padding = ['<pad>']*(max(self.hparams.seq_length-len(prompt), 0))
        prompt = padding + prompt
        return prompt

In [5]:
logger = TensorBoardLogger(
    save_dir='../..',
    name='logs'
)

trainer = Trainer(
    accelerator='cuda',
    max_epochs=-1,
    enable_progress_bar=True,
    logger = logger,
    check_val_every_n_epoch=10,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
generator = LstmTextGenerator(
    text_file_path='../../data/line_sentence/100k.txt',
    val_file_path='../../data/line_sentence/texts_validation.txt',
    word2vec_path='../../models/word2vec/100k/word2vec',
    seq_length=25,
    lstm_layers=3,
    lstm_dropout=0.2,
    lstm_hidden_size=100,
    dropout=0.2,
    bidirectional=True,
    batch_size=128,
)

In [7]:
summary(
    generator,
    input_size=(64, 20, 100),
    col_names=['input_size', 'output_size', 'num_params', 'params_percent']
)

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %
LstmTextGenerator                        [64, 20, 100]             [64, 36968]               --                             --
├─LSTM: 1-1                              [64, 20, 100]             [64, 20, 200]             644,800                     7.98%
├─Dropout: 1-2                           [64, 20, 200]             [64, 20, 200]             --                             --
├─Linear: 1-3                            [64, 200]                 [64, 36968]               7,430,568                  92.02%
Total params: 8,075,368
Trainable params: 8,075,368
Non-trainable params: 0
Total mult-adds (G): 1.30
Input size (MB): 0.51
Forward/backward pass size (MB): 20.98
Params size (MB): 32.30
Estimated Total Size (MB): 53.79

In [16]:
trainer.fit(generator)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | lstm    | LSTM             | 644 K 
1 | fc      | Linear           | 7.4 M 
2 | dropout | Dropout          | 0     
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
8.1 M     Trainable params
0         Non-trainable params
8.1 M     Total params
32.301    Total estimated model params size (MB)


Epoch 0: : 109583it [25:50, 70.68it/s, v_num=22]
Epoch 0: : 8891it [02:37, 56.48it/s, v_num=22]                             

In [15]:
generator.generate('dawno, dawno temu, za siedmioma górami i siedmioma lasami', temperature=1)

'dawno, dawno temu, za siedmioma górami i siedmioma lasami jednego dłonie życie. stanie mi więc? chyba kiedy pozostać odparł, sir i się od wszystkich czasu tę całą łaskę krew. ślad jak wobec mnie cię dostał. leży od. gdzie nie pozostanie obecny przy stanie, lecz i pies panu nie o zachód, ale'

In [12]:
generator.generate('dawno, dawno temu, za siedmioma górami i siedmioma lasami', temperature=0.3)

'dawno, dawno temu, za siedmioma górami i siedmioma lasami. a więc z tego będę, a nie jestem spotkamy. bądź zdrów, gdyż nie wiadomo, kim nie się. a więc nie wiadomo, kim nie powinienem, aby nie. nie mogę mi sobie, gdyż nie wiadomo o twoją. nie ma zamiaru mi'