In [57]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import numpy as np

In [928]:
MAX_TOKENS = 100
from datasets import load_dataset

dataset = load_dataset('wikipedia', '20220301.en', split='train[:1%]')

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

import spacy

nlp = spacy.load("en_core_web_sm")

In [929]:
class LayerNorm(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + 1e-6) + self.beta

In [1127]:
class Bert(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.n = 3
        self.emb_channels = 128
        self.vocab_size = vocab_size
        self.max_token_length = MAX_TOKENS
        self.emb = nn.Embedding(self.vocab_size, self.emb_channels)
        self.segmented_encoding = nn.Parameter(torch.randn(2, self.emb_channels))
        self.positional_encoding = nn.Parameter(torch.randn(self.max_token_length, self.emb_channels))
        self.ln1 = LayerNorm(self.emb_channels)
        self.ln2 = LayerNorm(self.emb_channels)
        self.l1 = nn.Linear(self.emb_channels, 250)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(250, self.emb_channels)
        self.linear = nn.Linear(self.emb_channels, self.vocab_size)
        self.train = True

        # transformer block
        self.head_count = 8
        self.qW = nn.Parameter(torch.randn(self.emb_channels, self.emb_channels))
        self.kW = nn.Parameter(torch.randn(self.emb_channels, self.emb_channels))
        self.vW = nn.Parameter(torch.randn(self.emb_channels, self.emb_channels))
        self.oW = nn.Parameter(torch.randn(self.emb_channels, self.emb_channels))

        self.nsp_classifier = nn.Linear(self.emb_channels, 1)

    def self_attention(self, x):
        q = (self.qW.view(self.head_count, self.emb_channels, -1).transpose(-1, -2) @ x.transpose(-1,-2)).transpose(-1, -2)        
        k = (self.kW.view(self.head_count, self.emb_channels, -1).transpose(-1, -2) @ x.transpose(-1,-2)).transpose(-1, -2)        
        v = (self.vW.view(self.head_count, self.emb_channels, -1).transpose(-1, -2) @ x.transpose(-1,-2)).transpose(-1, -2)        
        qK = q @ k.transpose(-1, -2)
        attention_weights = (qK @ v).view(-1, self.emb_channels)
        out = attention_weights @ self.oW
        out += x
        return out

    def feed_forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        return x
        

    def forward(self, x, unmasked_x, is_next):
        masked_pos = ((x == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]).tolist()
        mpos_tokid = [(masked_pos[i], unmasked_x[masked_pos[i]]) for i in range(len(masked_pos))]
        sep_pos = (x == tokenizer.sep_token_id).nonzero(as_tuple=True)[0][0].item()
        try:
            sep_pos2 = (x == tokenizer.pad_token_id).nonzero(as_tuple=True)[0][0].item()
        except:
            sep_pos2 = MAX_TOKENS
        x = self.emb(x)
        sent1 = x[:sep_pos]
        sent2 = x[sep_pos:sep_pos2]
        pad = x[sep_pos2:]
        sent1_with_segment = sent1 + self.segmented_encoding[0].unsqueeze(0)
        sent2_with_segment = sent2 + self.segmented_encoding[1].unsqueeze(0)
        if pad.size(0) > 0:
            x = torch.cat([sent1_with_segment, sent2_with_segment, pad], dim=0)
        else:
            x = torch.cat([sent1_with_segment, sent2_with_segment], dim=0)
        x += self.positional_encoding
        for _ in range(self.n):
            x = self.self_attention(x)
            x = self.ln1(x)
            x = self.feed_forward(x)
            x = self.ln2(x)
        if self.train:
            nsp_loss = self.nsp(x[0], is_next)
            x = self.linear(x)
            x = self.mlm(x, mpos_tokid) + nsp_loss
        if not self.train:
            x = self.linear(x)
        return x
    
    def mlm(self, x, mpos_tokid):
        mpos = torch.Tensor([(x[mpos_tokid[i][0]]).tolist() for i in range(len(mpos_tokid))])
        tok_id = torch.LongTensor([mpos_tokid[i][1] for i in range(len(mpos_tokid))])
        lossval = F.cross_entropy(mpos, tok_id)
        return lossval
    
    def nsp(self, x, is_next):
        logits = self.nsp_classifier(x)
        target = torch.tensor([is_next], dtype=torch.float)
        loss = F.binary_cross_entropy_with_logits(logits, target)
        return loss


In [938]:
sentence_pairs = []

In [939]:

for i in range(len(dataset[:10])):
    text = dataset[i]['text']
    doc = nlp(text)
    sentences = [sent.text for sent in doc.sents]
    for s1, s2 in zip(sentences, sentences[1:]):
        sentence_pairs.append((s1, s2))

In [940]:
import random

def mask(sentence):
    sentence = sentence.copy()
    for i in range(len(sentence)):
        if sentence[i] == tokenizer.sep_token_id:
            continue
        prob = random.random()
        if prob <= 0.15:
            sub_prob = random.random()
            if sub_prob <= 0.8:
                sentence[i] = tokenizer.mask_token_id
            elif sub_prob > 0.8 or sub_prob <= 0.9:
                sentence[i] = int(random.random() * (tokenizer.vocab_size-1))
    return sentence

In [941]:
import random

def generate_bert_input():
    FIRST_SENT_LIMIT = int(0.8 * MAX_TOKENS)
    SECOND_SENT_LIMIT = int(0.2 * MAX_TOKENS)
    prob = random.random()
    first_sent_idx = int(random.random() * len(sentence_pairs))
    is_next = True
    if prob > 0.5:
        second_sent_idx = first_sent_idx+1
    else:
        second_sent_idx = int(random.random() * len(sentence_pairs))
        is_next = False
    
    first_sent = tokenizer(sentence_pairs[first_sent_idx][0])['input_ids'][:FIRST_SENT_LIMIT]
    second_sent = tokenizer(sentence_pairs[second_sent_idx][1])['input_ids'][1:SECOND_SENT_LIMIT]
    pad = MAX_TOKENS - len(first_sent) - len(second_sent)

    masked_first_sent = mask(first_sent)
    masked_second_sent = mask(second_sent)

    masked_val = masked_first_sent + masked_second_sent + [tokenizer.pad_token_id] * pad
    val = first_sent + second_sent + [tokenizer.pad_token_id] * pad

    return {'masked_val': masked_val, 'unmasked_val': val, 'is_next': is_next}

In [1129]:
model = Bert(tokenizer.vocab_size)
optim = torch.optim.Adam(model.parameters())

In [1130]:
from tqdm import trange

In [1131]:
model.train = True
for _ in (t:=trange(100)):
    x = generate_bert_input()
    
    loss = model(torch.LongTensor(x['masked_val']), x['unmasked_val'], x['is_next'])
    
    optim.zero_grad()

    loss.backward()

    optim.step()

    t.set_description(f'loss: {loss.item()}')


loss: 11.92384147644043:  60%|██████    | 60/100 [00:01<00:00, 42.74it/s] 


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn