In [164]:
from pathlib import Path
import random

import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
import torchmetrics
from lightning.pytorch.loggers import TensorBoardLogger
from torchinfo import summary
from gensim.models.word2vec import LineSentence, Word2Vec
from tqdm import tqdm

In [169]:
w2v = Word2Vec.load('../../models/word2vec/punctuation/word2vec')

In [170]:
w2v.wv.key_to_index['kot']

3387

In [111]:
words = ['ala', 'ma', 'kota']
w2v.wv[words].shape

(3, 100)

In [171]:
class TextTrainDataset(IterableDataset):
    
    def __init__(self, text_file_path, w2v, seq_length=10):
        self.text_file_path = text_file_path
        self.seq_length = seq_length
        self.w2v = w2v
        
    def __iter__(self):
        for text in LineSentence(self.text_file_path):
            text = [word for word in text if word in self.w2v.wv]
            start_idx = random.randint(-self.seq_length+1, len(text)-self.seq_length-1)
            cropped_text = text[max(start_idx, 0) : start_idx+self.seq_length]
            self.__padd(cropped_text)
            target = text[start_idx+self.seq_length]
            yield self.w2v.wv[cropped_text], w2v.wv.key_to_index[target]
            
    def __padd(self, text):
        if len(text) < self.seq_length:
            text.extend(['<pad>']*(self.seq_length-len(text)))
            
    

In [152]:
dataset = TextTrainDataset('../../data/line_sentence/texts_punctuation.txt', w2v, 8)

In [153]:
for text, target in dataset:
    print(text.shape, target.shape)
    break

(8, 100) (100,)


In [180]:
class LstmTextGenerator(LightningModule):
    
    def __init__(self, text_file_path, word2vec_path, seq_length=10, batch_size=64, seed=42):
        super().__init__()
        self.save_hyperparameters()
        
        self.w2v = Word2Vec.load(self.hparams.word2vec_path)
        np.random.seed(seed)
        self.w2v.wv['<pad>'] = np.random.rand(100)
        
        self.lstm = nn.LSTM(input_size=100, hidden_size=100, batch_first=True)
        self.fc = nn.Linear(100, len(self.w2v.wv))
        
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out
        
    def training_step(self, batch, batch_no):
        text, target = batch
        predicted = self.forward(text)
        return self.loss(predicted, target)
        
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.001)
        return optimizer
    
    def train_dataloader(self):
        dataset = TextTrainDataset(
            self.hparams.text_file_path,
            self.w2v,
            self.hparams.seq_length
        )
        
        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=24
        )

In [177]:
logger = TensorBoardLogger(
    save_dir='../..',
    name='logs'
)

trainer = Trainer(
    accelerator='cuda',
    max_epochs=1,
    enable_progress_bar=True,
    logger = logger,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [181]:
generator = LstmTextGenerator(
    text_file_path='../../data/line_sentence/texts_punctuation.txt',
    word2vec_path='../../models/word2vec/punctuation/word2vec',
    batch_size=64
)

In [167]:
summary(
    generator,
    input_size=(64, 20, 100),
    col_names=['input_size', 'output_size', 'num_params', 'params_percent']
)

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %
LstmTextGenerator                        [64, 20, 100]             [64, 663962]              --                             --
├─LSTM: 1-1                              [64, 20, 100]             [64, 20, 100]             80,800                      0.12%
├─Linear: 1-2                            [64, 100]                 [64, 663962]              67,060,162                 99.88%
Total params: 67,140,962
Trainable params: 67,140,962
Non-trainable params: 0
Total mult-adds (G): 4.40
Input size (MB): 0.51
Forward/backward pass size (MB): 340.97
Params size (MB): 268.56
Estimated Total Size (MB): 610.05

In [182]:
trainer.fit(generator)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type             | Params
------------------------------------------
0 | lstm | LSTM             | 80.8 K
1 | fc   | Linear           | 67.1 M
2 | loss | CrossEntropyLoss | 0     
------------------------------------------
67.1 M    Trainable params
0         Non-trainable params
67.1 M    Total params
268.564   Total estimated model params size (MB)


Epoch 0: : 177it [00:57,  3.07it/s, v_num=2]
Epoch 0: : 519it [00:19, 26.94it/s, v_num=2]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
