In [1]:
# Standard library imports
import json
import os

from collections import Counter
from copy import deepcopy
from typing import Iterable, List

# Third party imports
import pandas as pd
import torch
import torch.nn as nn
import tqdm

from multiset import Multiset
from nltk.tokenize import word_tokenize

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k

# Local application imports
import src.constants as const
from src.dataset_utils import yield_tokens, sentence_to_tensor, load_files, build_vocab_transformation, tokenize_source, tokenize_target
from src.training_utils import train_epoch, evaluate, sequential_transforms, tensor_transform
from src.transcription_dataset_single_word import TranscriptionDataset
from src.transformer_model import Seq2SeqTransformer, generate_square_subsequent_mask, create_mask
from src.syllable_splitter import split_word

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
words_filepath = '/mnt/d/Projects/masters-thesis/data/single_words.txt'

with open(words_filepath, 'r') as f:
    words = f.readlines()
amount_of_words = len(words)

f'{amount_of_words} words loaded'

'606102 words loaded'

In [4]:
vowels_transcription = ['a', 'ʌ', 'ɤ̞',  'ɐ', 'ɔ', 'o', 'u', 'ɛ', 'i']

In [8]:
# Train & test split
sentences_to_use = amount_of_words
train_split = int(const.TRAIN_TEST_SPLIT * sentences_to_use)
validation_split = int((const.TRAIN_TEST_SPLIT + const.TRAIN_VALIDATION_SPLIT) * sentences_to_use)

In [6]:
train_dataset = TranscriptionDataset(words_filepath, tokenization_src=split_word, 
                                     tokenization_tgt=lambda x: split_word(x, vowels_transcription),
                                     start_index=0, end_index=train_split)
validation_dataset = TranscriptionDataset(words_filepath, tokenization_src=split_word,
                                          tokenization_tgt=lambda x: split_word(x, vowels_transcription),
                                          start_index=train_split, end_index=validation_split)
test_dataset = TranscriptionDataset(words_filepath, tokenization_src=split_word,
                                    tokenization_tgt=lambda x: split_word(x, vowels_transcription),
                                    start_index=validation_split, end_index=amount_of_words)

In [9]:
for ln in [const.SRC_LANGUAGE, const.TGT_LANGUAGE]:
    # Create torchtext's Vocab object
    const.vocab_transform[ln] = build_vocab_transformation(train_dataset, ln)

In [10]:
text_transform_src = sequential_transforms(const.vocab_transform[const.SRC_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

vowels_transcription = ['a', 'ʌ', 'ɤ̞',  'ɐ', 'ɔ', 'o', 'u', 'ɛ', 'i']
text_transform_tgt = sequential_transforms(const.vocab_transform[const.TGT_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

In [11]:
# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform_src(src_sample))
        tgt_batch.append(text_transform_tgt(tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=const.PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=const.PAD_IDX)
    return src_batch, tgt_batch

In [12]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(const.vocab_transform[const.SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(const.vocab_transform[const.TGT_LANGUAGE])

transformer = Seq2SeqTransformer(const.NUM_ENCODER_LAYERS, const.NUM_DECODER_LAYERS, const.EMB_SIZE,
                                 const.NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, const.FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(const.device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=const.PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [13]:
transformer.load_state_dict(torch.load('models/transformer-single-word-2023-11-10-606102-25.pth'))

<All keys matched successfully>

In [14]:
text_transform_src = sequential_transforms(tokenize_source,
    const.vocab_transform[const.SRC_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

vowels_transcription = ['a', 'ʌ', 'ɤ̞',  'ɐ', 'ɔ', 'o', 'u', 'ɛ', 'i']
text_transform_tgt = sequential_transforms(const.vocab_transform[const.TGT_LANGUAGE], #Numericalization
                                                tensor_transform) # Add BOS/EOS and create tensor

In [15]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(const.device)
    src_mask = src_mask.to(const.device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(const.device)
    for i in range(max_len-1):
        memory = memory.to(const.device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(const.device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == const.EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    src_sentence = src_sentence.lower()
    model.eval()
    src = text_transform_src(src_sentence).view(-1, 1)

    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=const.BOS_IDX).flatten()
    # print(list(tgt_tokens.cpu().numpy()))
    return " ".join(const.vocab_transform[const.TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [16]:
def jaccard_index(predicted: list[str], actual: list[str]) -> float:
    predicted_multiset = Multiset(predicted)
    actual_multiset = Multiset(actual)

    multiset_intersection = predicted_multiset & actual_multiset
    
    length_intersection = len(multiset_intersection)
    length_predicted = len(predicted_multiset)
    length_actual = len(actual_multiset)

    return length_intersection / (length_predicted + length_actual)



In [19]:
jaccard = 0
direct_compare = 0
for word, actual_transcription in tqdm.tqdm(test_dataset):
    word = ''.join(word)
    predicted_transcription = translate(transformer, word).split()

    jaccard += jaccard_index(predicted_transcription, actual_transcription) * 2  # Jaccard index on multiset is 0 <= J <= 1/2

    if actual_transcription == predicted_transcription:
        direct_compare += 1

jaccard /= len(test_dataset)
direct_compare /= len(test_dataset)


f'Jaccard: {jaccard}, direct_compare: {direct_compare}'

100%|█████████▉| 60610/60611 [13:27<00:00, 75.08it/s]


'Jaccard: 0.995714985492419, direct_compare: 0.9928560822292983'