In [5]:
import polars as pl
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from collections import Counter
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence


JP_TRAIN_FILE_PATH = "./kftt-data-1.0/data/orig/kyoto-train.ja"
EN_TRAIN_FILE_PATH = "./kftt-data-1.0/data/orig/kyoto-train.en"

tokenizer_src = get_tokenizer('spacy', language='ja_core_news_sm')
tokenizer_tgt = get_tokenizer('spacy', language='en_core_web_sm')

class datasets(Dataset):
    def __init__(self, text, label):
        self.jp_datas = text
        self.en_datas = label

    def __len__(self):
        return len (self.jp_datas)

    def __getitem__(self, index):
        jp = self.jp_datas[index]
        en = self.en_datas[index]
        return jp,en

class DataLoaderCreater:

    def __init__(self, src_tokenizer, tgt_tokenizer):
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

    def build_vocab(self, texts, tokenizer):
        counter = Counter()
        for text in texts:
            counter.update(tokenizer(text))
        specials = ['<unk>', '<pad>', '<start>', '<end>']
        v = vocab(counter, specials=specials, min_freq=1)
        v.set_default_index(v['<unk>'])
        return v

    def convert_text_to_indexes(self, text, vocab, tokenizer):
        return [vocab['<start>']] + [
            vocab[token] if token in vocab else vocab['<unk>'] for token in tokenizer(text.strip("\n"))
        ] + [vocab['<end>']]

    def create_dataloader(self, jp_file_path, en_file_path, collate_fn):
        with open(jp_file_path, "r", encoding="utf-8")as f:
            jp_list = f.readlines()
            jp_list = [jp.strip("\n") for jp in jp_list]

        with open(en_file_path, "r", encoding="utf-8")as f:
            en_list = f.readlines()
            en_list = [en.strip("\n") for en in en_list]

        vocab_src = self.build_vocab(jp_list, tokenizer_src)
        vocab_tgt = self.build_vocab(en_list, tokenizer_tgt)
        vocab_src_stoi = vocab_src.get_stoi()
        vocab_tgt_stoi = vocab_tgt.get_stoi()

        src_data = [self.convert_text_to_indexes(jp_data, vocab_src_stoi, self.src_tokenizer) for jp_data in jp_list]
        tgt_data = [self.convert_text_to_indexes(en_data, vocab_tgt_stoi, self.tgt_tokenizer) for en_data in en_list]
        dataset = datasets(src_data, tgt_data)

        dataloader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn)

        return dataloader

PADDING_ID = 1

def collate_fn(batch):
    src_batch = [src for src,tgt in batch]
    tgt_batch = [tgt for src,tgt in batch]

    src_batch = pad_sequence(src_batch, padding_value=PADDING_ID)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PADDING_ID)

train_dataloader = DataLoaderCreater(tokenizer_src, tokenizer_tgt).create_dataloader(jp_file_path=JP_TRAIN_FILE_PATH, en_file_path=EN_TRAIN_FILE_PATH, collate_fn=collate_fn)