In [1]:
from pathlib import Path

import torch
import torch.nn as nn
from datasets import load_dataset
from datasets import DatasetDict
from tokenizers import Tokenizer

In [34]:
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.models import BPE
from tokenizers.normalizers import Lowercase, NFKC, Sequence
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
from tokenizers import SentencePieceBPETokenizer

In [3]:
raw_dataset = load_dataset('opus100', name='ar-en')
raw_dataset

DatasetDict({
    test: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
    train: Dataset({
        features: ['translation'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
})

In [4]:
raw_dataset['train']

Dataset({
    features: ['translation'],
    num_rows: 1000000
})

In [5]:
raw_dataset['train'][1], raw_dataset['train'][0]['translation']

({'translation': {'ar': '...لقد كان', 'en': 'It was, um...'}},
 {'ar': 'و هذه؟', 'en': 'And this?'})

In [44]:
def extract_texts(dataset, lang="en"):
    return (item["translation"][lang] for item in dataset)

def build_tokenizer(dataset, lang):
    special_tokens = ['[SOS]', '[EOS]', '[PAD]', '[UNK]']
    tokenizer = SentencePieceBPETokenizer(unk_token='[UNK]')
    tokenizer.train_from_iterator(extract_texts(dataset, lang), special_tokens=special_tokens,
                         min_frequency=2,
                         show_progress=True,)
    
    return tokenizer

def get_or_build_tokenizer(config, dataset, lang):
    tokenizer_path = Path(config.tokenizer_file.format(lang))
    
    if tokenizer_path.exists():
        return Tokenizer.from_file(str(tokenizer_path))
    
    tokenizer = build_tokenizer(dataset, lang)
    tokenizer.save(str(tokenizer_path))
    return tokenizer

In [45]:
from datasets import concatenate_datasets
combined_dataset = concatenate_datasets([raw_dataset['train'], raw_dataset['validation'], raw_dataset['test']])
texts = extract_texts(combined_dataset)

In [46]:
src_tokenizer = get_or_build_tokenizer(config, combined_dataset, config.source_lang)
target_tokenizer = get_or_build_tokenizer(config, combined_dataset, config.target_lang)









### Create the Dataset and DataLoaders

In [87]:
from torch.utils.data import Dataset, DataLoader

class EnArDataset(Dataset):

    def __init__(self,
                 dataset: DatasetDict,
                 src_tokenizer: Tokenizer,
                 target_tokenizer: Tokenizer,
                 src_lang: str,
                 target_lang: str,
                 model_max_length: int) -> None:
        super().__init__()
        self.dataset = dataset
        self.src_tokenizer = src_tokenizer
        self.target_tokenizer = target_tokenizer
        self.src_lang = src_lang
        self.target_lang = target_lang
        self.model_max_length = model_max_length

        self.sos_token = torch.tensor([src_tokenizer.token_to_id('[SOS]')], dtype=torch.int64)
        self.eos_token = torch.tensor([src_tokenizer.token_to_id('[EOS]')], dtype=torch.int64)
        self.pad_token = torch.tensor([src_tokenizer.token_to_id('[PAD]')], dtype=torch.int64)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: any) -> torch.Tensor:
        src_target = self.dataset[idx]
        src_text = src_target['translation'][self.src_lang]
        tgt_text = src_target['translation'][self.target_lang]

        encoder_input_tokens = self.src_tokenizer.encode(src_text).ids
        decoder_input_tokens = self.target_tokenizer.encode(tgt_text).ids

        num_pad_tokens_src = self.model_max_length - len(encoder_input_tokens) - 2
        num_pad_tokens_tgt = self.model_max_length - len(decoder_input_tokens) - 1

        assert num_pad_tokens_src < 0 and num_pad_tokens_tgt < 0, "Sentence too long!"

        encoder_inputs = torch.cat([self.sos_token,
                                    torch.tensor(encoder_input_tokens, dtype=torch.int64),
                                    self.eos_token,
                                    torch.tensor([self.pad_token] * num_pad_tokens_src)])

        decoder_inputs = torch.cat([self.sos_token,
                                    torch.tensor(decoder_input_tokens, dtype=torch.int64),
                                    torch.tensor([self.pad_token] * num_pad_tokens_src)])
        
        label = torch.cat([torch.tensor(decoder_input_tokens, dtype=torch.int64),
                           self.eos_token,
                           torch.tensor([self.pad_token] * num_pad_tokens_src)])

        assert encoder_inputs.size(0) == self.model_max_length
        assert decoder_inputs.size(0) == self.model_max_length
        assert label.size(0) == self.model_max_length

        return {'encoder_inputs': encoder_inputs,
                'decoder_inputs': decoder_inputs,
                'label': label,
                'encoder_mask': ((encoder_inputs != self.pad_token)[None:, ...][None:, ...]).int(),
                'decoder_mask': (decoder_inputs != self.pad_token[None:, ...][None:, ...]).int() & causal_mask(sel.model_max_length),
                'src_text': src_text,
                'tgt_text': tgt_text
               }

def causal_mask(size):
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

In [68]:
from datasets import concatenate_datasets
from typing import Tuple

def get_datasets(config, src_tokenizer, target_tokenizer) -> Tuple[EnArDataset, EnArDataset, EnArDataset]:

    train_ds = EnArDataset(
        raw_dataset['train'],
        src_tokenizer,
        target_tokenizer,
        src_lang=config.source_lang, 
        target_lang=config.target_lang,
        model_max_length=config.model_max_length)

    validation_ds = EnArDataset(
        raw_dataset['validation'],
        src_tokenizer,
        target_tokenizer,
        src_lang=config.source_lang, 
        target_lang=config.target_lang,
        model_max_length=config.model_max_length)

    test_ds = EnArDataset(
        raw_dataset['test'],
        src_tokenizer,
        target_tokenizer,
        src_lang=config.source_lang, 
        target_lang=config.target_lang,
        model_max_length=config.model_max_length)
    return (train_ds, validation_ds, test_ds)

In [83]:
def get_loaders_and_tokenizers(config) -> Tuple[DataLoader, DataLoader, DataLoader]:
    raw_dataset = load_dataset('opus100', name='ar-en')
    combined_dataset = concatenate_datasets([raw_dataset['train'], raw_dataset['validation'], raw_dataset['test']])
    texts = extract_texts(combined_dataset)
    src_tokenizer = get_or_build_tokenizer(config, combined_dataset, config.source_lang)
    target_tokenizer = get_or_build_tokenizer(config, combined_dataset, config.target_lang)
    
    train_ds, validation_ds, test_ds = get_datasets(config, src_tokenizer, target_tokenizer)
    train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
    valid_dl = DataLoader(validation_ds, batch_size=1, shuffle=False)
    test_dl = DataLoader(test_ds, batch_size=8, shuffle=False)
    return {'train_dl':train_dl,
            'valid_dl': valid_dl,
            'test_dl': test_dl,
            'src_tokenizer': src_tokenizer,
            'target_tokenizer': target_tokenizer
           }

### Define The tokenizer config dataclass

In [84]:
from dataclasses import dataclass
from typing import Dict

@dataclass
class TokenizerConfig:
    tokenizer_file: str = 'tokenizer_config_{0}.json'
    eos_token: str = '[SOS]'
    model_max_length: int = 512
    pad_token: str = '[PAD]'
    return_tensors: str = 'pt'
    separate_vocabs: bool = False
    source_lang: str = 'en'
    target_lang: str = 'ar'
    unk_token: str = '[UNK]'

In [85]:
config = TokenizerConfig()
config

TokenizerConfig(tokenizer_file='tokenizer_config_{0}.json', eos_token='[SOS]', model_max_length=512, pad_token='[PAD]', return_tensors='pt', separate_vocabs=False, source_lang='en', target_lang='ar', unk_token='[UNK]')

In [86]:
get_loaders_and_tokenizers(config, raw_dataset)

{'train_dl': <torch.utils.data.dataloader.DataLoader at 0x105c60cd0>,
 'valid_dl': <torch.utils.data.dataloader.DataLoader at 0x167fc8e50>,
 'test_dl': <torch.utils.data.dataloader.DataLoader at 0x168e65790>,
 'src_tokenizer': <tokenizers.Tokenizer at 0x1602c2600>,
 'target_tokenizer': <tokenizers.Tokenizer at 0x15c9e7400>}