In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import numpy as np

In [3]:
MAX_TOKENS = 200

In [151]:
class Bert(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.emb_channels = 128
        self.vocab_size = vocab_size
        self.max_token_length = MAX_TOKENS
        self.emb = nn.Embedding(self.vocab_size, self.emb_channels)
        self.positional_encoding = nn.Parameter(torch.randn(self.max_token_length, self.emb_channels))

    def forward(self, x, unmasked_x, is_next):
        x = self.emb(x)
        print(x.shape)
        print(self.positional_encoding.shape)
        x += self.positional_encoding

In [116]:
from datasets import load_dataset

dataset = load_dataset('wikipedia', '20220301.en', split='train[:1%]')

In [117]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [118]:
import spacy

nlp = spacy.load("en_core_web_sm")

In [119]:
sentence_pairs = []

In [120]:

for i in range(len(dataset[:10])):
    text = dataset[i]['text']
    doc = nlp(text)
    sentences = [sent.text for sent in doc.sents]
    for s1, s2 in zip(sentences, sentences[1:]):
        sentence_pairs.append((s1, s2))

In [121]:
import random

def mask(sentence):
    for i in range(len(sentence)):
        if sentence[i] == tokenizer.sep_token_id:
            continue
        prob = random.random()
        if prob <= 0.15:
            sub_prob = random.random()
            if sub_prob <= 0.8:
                sentence[i] = tokenizer.mask_token_id
            elif sub_prob > 0.8 or sub_prob <= 0.9:
                sentence[i] = int(random.random() * (tokenizer.vocab_size-1))
    return sentence

In [138]:
import random

def generate_bert_input():
    FIRST_SENT_LIMIT = int(0.8 * MAX_TOKENS)
    SECOND_SENT_LIMIT = int(0.2 * MAX_TOKENS)
    prob = random.random()
    first_sent_idx = int(random.random() * len(sentence_pairs))
    is_next = True
    if prob > 0.5:
        second_sent_idx = first_sent_idx+1
    else:
        second_sent_idx = int(random.random() * len(sentence_pairs))
        is_next = False
    
    first_sent = tokenizer(sentence_pairs[first_sent_idx][0])['input_ids'][:FIRST_SENT_LIMIT]
    second_sent = tokenizer(sentence_pairs[second_sent_idx][1])['input_ids'][1:SECOND_SENT_LIMIT]
    pad = MAX_TOKENS - len(first_sent) - len(second_sent)

    masked_first_sent = mask(first_sent)
    masked_second_sent = mask(second_sent)

    masked_val = masked_first_sent + masked_second_sent + [tokenizer.pad_token_id] * pad
    val = first_sent + second_sent + [tokenizer.pad_token_id] * pad

    return {'masked_val': masked_val, 'unmasked_val': val, 'is_next': is_next}

In [152]:
model = Bert(tokenizer.vocab_size)
x = generate_bert_input()
model(torch.LongTensor(x['masked_val']), torch.LongTensor(x['unmasked_val']), x['is_next'])

torch.Size([200, 128])
torch.Size([200, 128])
