In [160]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

import torchtext
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from copy import deepcopy
import random
import math
import time
from tqdm.notebook import tqdm

from typing import Tuple

In [3]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [149]:
class TxtProcessor() : 
    def __init__(self, lemma_dict_path = 'lemma_dict.pickle') : 
        self._lemma_dict = pd.read_pickle(lemma_dict_path)
        self.tokenizer = get_tokenizer('basic_english')
        self._nsp_label = {'IsNext':1, 'NotNext':0}
        
    def _wi_map(self, word) : 
        return self._lemma_dict.get(word, self._lemma_dict['<unk>'])
    
    def sent_tokenize(self, txt) : 
        tokens = self.tokenizer(txt)
        return tokens
    
    def word_indexing(self, word) : 
        return self._wi_map(word)
    
    def preprocess(self, txt) : 
        tokens = self.sent_tokenize(txt)
        wi_arr = np.vectorize(self._wi_map)(tokens)
        return wi_arr
    
    @property
    def lemma_dict(self) : 
        return self._lemma_dict
    
    @property
    def mask_id(self) : 
        return self._lemma_dict['<mask>']
    
    @property
    def unk_id(self) : 
        return self._lemma_dict['<unk>']

    @property
    def pad_id(self) : 
        return self._lemma_dict['<pad>']

    @property
    def sep_id(self) : 
        return self._lemma_dict['<sep>']
    
    @property
    def cls_id(self) : 
        return self._lemma_dict['<cls>']
    
    @property
    def nsp_label(self) : 
        return self._nsp_label

In [156]:
class BertIterator(Dataset) : 
    def __init__(self,
                filename = 'prep_docs.txt',
                lemma_dict_path = 'lemma_dict.pickle',
                 seq_len=256,
                in_memory=True) : 
        
        if in_memory is False : 
            NotImplementedError("Only in-memory version is supported")
            
        self.docs = self._load_txt_in_memory(filename)
        self.prep = TxtProcessor(lemma_dict_path)
        self.seq_len = seq_len
        
    def __len__(self) : 
        return self.length
    
    def __getitem__(self, item) : 
        txt1, txt2 = self._sample_txt_from_line(item)
        
        wi1, mask_l1 = self._generate_mask(txt1)
        wi2, mask_l2 = self._generate_mask(txt2)        
        wi2, mask_lv2, nsp_l = self._generate_nsp(wi2, mask_l2)
        
        wi = self._concat_sequences(wi1, wi2)
        mask_l = self._concat_sequences(mask_l1, mask_l2)
        
        return {
            'text' : wi,
            'mlm' : mask_l,
            'nsp' : nsp_l
        }

    def _load_txt_in_memory(self, fname) : 
        docs = open(fname).read().splitlines()
        self.length = len(docs) # take end-line
        return docs
    
    def _sample_txt_from_line(self, idx, get_pair=True) :
        txt1, txt2 = self.docs[idx].split("\t")
        if get_pair :
            return txt1, txt2
        else : 
            return txt2        
    
    def _generate_mask(self, txt) : 
        wi = self.prep.preprocess(txt1)
        
        # random-sampling mask targeted index
        index_arr = np.arange(len(wi))
        np.random.shuffle(index_arr)
        index_arr = index_arr[:int(index_arr.shape[0]*0.15)]
        mask_label[index_arr] = wi[index_arr]
        
        # seperate mask targeted index into 3 conditions
        mask_idx_arr = index_arr[:int(len(index_arr)*0.8)]
        replace_idx_arr = index_arr[int(len(index_arr)*0.8):int(len(index_arr)*0.9)]
        unchanged_idx_arr = index_arr[int(len(index_arr)*0.9):]
        
        # apply masking
        wi[mask_idx_arr] = self.prep.mask_id
        random_alloc_wi = np.random.choice(np.arange(4, len(prep.lemma_dict)),
                                           size=replace_idx_arr.shape[0], replace=False)
        wi[replace_idx_arr] = random_alloc_wi
        
        return wi, mask_label

    def _generate_nsp(self, wi, label) : 
        p = random.random()
        if p > 0.5 : # NotNext
            
            rand_sample_idx = np.random.randint(low=5, high=self.length, size=1).item()
            txt = self._sample_txt_from_line(idx=rand_sample_idx, get_pair=False)
            wi, label = self._generate_mask(txt)
            return wi, label, self.prep.nsp_label['NotNext']
            
        return wi, label, self.prep.nsp_label['IsNext']
    
    def _concat_sequences(self, wi1, wi2) : 
        pad_length = self.seq_len - 3 - wi1.shape[0] - wi2.shape[0]
        cated = [self.prep.cls_id] + wi1.tolist() + [self.prep.sep_id] + wi2.tolist() + [self.prep.sep_id]
        cated = cated[:self.seq_len] # list type
        padded = torch.tensor(cated + [self.prep.pad_id] * pad_length).long().contiguous()
        return padded

In [157]:
iterator = BertIterator()

In [None]:
DataLoader(iterator, batch_size=128, shuffle=False, num_workers=10,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

In [158]:
for i in iterator : 
    break