In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import random
import spacy
import re
import os
import json

# Set data directory
data_dir = "../data/opensubtitles_en_ko"

# Load tokenizer models
spacy_en = spacy.load("en_core_web_sm")
spacy_ko = spacy.load("ko_core_news_sm")

# Tokenizer functions
def tokenize_en(text):
    return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

def tokenize_ko(text):
    return [tok.text for tok in spacy_ko.tokenizer(text)]

# Load dataset
src_file = os.path.join(data_dir, "OpenSubtitles.en-ko.en")
trg_file = os.path.join(data_dir, "OpenSubtitles.en-ko.ko")

src_texts, trg_texts = [], []
with open(src_file, "r", encoding="utf-8") as f_en, open(trg_file, "r", encoding="utf-8") as f_ko:
    for en_line, ko_line in zip(f_en, f_ko):
        src_texts.append(en_line.strip())
        trg_texts.append(ko_line.strip())

# Save dataset
save_data = lambda data, path: json.dump(data, open(path, "w", encoding="utf-8"), ensure_ascii=False, indent=4)
save_data(src_texts, os.path.join(data_dir, "src_texts.json"))
save_data(trg_texts, os.path.join(data_dir, "trg_texts.json"))

# Load vocabulary
src_vocab = {word: i for i, word in enumerate(["<PAD>", "<UNK>", "<SOS>", "<EOS>"] + list(set([tok for text in src_texts for tok in tokenize_en(text)])))}
trg_vocab = {word: i for i, word in enumerate(["<PAD>", "<UNK>", "<SOS>", "<EOS>"] + list(set([tok for text in trg_texts for tok in tokenize_ko(text)])))}

# Save vocab
save_data(src_vocab, os.path.join(data_dir, "src_vocab.json"))
save_data(trg_vocab, os.path.join(data_dir, "trg_vocab.json"))

# Load dataset
src_texts = json.load(open(os.path.join(data_dir, "src_texts.json"), "r", encoding="utf-8"))
trg_texts = json.load(open(os.path.join(data_dir, "trg_texts.json"), "r", encoding="utf-8"))

# Custom Dataset Class
class TranslationDataset(Dataset):
    def __init__(self, src_texts, trg_texts, src_vocab, trg_vocab, max_len=50):
        self.src_texts = src_texts
        self.trg_texts = trg_texts
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.max_len = max_len
    
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src_seq = self.text_to_tensor(self.src_texts[idx], self.src_vocab)
        trg_seq = self.text_to_tensor(self.trg_texts[idx], self.trg_vocab)
        return src_seq, trg_seq
    
    def text_to_tensor(self, text, vocab):
        tokens = tokenize_en(text) if vocab == src_vocab else tokenize_ko(text)
        token_ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
        token_ids = token_ids[:self.max_len] + [vocab["<PAD>"]] * (self.max_len - len(token_ids))
        return torch.tensor(token_ids, dtype=torch.long)

# Create dataset and dataloader
dataset = TranslationDataset(src_texts, trg_texts, src_vocab, trg_vocab)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)