In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import numpy as np

In [259]:
MAX_TOKENS = 200

In [240]:
class Bert(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoding = tiktoken.get_encoding("r50k_base")
        self.emb_size = self.encoding.n_vocab
        self.emb_channels = 128
        self.max_token_length = MAX_TOKENS
        self.emb = nn.Embedding(self.emb_size, self.emb_channels)
        self.positional_encoding = nn.Parameter(torch.randn(self.max_token_length, self.emb_channels))

    def forward(self):
        x = self.emb()
        

In [213]:
from datasets import load_dataset

dataset = load_dataset('wikipedia', '20220301.en', split='train[:1%]')

In [214]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [215]:
import spacy

nlp = spacy.load("en_core_web_sm")

In [216]:
sentence_pairs = []

In [234]:

for i in range(len(dataset[:10])):
    text = dataset[i]['text']
    doc = nlp(text)
    sentences = [sent.text for sent in doc.sents]
    for s1, s2 in zip(sentences, sentences[1:]):
        sentence_pairs.append((s1, s2))

In [235]:
import random

def mask(sentence):
    for i in range(len(sentence)):
        prob = random.random()
        if prob <= 0.15:
            sub_prob = random.random()
            if sub_prob <= 0.8:
                sentence[i] = tokenizer.mask_token_id
            elif sub_prob > 0.8 or sub_prob <= 0.9:
                sentence[i] = int(random.random() * (tokenizer.vocab_size-1))
    return sentence

In [260]:
import random

def generate_bert_input():
    FIRST_SENT_LIMIT = int(0.8 * MAX_TOKENS)
    SECOND_SENT_LIMIT = int(0.2 * MAX_TOKENS)
    prob = random.random()
    first_sent_idx = int(random.random() * len(sentence_pairs))
    isNext = True
    if prob > 0.5:
        second_sent_idx = first_sent_idx+1
    else:
        second_sent_idx = int(random.random() * len(sentence_pairs))
        isNext = False
    
    first_sent = tokenizer(sentence_pairs[first_sent_idx][0])['input_ids'][:FIRST_SENT_LIMIT]
    second_sent = tokenizer(sentence_pairs[second_sent_idx][1])['input_ids'][1:SECOND_SENT_LIMIT]
    pad = MAX_TOKENS - len(first_sent) - len(second_sent)

    masked_first_sent = mask(first_sent)
    masked_second_sent = mask(second_sent)

    masked_val = masked_first_sent + masked_second_sent + [tokenizer.pad_token_id] * pad
    val = first_sent + second_sent + [tokenizer.pad_token_id] * pad

    return {'masked_val': masked_val, 'unmasked_val': val, 'isNext': isNext}

In [262]:
tokenizer.decode(generate_bert_input()['masked_val'])

'[MASK] ins clarifiedtion effects the tombstone [MASK] al [MASK] temperature effects bakery on the amount of albedo and the level of local insolation ( solar irrad [MASK] ) ; high albedo areas in the arctic and antarctic regions are cold due to low insolation, whereas areas such [MASK] the sahara desert, which alsoroving a relatively [MASK] albedo, [MASK] [MASK] conversely due to high insolation. [SEP] because insola [MASK] plays such a big role in the [MASK] [MASK] cooling effects of al aka, [MASK] insolation areas like the trop [MASK] will tend to show a more pronounced fluctuation [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [