In [2]:
import tqdm
import random
from torch.utils.data import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class BERTDataset(Dataset):
    def __init__(self, corpus_path = './data/eng-fra.txt', vocab = None , seq_len = 20, corpus_lines=None):
        self.vocab = vocab
        self.seq_len = seq_len
        self.corpus_lines = 0
        self.corpus_path = corpus_path
        self.lines = []

        # Reopen the file to read the lines
        with open('./data/eng-fra.txt', "r", encoding="utf-8") as f:
            for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
                self.lines.append(line.strip())

        self.corpus_lines = len(self.lines)

    def get_corpus_line(self, item):
        return self.lines[item][0], self.lines[item][1]
    
    def get_random_line(self):
        return self.lines[random.randrange(self.corpus_lines)][1]
    
    def random_sent(self, index):
        t1, t2 = self.get_corpus_line(index)

        # output_text, label(isNotNext:0, isNext:1)
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0
        
    def random_word(self, sentence):
        tokens = sentence.split()
        output_label = []

        for i, token in enumerate(tokens):
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token
                if prob < 0.8:
                    tokens[i] = self.vocab.mask_index

                # 10% randomly change token to random token
                elif prob < 0.9:
                    tokens[i] = random.randrange(len(self.vocab))

                # 10% randomly change token to current token
                else:
                    tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)

                output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))

            else:
                tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
                output_label.append(0)

        return tokens, output_label
    
    def __getitem__(self, item):
        t1, t2, is_next_label = self.random_sent(item)
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        # [CLS] tag = SOS tag, [SEP] tag = EOS tag
        t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
        t2 = t2_random + [self.vocab.eos_index]

        t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
        t2_label = t2_label + [self.vocab.pad_index]

        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}
        
    def __len__(self):
        return self.corpus_lines


In [7]:
corpus_lines = 0
lines = []

with open('./data/eng-fra.txt', "r", encoding="utf-8") as f:
    # Counting lines for tqdm progress
    for _ in f:
        corpus_lines += 1

# Reopen the file to read the lines
with open('./data/eng-fra.txt', "r", encoding="utf-8") as f:
    for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
        lines.append(line.strip())


Loading Dataset: 100%|██████████| 135842/135842 [00:00<00:00, 1818306.30it/s]


In [4]:
corpus_lines

135842

 train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
                                corpus_lines=args.corpus_lines, on_memory=args.on_memory)

In [8]:
from collections import Counter

counter = Counter()
with open('./data/eng-fra.txt', "r", encoding="utf-8") as f:
    for line in tqdm.tqdm(f, desc="Loading Dataset"):
        if isinstance(line, list):
            words = line
        else:
            words = line.replace("\n", "").replace("\t", " ").split()

        for word in words:
            counter[word] += 1

Loading Dataset: 135842it [00:00, 153362.44it/s]
