In [1]:
%load_ext autoreload

In [2]:
%autoreload

In [3]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from tqdm import trange
from torch.utils.data import DataLoader
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from transformers import XLMTokenizer, AutoTokenizer
from torchinfo import summary

from transformer import *
from callback import GenerateCallback

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# tokenizer = XLMTokenizer.from_pretrained("allegro/herbert-klej-cased-tokenizer-v1")
tokenizer = XLMTokenizer.from_pretrained("allegro/herbert-klej-cased-tokenizer-v1")

In [5]:
len(tokenizer)

50560

In [6]:
# encoded = tokenizer.encode("witaj świecie")
# print(encoded)
# decoded = tokenizer.decode(encoded)
# print(decoded)

In [7]:
import re
import glob
import random
import pickle
from pathlib import Path

import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm

from utils import pad


class TextTrainDataset(Dataset):
    
    def __init__(self, dataset_path, tokenizer, min_src_length, max_src_length, tgt_length, remove_dialogs=True, remove_special_chars=False, lowercase=False, tqdm=False, cache_path=None, cache_ignore=False, min_line_length=0):
        self.samplesset_path = dataset_path
        self.tokenizer = tokenizer
        self.min_src_length = min_src_length
        self.max_src_length = max_src_length
        self.tgt_length = tgt_length
        self.remove_dialogs = remove_dialogs
        self.remove_special_chars = remove_special_chars
        self.lowercase = lowercase
        self.tqdm = tqdm
        self.cache_path = cache_path
        self.cache_ignore = cache_ignore
        self.min_line_length = min_line_length
        
        self.samples = self.__get_samples()
        
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        text = self.samples[idx]
        src, tgt = self.__split_into_src_tgt(text)
        return np.array(src, dtype=np.int32), np.array(tgt, dtype=np.int32)
    
    def __get_samples(self):
        if self.cache_path is None or self.cache_ignore or not Path(self.cache_path).exists():
            samples = self.__create_samples()
            self.__save_samples_to_cache(samples)
            return samples
        else:
            return self.__load_samples_from_cache()
        
    def __load_samples_from_cache(self):
        with open(self.cache_path, 'rb') as f:
            return pickle.load(f)
        
    def __save_samples_to_cache(self, samples):
        Path(self.cache_path).parent.mkdir(parents=True, exist_ok=True) 
        with open(self.cache_path, 'wb') as f:
            return pickle.dump(samples, f)
        
    def __create_samples(self):
        paths = list(glob.glob(f'{self.samplesset_path}/**/*.txt', recursive=True))
        random.shuffle(paths)
        data = []
        
        if self.tqdm:
            paths = tqdm(paths)
        
        for path in paths:
            text = self.__read_text_from_file(path)
            samples = self.__get_samples_from_text(text)
            data.extend(samples)
                
        return data
                
    def __get_samples_from_text(self, text):
        samples = []
        tokenized = self.tokenizer.encode(text)
        
        min_length = self.min_src_length + 1
        
        start_idx = 0
        end_idx = len(tokenized) - min_length - 1
        
        for idx in range(start_idx, end_idx):
            sequence = tokenized[idx : idx+self.max_src_length+self.tgt_length]
            samples.append(sequence)
            
        return samples
    
    def __split_into_src_tgt(self, text):
        max_src_length = min(len(text)-1, self.max_src_length)
        src_length = random.randint(self.min_src_length, max_src_length) if self.min_src_length != max_src_length else max_src_length
        
        src = text[:src_length]
        tgt = text[src_length:src_length+self.tgt_length-1]
        tgt.insert(0, self.tokenizer.bos_token_id)
        
        src = pad(src, self.tgt_length, pad_token=self.tokenizer.pad_token_id)
        tgt = pad(tgt, self.tgt_length, pad_token=self.tokenizer.pad_token_id)
        
        return src, tgt

    def __read_text_from_file(self, path):
        with open(path, encoding='utf-8') as f:
            lines = f.readlines()
            lines = map(self.__preprocess_line, lines)
            lines = filter(lambda line: len(line) > self.min_line_length, lines)
            if self.remove_dialogs:
                lines = self.__remove_dialogs(lines)
            text = '\n'.join(lines)
            if self.remove_special_chars:
                text = re.sub(r'[^a-ząćęłńóśźż.,!? \n]', ' ', text, flags=re.IGNORECASE)
            return text

    def __remove_dialogs(self, lines):
        return filter(lambda line: not self.__is_dialog_line(line), lines)
    
    def __preprocess_line(self, line):
        line = line.strip()
        if self.lowercase:
            line = line.lower()
        return line
        
    @staticmethod
    def __is_dialog_line(line):
        return '—' in line or '–' in line or '-' in line or '„' in line or '"' in line

## Testing dataset

In [8]:
dataset = TextTrainDataset(
    '../../data/training_trans',
    tokenizer,
    min_src_length=6,
    max_src_length=20,
    tgt_length=100,
    lowercase=True,
    tqdm=True,
    cache_path='.cache/dataset',
    # cache_ignore=True,
    remove_special_chars=True,
    min_line_length=25
)

In [10]:
for x in tqdm(dataset):
    pass

  0%|          | 0/174852 [00:00<?, ?it/s]


ValueError: empty range for randrange() (6, 3, -3)

: 

In [None]:
src, tgt = dataset[random.randint(0, len(dataset)-1)]
print(len(src), len(tgt))
print(src)
print(tgt)

100 100
[    2     2     2     2     2     2     2     2     2     2     2     2
     2     2     2     2     2     2     2     2     2     2     2     2
     2     2     2     2     2     2     2     2     2     2     2     2
     2     2     2     2     2     2     2     2     2     2     2     2
     2     2     2     2     2     2     2     2     2     2     2     2
     2     2     2     2     2     2     2     2     2     2     2     2
     2     2     2     2     2     2     2     2     2     2 27873  1074
   990  2313   281 20728  1336    14  2313    14  2480    20    22  8782
  1157    15   256   824]
[    0   756  6652    18    37   407 48181  1846    15  3027  1063 18547
    20    14    26 23446  4590 22231  2262    91 14295  2382    15    25
  2893  1423  1234 44052   955    15   601    39  2277    14  2442    20
  8095   146    19   254    16   358 10202  2105  6154 15136  1619   429
    20    14    25 11334   260  2382   990    32   470  5849 17705   937
   281 20728  133

# Transformer

In [None]:
from torch.optim import Adam
from lightning.pytorch import LightningModule


class TransformerLightning(LightningModule):
    
    def __init__(self, seq_length, lr=0.001):
        super().__init__()
        self.save_hyperparameters()
        
        self.tokenizer = XLMTokenizer.from_pretrained("allegro/herbert-klej-cased-tokenizer-v1")
        # self.tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
        
        self.transformer = Transformer(
            src_vocab_size=len(self.tokenizer),
            tgt_vocab_size=len(self.tokenizer),
            d_model=512,
            num_heads=8,
            num_layers=6,
            d_ff=2048,
            max_seq_length=100,
            dropout=0.1,
            mask_token=self.tokenizer.pad_token_id,
        )
        
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
        
    def forward(self, src, tgt):
        return self.transformer(src, tgt)
        
    def training_step(self, batch, batch_no):
        src_data, tgt_data = batch
        output = self(src_data, tgt_data[:, :-1])
        predicted = output.contiguous().view(-1, len(self.tokenizer))
        target = tgt_data[:, 1:].contiguous().view(-1)
        loss = self.criterion(predicted, target.long())
        self.log('train_loss', loss)
        return loss
        
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.98), eps=1e-9)
        return optimizer
    
    def generate(self, prompt, length=50, temperature=0.5):
        src_ids = self.tokenizer.encode(prompt)[1:-1]
        generated_ids = self.__generate_ids(src_ids, length, temperature)
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        return generated_text
    
    def __generate_ids(self, src_ids, length=200, temperature=0.5):
        src_ids = pad(src_ids, self.hparams.seq_length, self.tokenizer.pad_token_id)
        tgt_ids = [self.tokenizer.bos_token_id]
        
        src_tensor = torch.unsqueeze(torch.tensor(src_ids, device=self.device), dim=0)
        
        self.eval()
        
        with torch.no_grad():
            for _ in range(length):
                tgt_padded = pad(tgt_ids, self.hparams.seq_length, self.tokenizer.pad_token_id)
                tgt_tensor = torch.unsqueeze(torch.tensor(tgt_padded, device=self.device), dim=0)
                
                # pos = self.hparams.seq_length - len(tgt_ids)
                output = self(src_tensor, tgt_tensor).squeeze(0)[-1]
                word_idx = self.__sample_word_idx(output, temperature)
                tgt_ids.append(word_idx)
            
        self.train()
        return src_ids + tgt_ids[1:]
        
    @staticmethod
    def __sample_word_idx(outputs, temperature=1.0):
        scaled_logits = torch.log_softmax(outputs, dim=0) / temperature
        adjusted_probs = F.softmax(scaled_logits, dim=-1)
        next_word_index = torch.multinomial(adjusted_probs, num_samples=1).item()
        return next_word_index

: 

# Training

In [None]:
train_dataloader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0
)

In [None]:
transformer = TransformerLightning(seq_length=100)

In [None]:
logger = TensorBoardLogger(
    save_dir='.',
    name='logs'
)

generate_callback = GenerateCallback(
    'Pewnego dnia czerwony kapturek szedł przez las z koszyczkiem jedzenia do swojej babci, która mieszkała w lesie. Śledził go jednak zły wilk, który chciał zjeść dziewczynkę.',
    temperatures=[0.01, 0.1, 0.2, 0.3, 0.5, 0.7],
    length=100,
    interval=100
)

trainer = Trainer(
    accelerator='cuda',
    precision='16-mixed',
    max_epochs=-1,
    enable_progress_bar=True,
    logger = logger,
    callbacks=[generate_callback],
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(transformer, train_dataloaders=train_dataloader)

In [None]:
transformer.generate('Pewnego słonecznego dnia czerwony kapturek szedł do swojej babci z koszyczkiem. Kapturek był koloru', temperature=0.2)

torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])
torch.Size([50560])


KeyboardInterrupt: 

# Training old

In [None]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    print('output', output.shape)
    predicted = output.contiguous().view(-1, tgt_vocab_size)
    target = tgt_data[:, 1:].contiguous().view(-1)
    print(predicted.shape, target.shape)
    break
    loss = criterion(predicted, target)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

output torch.Size([64, 99, 5000])
torch.Size([6336, 5000]) torch.Size([6336])
