In [None]:
import polars as pl
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab
from collections import Counter
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

FILE_PATH = "./wiki_corpus_2.01/kyoto_lexicon.csv"
tokenizer_src = get_tokenizer('spacy', language='ja_core_news_sm')
tokenizer_tgt = get_tokenizer('spacy', language='en_core_web_sm')
# どれか一つの行がうまく読み込めなかったため、truncate_tagged_linesでfield数が一致しない行は無視している(元データ51983行)

class datasets(Dataset):
    def __init__(self, text, label):
        self.jp_datas = text
        self.en_datas = label

    def __len__(self):
        return len (self.jp_datas)

    def __getitem__(self, index):
        jp = self.jp_datas[index]
        en = self.en_datas[index]
        return jp,en

class DataLoaderCreater:

    def __init__(self, file_path, src_tokenizer, tgt_tokenizer):
        self.file_path = file_path
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

    def build_vocab(self, texts, tokenizer):
        counter = Counter()
        for text in texts:
            counter.update(tokenizer(text))
        specials = ['<unk>', '<pad>', '<start>', '<end>']
        v = vocab(counter, specials=specials, min_freq=1)
        v.set_default_index(v['<unk>'])
        return v

    def convert_text_to_indexes(self, text, vocab, tokenizer):
        return [vocab['<start>']] + [
            vocab[token] if token in vocab else vocab['<unk>'] for token in tokenizer(text.strip("\n"))
        ] + [vocab['<end>']]

    def create_dataloader(self):
        df = pl.read_csv(self.file_path, separator=",", encoding="utf-8", has_header=True, truncate_ragged_lines=True)
        df_selected = df.select([df.columns[0], df.columns[1]])
        df_jp = df_selected[:, 0]
        jp_list = df_jp.to_list()
        df_en = df_selected[:, 1]
        en_list = df_en.to_list()

        self.vocab_src = self.build_vocab(jp_list, tokenizer_src)
        self.vocab_tgt = self.build_vocab(en_list, tokenizer_tgt)
        self.vocab_src = self.vocab_src.get_stoi()
        self.vocab_tgt = self.vocab_tgt.get_stoi()
        self.len_src_vocab = len(self.vocab_src)
        self.len_tgt_vocab = len(self.vocab_tgt)

        src_data = pad_sequence([torch.tensor(self.convert_text_to_indexes(text, self.vocab_src, tokenizer=self.src_tokenizer)) for text in jp_list], batch_first = True, padding_value = self.vocab_src["<pad>"])
        tgt_data = pad_sequence([torch.tensor(self.convert_text_to_indexes(text, self.vocab_tgt, tokenizer=self.tgt_tokenizer)) for text in en_list], batch_first = True, padding_value = self.vocab_tgt["<pad>"])

        dataset = datasets(src_data, tgt_data)

        # データセットの長さを取得
        dataset_length = len(dataset)

        # 各分割のサイズを計算
        train_size = int(0.8 * dataset_length)
        val_size = int(0.1 * dataset_length)
        test_size = dataset_length - train_size - val_size

        # データセットをランダムに分割
        train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
        train_dataloader = DataLoader(train_dataset, batch_size=128)
        valid_dataloader = DataLoader(valid_dataset, batch_size=128)
        test_dataloader = DataLoader(test_dataset, batch_size=128)

        return train_dataloader, test_dataloader, valid_dataloader

dataloader_creater = DataLoaderCreater(FILE_PATH, tokenizer_src, tokenizer_tgt)
train_dataloader, test_dataloader, valid_dataloader = dataloader_creater.create_dataloader()
src_vocab_size = dataloader_creater.len_src_vocab
tgt_vocab_size = dataloader_creater.len_tgt_vocab
vocab_src = dataloader_creater.vocab_src
vocab_tgt = dataloader_creater.vocab_tgt

In [None]:
# def evaluate(model, dataloader, criterion):
#     model.eval()
#     epoch_loss = 0
#     with torch.no_grad():
#         for src, tgt in dataloader:
#             tgt_input = tgt[:, :-1]
#             tgt_output = tgt[:, 1:]
#             output = model(src, tgt_input)
#             output = output.view(-1, output.shape[-1])
#             tgt_output = tgt_output.contiguous().view(-1)
#             loss = criterion(output, tgt_output)
#             epoch_loss += loss.item()
#     return epoch_loss / len(dataloader)