Data: https://opus.nlpl.eu/opus-100.php

In [1]:
import os
import string
import itertools
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset

In [2]:
@dataclass
class paths:
    data = os.path.join('data', 'en-pl')
    pl_test = os.path.join(data, 'opus.en-pl-test.pl')
    en_test = os.path.join(data, 'opus.en-pl-test.en')
    pl_dev = os.path.join(data, 'opus.en-pl-dev.pl')
    en_dev = os.path.join(data, 'opus.en-pl-dev.en')
    pl_train = os.path.join(data, 'opus.en-pl-train.pl')
    en_train = os.path.join(data, 'opus.en-pl-train.en')

In [58]:
class Vocabulary:
    def __init__(self, data):
        
        self.vocab = {
            '<unk>': 0,
            '<pad>': 1,
            '<sos>': 2,
            '<eos>': 3
        }
        
        self.build_vocab(data)
        
    def __getitem__(self, index):
        assert type(index) in [str, int], 'Index type must be string or int'
        
        if isinstance(index, str):
            try:
                return self.vocab[index]
            
            except KeyError:
                return self.vocab['<unk>']
        
        elif isinstance(index, int):
            try:
                return list(self.vocab.keys())[list(self.vocab.values()).index(index)]
            except (KeyError,ValueError):
                return self[0]
    
    def __len__(self):
        return len(self.vocab)
    
    def append_word(self, word):
        if not word in self.vocab and word.isalpha():
            self.vocab[word] = len(self)
    
    def build_vocab(self, data):
        bag_of_words = sorted(list(set(data)))
        
        for word in bag_of_words:
            self.append_word(word)

In [72]:
class PolEngDS(Dataset):
    def __init__(self, pl_path, en_path):

        self.data = {
            'polish': self._load_data(pl_path),
            'english': self._load_data(en_path)
        }
        
        self.preprocessing()
        
        self.vocab_pl = Vocabulary(self.__flat_list(self.data['polish']))
        self.vocab_en = Vocabulary(self.__flat_list(self.data['english']))
        
    def __getitem__(self, index):
        pl, en = [text.split() for text in self.data.iloc[index].values]
        
        
        return [self.vocab_pl[word] for word in pl], [self.vocab_en[word] for word in en]
    
    def __len__(self):
        pass
    
    @staticmethod
    def _load_data(path):
        with open(path, 'r', encoding='UTF-8') as f:
            data = f.read()
        data = data.split('\n')[:-1]
        
        return data
    
    def preprocessing(self):
        preprocessed_data = {
            'polish': [],
            'english': []
        }
        
        for i, (pl, en) in enumerate(zip(*self.data.values())):
            preprocessed_data['polish'].append(self.__text_prep(pl))
            preprocessed_data['english'].append(self.__text_prep(en))
        
        self.data = pd.DataFrame(preprocessed_data)
   
    @staticmethod
    def __text_prep(text):
        #remove punctuations
        text = text.translate(str.maketrans('', '', string.punctuation))
        text = text.strip().lower()
        text.split('/n')
        
        return text
    
    @staticmethod
    def __flat_list(data):
        data = [text.split() for text in data]
        return list(itertools.chain.from_iterable(data))

In [73]:
data = PolEngDS(paths.pl_test, paths.en_test)