In [1]:
from transformers import BertTokenizerFast, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd 
import os 
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# define device 
# configuration 

TOKENIZERS_PARALLELISM = True

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA as device")
else:
    # Check that MPS is available
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                  "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                  "and/or you do not have an MPS-enabled device on this machine.")
        device = torch.device("cpu")
        print("Using CPU as device")
    else:
        device = torch.device("mps")
        print("Using MPS as device")
    
torch.set_default_device(device)

Using MPS as device


In [3]:
current_path = os.getcwd()
preprocessed_directory = preprocessed_directory = os.path.join(current_path, "preprocessed")

In [4]:

# import tokenizers
kr_tokenizer = BertTokenizerFast.from_pretrained("kykim/bert-kor-base")
en_tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")

In [5]:
# Test tokenizers 
tmp_kr_sentence = "오늘 하교길에 길고양이를 보았는데, 너무 귀여워서 집에 데려가고 싶었다. 하지만 그러지는 않았다."
tmp_en_sentence = "The cat I saw during heading home today was so cute, that I wanted to bring it to home."

tmp_kr_tokenized = kr_tokenizer(tmp_kr_sentence, add_special_tokens=True, padding="max_length", max_length=256, truncation=True)
tmp_en_tokenized = en_tokenizer(tmp_en_sentence, add_special_tokens=True, padding="max_length", max_length=256, truncation=True)

# print(kr_tokenizer.convert_ids_to_tokens(tmp_kr_tokenized.input_ids))
# print(en_tokenizer.convert_ids_to_tokens(tmp_en_tokenized.input_ids))

# print(kr_tokenizer.decode(tmp_kr_tokenized.input_ids, skip_special_tokens=True))

# check if both tokenizer has pad token 
# print(kr_tokenizer.pad_token)
# print(en_tokenizer.pad_token)

In [6]:
df_train = pd.read_parquet(path=os.path.join(preprocessed_directory, "train.parquet"))
df_test = pd.read_parquet(path=os.path.join(preprocessed_directory, "test.parquet"))
df_validation = pd.read_parquet(path=os.path.join(preprocessed_directory, "validation.parquet"))



class en2kr_Train_Dataset(Dataset): 
    def __init__(self, max_len): 
        self.data = df_train
        self.max_len = max_len 
        self.kr_tokenizer = kr_tokenizer
        self.en_tokenizer = en_tokenizer
        
    def __len__(self): 
        return len(self.data) 

    def __getitem__(self, idx): 
        row = self.data.iloc[[idx]]
        en_sentence = row["english"].item()
        kr_sentence = row["korean"].item()
        kr_tokenized_ids = self.kr_tokenizer(kr_sentence, add_special_tokens=True, padding="max_length", max_length=self.max_len, truncation=True).input_ids
        en_tokenized_ids = self.en_tokenizer(en_sentence, add_special_tokens=True, padding="max_length", max_length=self.max_len, truncation=True).input_ids

        # kr_tokenized = self.kr_tokenizer.convert_ids_to_tokens(kr_tokenized_ids)
        # en_tokenized = self.en_tokenizer.convert_ids_to_tokens(en_tokenized_ids)

        kr_tokenized_ids = torch.IntTensor(kr_tokenized_ids)
        en_tokenized_ids = torch.IntTensor(en_tokenized_ids)
        return kr_tokenized_ids, en_tokenized_ids
        
class en2kr_Test_Dataset(Dataset): 
    def __init__(self, max_len): 
        self.data = df_test
        self.max_len = max_len 
        self.kr_tokenizer = kr_tokenizer
        self.en_tokenizer = en_tokenizer
        
    def __len__(self): 
        return len(self.data) 

    def __getitem__(self, idx): 
        row = self.data.iloc[[idx]]
        en_sentence = row["english"].item()
        kr_sentence = row["korean"].item()
        kr_tokenized_ids = self.kr_tokenizer(kr_sentence, add_special_tokens=True, padding="max_length", max_length=self.max_len, truncation=True).input_ids
        en_tokenized_ids = self.en_tokenizer(en_sentence, add_special_tokens=True, padding="max_length", max_length=self.max_len, truncation=True).input_ids

        # kr_tokenized = self.kr_tokenizer.convert_ids_to_tokens(kr_tokenized_ids)
        # en_tokenized = self.en_tokenizer.convert_ids_to_tokens(en_tokenized_ids)
        
        kr_tokenized_ids = torch.IntTensor(kr_tokenized_ids)
        en_tokenized_ids = torch.IntTensor(en_tokenized_ids)
        
        return kr_tokenized_ids, en_tokenized_ids

class en2kr_Validation_Dataset(Dataset): 
    def __init__(self, max_len): 
        self.data = df_validation
        self.max_len = max_len 
        self.kr_tokenizer = kr_tokenizer
        self.en_tokenizer = en_tokenizer
        
    def __len__(self): 
        return len(self.data) 

    def __getitem__(self, idx): 
        row = self.data.iloc[[idx]]
        en_sentence = row["english"].item()
        kr_sentence = row["korean"].item()
        kr_tokenized_ids = self.kr_tokenizer(kr_sentence, add_special_tokens=True, padding="max_length", max_length=self.max_len, truncation=True).input_ids
        en_tokenized_ids = self.en_tokenizer(en_sentence, add_special_tokens=True, padding="max_length", max_length=self.max_len, truncation=True).input_ids

        # kr_tokenized = self.kr_tokenizer.convert_ids_to_tokens(kr_tokenized_ids)
        # en_tokenized = self.en_tokenizer.convert_ids_to_tokens(en_tokenized_ids)
        
        kr_tokenized_ids = torch.IntTensor(kr_tokenized_ids)
        en_tokenized_ids = torch.IntTensor(en_tokenized_ids)
        
        return kr_tokenized_ids, en_tokenized_ids

In [7]:
batch_size = 512

train_dataset = en2kr_Train_Dataset(max_len=128)
test_dataset = en2kr_Test_Dataset(max_len=128)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=device))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, generator=torch.Generator(device=device))