In [13]:
# Standard library imports
import json
import os

from copy import deepcopy
from typing import Iterable, List

# Third party imports
import pandas as pd
import torch
import torch.nn as nn
import tqdm

from nltk.tokenize import word_tokenize

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k

# Local application imports
import src.constants as const
from src.dataset_utils import yield_tokens, sentence_to_tensor, load_files, build_vocab_transformation, tokenize_source, tokenize_target
from src.training_utils import train_epoch, evaluate, sequential_transforms, tensor_transform
from src.transcription_dataset import TranscriptionDataset
from src.transformer_model import Seq2SeqTransformer, generate_square_subsequent_mask, create_mask
from src.syllable_splitter import split_word

In [2]:
import warnings
warnings.filterwarnings('ignore')

# Dataset stuff

In [3]:
# Find all the files
files = load_files('/mnt/d/Projects/masters-thesis/data/transcriptions')

filepaths_to_size = pd.read_csv('/mnt/d/Projects/masters-thesis/data/filepath_to_size.csv')
lines_count = filepaths_to_size['size'].sum()

f'Брой файлове: {len(files)}, брой редове: {lines_count}'

'Брой файлове: 47825, брой редове: 80615534'

In [4]:
# Train & test split
sentences_to_use = 50000
train_split = int(const.TRAIN_TEST_SPLIT * sentences_to_use)
validation_split = int((const.TRAIN_TEST_SPLIT + const.TRAIN_VALIDATION_SPLIT) * sentences_to_use)

In [5]:
train_dataset = TranscriptionDataset(files, tokenization_src=tokenize_source, tokenization_tgt=tokenize_target,
                                     start_index=0, end_index=train_split)
validation_dataset = TranscriptionDataset(files, tokenization_src=tokenize_source, tokenization_tgt=tokenize_target,
                                          start_index=train_split, end_index=validation_split)
test_dataset = TranscriptionDataset(files, tokenization_src=tokenize_source, tokenization_tgt=tokenize_target,
                                    start_index=validation_split, end_index=sentences_to_use)

In [6]:
for ln in [const.SRC_LANGUAGE, const.TGT_LANGUAGE]:
    # Create torchtext's Vocab object
    const.vocab_transform[ln] = build_vocab_transformation(train_dataset, ln)

In [7]:
text_transform_src = sequential_transforms(const.vocab_transform[const.SRC_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

vowels_transcription = ['a', 'ʌ', 'ɤ̞',  'ɐ', 'ɔ', 'o', 'u', 'ɛ', 'i']
text_transform_tgt = sequential_transforms(const.vocab_transform[const.TGT_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor


In [8]:
# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform_src(src_sample))
        tgt_batch.append(text_transform_tgt(tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=const.PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=const.PAD_IDX)
    return src_batch, tgt_batch

In [9]:
train_dataloader = DataLoader(deepcopy(train_dataset), batch_size=128, collate_fn=collate_fn)
validation_dataloader = DataLoader(deepcopy(validation_dataset), batch_size=128, collate_fn=collate_fn)
test_dataloader = DataLoader(deepcopy(test_dataset), batch_size=128, collate_fn=collate_fn)

# Model stuff

In [12]:

torch.manual_seed(0)

SRC_VOCAB_SIZE = len(const.vocab_transform[const.SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(const.vocab_transform[const.TGT_LANGUAGE])


transformer = Seq2SeqTransformer(const.NUM_ENCODER_LAYERS, const.NUM_DECODER_LAYERS, const.EMB_SIZE,
                                 const.NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, const.FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(const.device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=const.PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)


In [None]:
from timeit import default_timer as timer
NUM_EPOCHS = 25

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer, train_dataloader, loss_fn)
    end_time = timer()
    val_loss = evaluate(transformer, validation_dataloader, loss_fn)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s")

In [None]:
from datetime import datetime
today_date = datetime.today().strftime('%Y-%m-%d')

torch.save(transformer.state_dict(), f'models/transformer-{today_date}-{sentences_to_use}-{NUM_EPOCHS}.pth')

In [14]:
transformer.load_state_dict(torch.load('models/transformer-2023-10-08-50000-25.pth'))

<All keys matched successfully>

In [15]:
text_transform_src = sequential_transforms(tokenize_source,
    const.vocab_transform[const.SRC_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

vowels_transcription = ['a', 'ʌ', 'ɤ̞',  'ɐ', 'ɔ', 'o', 'u', 'ɛ', 'i']
text_transform_tgt = sequential_transforms(const.vocab_transform[const.TGT_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

In [17]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(const.device)
    src_mask = src_mask.to(const.device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(const.device)
    for i in range(max_len-1):
        memory = memory.to(const.device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(const.device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == const.EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    src_sentence = src_sentence.lower()
    model.eval()
    src = text_transform_src(src_sentence).view(-1, 1)

    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=const.BOS_IDX).flatten()
    # print(list(tgt_tokens.cpu().numpy()))
    return " ".join(const.vocab_transform[const.TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [18]:
print(translate(transformer, "здравей, как сте?"))
print(translate(transformer, "Добре съм, благодаря."))
print(translate(transformer, "Айрян"))
print(translate(transformer, "Български език"))
print(translate(transformer, "език Български език"))
print(translate(transformer, "Лятото е моето любимо време на годината."))
print(translate(transformer, "В парка разцъфтяха невероятни цветя."))
print(translate(transformer, "Музиката успокоява душата ми след дълъг работен ден."))
print(translate(transformer, "Вчера се срещнах със стар приятел, когото не бях виждал години."))
print(translate(transformer, "Четенето на книги разширява хоризонтите и обогатява речника."))

 zdrʌ vɛj , kʌk stɛ ? 
 dob rɛ sɐm , blʌ go dʌr jɐ . 
 xvɐr ljɐ 
 bɐl gʌr ski ɛ zik 
 ɛ zik bɐl gʌr ski ɛ zik 
 ljɐ to to ɛ mo ɛ to ljo bi mo vrɛ mɛ nʌ go di nʌ tʌ . 
 v pʌr kʌ rʌz lit tjɐ xʌ nɛ vɛ ro jɐt ni tsvɛt jɐ . 
 mo zi kʌ tʌ os po ko jɐ vʌ do ʃʌ tʌ mi slɛd dɐ lɐg rʌ bo tɛn dɛn . 
 vtʃɛ rʌ sɛ srɛʃ tnʌx sɐs stʌr pri jɐ tɛl , ko go to nɛ bjɐx viʒ dʌl go di ni . 
 tʃɛ tɛ nɛ to nʌ kni gi rʌz ʃir jɐ vʌ xo ri zon ti tɛ i o bo gʌt jɐ vʌ rɛt ʃni kʌ . 
