In [None]:
from torch.utils.data import Dataset, DataLoader
import random
import itertools
import torch

class Bookcorpus(Dataset): # TODO rewrite 
    
    def __init__(self, tokenizer, seq_len=64, split="train", n_rows=None):
        """
        n_rows None means take the whole dataset
        """
        
        if not split in ["train"]:
            raise ValueError("For Bookcorpus there is only a train split")
            
        if n_rows is not None:
            self.dataset = load_dataset("bookcorpus", split=split+"[0:"+str(n_rows)+"]")#[split]
        else:
            self.dataset = load_dataset("bookcorpus")#[split]
            
        self.n_rows = len(self.dataset) 
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return self.n_rows

    def __getitem__(self, item): # TODO Where is truncation if sequence is to long? How is ensured that both sentences fit into the sequence?
        
        # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
        s1, s2, is_next_label = self.get_sent(item)
        
        # Step 2: replace random words in EACH sentence with mask / random words # copied 
        t1_random, t1_label = self.random_word(s1)
        t2_random, t2_label = self.random_word(s2)
        
        # Step 3: Adding CLS and SEP tokens to the start and end of sentences # copied 
         # Adding PAD token for labels
        t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
        t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
        t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
        t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]

        # Step 4: combine sentence 1 and 2 as one input # copied 
        # adding PAD tokens to make the sentence same length as seq_len
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]
        padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}
        
        
        
        #return  {"s1":s1, "s2":s2, "is_next_label":is_next_label}
        #return {"t1_random":t1_random, "t1_label":t1_label, "t2_random":t2_random, "t2_label":t2_label}
    
    def get_sent(self, index): #selfmade
        '''gets sentence pair as dicitinary s1, s2, isNext'''
        isNext = random.random() > 0.5
        
        t1 = self.dataset[index]["text"]
        if isNext:
            t2 = self.dataset[index+1]["text"]
            return t1, t2, 1
        else:
            t2 = self.get_random_line(index+1)["text"]
            return t1, t2, 0
        
    def get_random_line(self, excludedIndex): #selfmade
        '''return random single sentence excluding'''
        randIndex = random.randint(1, self.__len__())
            
        # ensure that randIndex is not next sentence
        while randIndex == excludedIndex:
            randIndex = random.randint(1, self.__len__())
        
        return self.dataset[randIndex]

    def random_word(self, sentence): #copied
        tokens = sentence.split()
        output_label = []
        output = []

        # 15% of the tokens would be replaced
        for i, token in enumerate(tokens):
            prob = random.random()

            # remove cls and sep token
            token_id = self.tokenizer(token)['input_ids'][1:-1]

            if prob < 0.15:
                prob /= 0.15

                # 80% chance change token to mask token
                if prob < 0.8:
                    for i in range(len(token_id)):
                        output.append(self.tokenizer.vocab['[MASK]'])

                # 10% chance change token to random token
                elif prob < 0.9:
                    for i in range(len(token_id)):
                        output.append(random.randrange(len(self.tokenizer.vocab)))

                # 10% chance change token to current token
                else:
                    output.append(token_id)

                output_label.append(token_id)

            else:
                output.append(token_id)
                for i in range(len(token_id)):
                    output_label.append(0)

        # flattening
        output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
        output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
        assert len(output) == len(output_label)
        #assert len(output) == self.seq_len, "sequence length not fixed! "+str(len(output)) # from moritz
        return output, output_label