In [1]:
import random
import re
import pickle
import os.path

import numpy as np
import torch
import torchtext
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
import torch.nn.functional as F
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from torchtext.vocab import build_vocab_from_iterator
from lightning.pytorch.loggers import TensorBoardLogger
from torchinfo import summary
from gensim.models.word2vec import LineSentence, Word2Vec
from tqdm import tqdm

In [2]:
class TextTrainDataset(IterableDataset):
    
    def __init__(self, dataset_path, pad_token_idx, seq_length=10, target_length=1, padding_factor=20, padding_limit=3):
        self.dataset_path = dataset_path
        self.pad_token_idx = pad_token_idx
        self.seq_length = seq_length
        self.target_length = target_length
        self.padding_factor = padding_factor
        self.padding_limit = padding_limit
        
        with open(dataset_path, 'rb') as f:
            self.dataset = pickle.load(f)
        
    def __len__(self):
        return len(self.dataset)*10
        
    def __iter__(self):
        for _ in range(10):
            for text in self.dataset:
                length = min(random.randint(self.padding_limit, self.padding_factor), self.seq_length, len(text)-self.target_length)
                start_idx = random.randint(0, len(text)-length-self.target_length)
                cropped_text = text[start_idx : start_idx+length]
                cropped_text = self.__padd(cropped_text)
                target = text[start_idx+length:start_idx+length+self.target_length]
                yield np.array(cropped_text), np.array(target)
            
    def __padd(self, text):
        if len(text) < self.seq_length:
            padding = [self.pad_token_idx]*(self.seq_length-len(text))
            text = padding + text
        return text

In [3]:
class TextValidationDataset(IterableDataset):
    
    def __init__(self, text_file_path):
        self.text_file_path = text_file_path
        
    def __iter__(self):
        for text in LineSentence(self.text_file_path):
            yield ' '.join(text)

In [4]:
class LstmTextGenerator(LightningModule):
    
    def __init__(
        self,
        
        # files
        vocabulary_path,
        train_file_path,
        
        # architecture
        embedding_dim=100,
        lstm_layers=1,
        lstm_dropout=0,
        lstm_hidden_size=100,
        dropout=0.2,
        bidirectional=False,
        
        # training process
        batch_size=64,
        seq_length=10, 
        target_length=10,
        target_weight_decrease=1.0,
        padding_factor=20,
        padding_limit=3,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.vocabulary = torch.load(self.hparams.vocabulary_path)
        self.vocabulary.append_token('<pad>')
        
        self.embedding = nn.Embedding(
            len(self.vocabulary),
            self.hparams.embedding_dim
        )
        
        self.lstm = nn.LSTM(
            input_size=100,
            hidden_size=self.hparams.lstm_hidden_size,
            batch_first=True,
            num_layers=self.hparams.lstm_layers,
            dropout=self.hparams.lstm_dropout,
            bidirectional=self.hparams.bidirectional
        )
        
        self.fc = nn.Linear((2 if self.hparams.bidirectional else 1)*self.hparams.lstm_hidden_size, len(self.vocabulary))
        
        self.dropout = nn.Dropout(self.hparams.dropout)
        
        self.loss = nn.CrossEntropyLoss()
        
    def generate(self, prompt, length=50, temperature=0.5):
        generated = prompt
        prompt = self.__preprocess_prompt(prompt)
        
        for _ in range(length):
            input_tensor = torch.unsqueeze(torch.tensor(prompt, device=self.device), dim=0)
            next_word_logits = self(input_tensor)[0]
            word_idx = self.__get_word_from_logits(next_word_logits, temperature)
            prompt = prompt[1:] + [word_idx]
            
            word = self.vocabulary.lookup_token(word_idx)
            if word not in list('.!?,'):
                generated += ' '
            generated += word
        
        return generated
    
    def __get_word_from_logits(self, next_word_logits, temperature=0.5):
        scaled_logits = next_word_logits / temperature
        adjusted_probs = F.softmax(scaled_logits, dim=-1)
        next_word_index = torch.multinomial(adjusted_probs, num_samples=1).item()
        return next_word_index
        
    def forward(self, x):
        out = self.embedding(x)
        out, _ = self.lstm(out)
        out = self.dropout(out)
        out = self.fc(out[:, -1, :])
        return out
        
    def training_step(self, batch, batch_no):
        texts, targets = batch
        loss = 0
        weight = 1.0
        
        for i in range(self.hparams.target_length):
            predicted = self.forward(texts)
            loss_part = self.loss(predicted, targets[:, i])
            loss += weight * loss_part
            weight *= self.hparams.target_weight_decrease
            
            words_idx = predicted.argmax(dim=-1)
            texts = torch.cat((texts[:, 1:], words_idx.unsqueeze(1)), axis=1)
        
        self.log('train_loss', loss)
        return loss
        
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.001)
        return optimizer
    
    def train_dataloader(self):
        dataset = TextTrainDataset(
            self.hparams.train_file_path,
            pad_token_idx=self.vocabulary['<pad>'],
            seq_length=self.hparams.seq_length,
            target_length=self.hparams.target_length,
            padding_factor=self.hparams.padding_factor,
            padding_limit=self.hparams.padding_limit,
        )
        
        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
        )
        
    def __preprocess_prompt(self, prompt):
        tokenized = self.__tokenize(prompt)
        words_idx = self.vocabulary(tokenized)
        words_idx = [idx for idx in words_idx if idx != -1]
        padding = [0]*(max(self.hparams.seq_length-len(prompt), 0))
        prompt = padding + words_idx
        return prompt
    
    def __tokenize(self, text):
        text = text.lower()
        text = re.sub(r'[^a-ząćęłńóśźż.,!?\- ]', ' ', text)
        text = re.sub(r'([,-.!?])', ' \\1 ', text)
        text = [word for word in text.split(' ') if word]
        return text

In [5]:
logger = TensorBoardLogger(
    save_dir='../..',
    name='logs'
)

trainer = Trainer(
    accelerator='cuda',
    max_epochs=-1,
    enable_progress_bar=True,
    logger = logger,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
generator = LstmTextGenerator(
    # files
    train_file_path='../../data/binary_texts/fairytales.pickle',
    vocabulary_path='../../models/vocabulary.pth',
    
    # architecture
    lstm_layers=3,
    lstm_dropout=0.2,
    lstm_hidden_size=100,
    dropout=0.2,
    bidirectional=True,
    
    # training
    seq_length=35,
    target_length=2,
    batch_size=128,
    target_weight_decrease=1.0,
    padding_factor=80,
    padding_limit=4,
)

In [7]:
generator = LstmTextGenerator.load_from_checkpoint('../../logs/version_7/checkpoints/epoch=342-step=93639.ckpt')

In [8]:
summary(
    generator,
    input_size=(64, 20),
    col_names=['input_size', 'output_size', 'num_params', 'params_percent'],
    dtypes=[torch.LongTensor],
    device='cpu'
)

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %
LstmTextGenerator                        [64, 20]                  [64, 150001]              --                             --
├─Embedding: 1-1                         [64, 20]                  [64, 20, 100]             15,000,100                 32.75%
├─LSTM: 1-2                              [64, 20, 100]             [64, 20, 200]             644,800                     1.41%
├─Dropout: 1-3                           [64, 20, 200]             [64, 20, 200]             --                             --
├─Linear: 1-4                            [64, 200]                 [64, 150001]              30,150,201                 65.84%
Total params: 45,795,101
Trainable params: 45,795,101
Non-trainable params: 0
Total mult-adds (G): 3.71
Input size (MB): 0.01
Forward/backward pass size (MB): 79.87
Params size (MB): 183.18
Estimated Total Size (MB): 263.06

In [9]:
trainer.fit(generator)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | vocabulary | Vocab            | 0     
1 | embedding  | Embedding        | 15.0 M
2 | lstm       | LSTM             | 644 K 
3 | fc         | Linear           | 30.2 M
4 | dropout    | Dropout          | 0     
5 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
45.8 M    Trainable params
0         Non-trainable params
45.8 M    Total params
183.180   Total estimated model params size (MB)
  rank_zero_warn(


Epoch 19:   8%|▊         | 21/269 [00:01<00:12, 19.28it/s, v_num=13] 

In [21]:
generator.generate('dawno, dawno temu, za siedmioma górami i siedmioma', temperature=1)

'dawno, dawno temu, za siedmioma górami i siedmioma zimy, wiadro były balony i w książkach na fachu przystrojone się aż lekko ptak się spotkały przez by węgiel nie złożę wszystkie zwierzęta na stałe zdrowie cicho przepisane znaczy fabryczne ich głośne groszy dzwonek co się wziął wziął dba z, ten dzień miesiąc szybko zaproszę pisać diety pięknie'

In [24]:
generator.generate('Pewnego słonecznego dnia czerwony kapturek szedł do swojej babci z koszyczkiem', temperature=1)

'Pewnego słonecznego dnia czerwony kapturek szedł do swojej babci z koszyczkiem powrotem. - cóż to jest. kundel aż z dala. - super mały góra, albo zachować ładnie piskiem orzech, autorka l. mróz - cieślik wierszyk z obrazkiem - bajeczki - pręgi, uwaga, sio. ja gotowy - - wnuczek coś złoży! mamo'