In [1]:
from config_file import get_config
config = get_config()

In [2]:
5

4369334640

In [2]:
import datasets

In [3]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

In [4]:
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

def get_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config["tokenizer_file"].format(lang))
    if tokenizer_path.exists():
        # load
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    else:
        # build
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(min_frequency=2, show_progress=True, 
                                   special_tokens=["[UNK]", "[SOS]", "[EOS]", "[PAD]"])
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    return tokenizer

In [37]:
# Hmm why use token_to_id? Why not use encode? And if so, why not use token_to_id when max_len? ids?
# what shape does maked_fill_ in attention expect? (batch, h, seq, seq) is attention matrix. based on that we are making masks to be (1, 1, seq_len)...
# shape of causal mask also
# cant i use torch.cat with lists? - Nope. Also tensor scalars not allowed - hence, [tokencoded SOS]
# whats the result type of encode, token_to_id etc etc - not tensors? can we make it tensors? - Nope
# difference between dtype=torch.int64 and .int()

from torch.utils.data import Dataset
import torch

class BilingualDataset(Dataset):
    def __init__(self, ds, config, tokenizer_src, tokenizer_tgt):
        super().__init__()
        self.ds = ds
        self.seq_len = config["seq_len"]
        self.lang_src = config["lang_src"]
        self.lang_tgt = config["lang_tgt"]
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.sos_token = torch.tensor([self.tokenizer_src.token_to_id('[SOS]')]).int()
        self.eos_token = torch.tensor([self.tokenizer_src.token_to_id('[EOS]')]).int()
        self.pad_token = torch.tensor([self.tokenizer_src.token_to_id('[PAD]')]).int()

    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, index):
        src_tgt_pair = self.ds[index]['translation']
        src_text = src_tgt_pair.get(self.lang_src)
        tgt_text = src_tgt_pair.get(self.lang_tgt)

        src_ids = self.tokenizer_src.encode(src_text).ids
        tgt_ids = self.tokenizer_tgt.encode(tgt_text).ids

        enc_pad_count = self.seq_len - len(src_ids) - 2
        dec_pad_count = self.seq_len - len(tgt_ids) - 1
        if enc_pad_count < 0 or dec_pad_count < 0:
            raise ValueError("Sentence too long")

        encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(src_ids).int(), 
            torch.tensor([self.pad_token] * enc_pad_count).int()
        ])
        decoder_input = torch.cat([
            self.sos_token, 
            torch.tensor(tgt_ids).int(),
            torch.tensor([self.pad_token] * dec_pad_count).int()
        ])

        label = torch.cat([
            torch.tensor(tgt_ids).int(),
            self.eos_token,
            torch.tensor([self.pad_token] * dec_pad_count).int()
        ])
        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), 
            # encoder mask: (1, 1, seq_len) -> Has 1 when there is text and 0 when there is pad (no text)
            
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            # (1, seq_len) and (1, seq_len, seq_len)
            # Will get 0 for all pads. And 0 for earlier text.
            "label": label,
            "src_text": src_text,
            "tgt_text": tgt_text
            }
    
def causal_mask(size):
    return torch.tril(torch.ones(1, size, size)).bool()

In [25]:
# Why not use batch encoding to speed things up? > NEED TO SORT THIS OUT

from torch.utils.data import DataLoader, random_split

def get_ds(config):
    ds_raw = datasets.load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')
    tokenizer_src = get_build_tokenizer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_build_tokenizer(config, ds_raw, config["lang_tgt"])        
    # max_length_src = 0
    # max_length_tgt = 0
    # for item in ds_raw['translation']:
    #     src_ids = tokenizer_src.encode(item[config["lang_src"]]).ids
    #     tgt_ids = tokenizer_tgt.encode(item[config["lang_tgt"]]).ids
    #     max_length_src = max(max_length_src, len(src_ids))
    #     max_length_tgt = max(max_length_tgt, len(tgt_ids))
    # print(f'[INFO] Longest sentence in src languagee  - {config["lang_src"]} - contains {max_length_src} words')
    # print(f'[INFO] Longest sequence in tgt language - {config["lang_tgt"]} - contains {max_length_tgt} words')
    train_ds_raw, val_ds_raw = random_split(ds_raw, lengths=[0.9, 0.1])
    train_ds = BilingualDataset(train_ds_raw, config, tokenizer_src, tokenizer_tgt)
    val_ds = BilingualDataset(val_ds_raw, config, tokenizer_src, tokenizer_tgt)
    train_dl = DataLoader(train_ds, shuffle=True, batch_size=config["batch_size"])
    val_dl = DataLoader(val_ds, shuffle=True, batch_size=1)
    return train_dl, val_dl, tokenizer_src, tokenizer_tgt


In [40]:
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)

In [42]:
values = next(iter(train_dataloader))

In [43]:
values.keys()

dict_keys(['encoder_input', 'decoder_input', 'encoder_mask', 'decoder_mask', 'label', 'src_text', 'tgt_text'])

In [48]:
values['encoder_input'].shape

torch.Size([16, 349])