In [241]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from reformer_pytorch import Reformer, ReformerLM
from transformers import BertTokenizer, AdamW

import re
import os
from tqdm import tqdm, tqdm_notebook
from glob import glob

import json

## Shorter max length to test faster on current hardware

In [351]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
tokenizer.max_len = 128

In [352]:
model = ReformerLM(
    num_tokens = tokenizer.vocab_size,
    dim = 512,
    depth = 6,
    heads = 8,
    max_seq_len = tokenizer.max_len,
    causal = True
)

In [379]:
test = 'Hello, my dog is cute'

In [380]:
tok = tokenizer.encode(test, max_length=tokenizer.max_len, add_special_tokens=True)
tok = torch.tensor(tok, dtype=torch.long)
tok.shape

torch.Size([8])

In [381]:
tokenizer.decode(tok)

'[CLS] Hello, my dog is cute [SEP]'

In [382]:
def mask_tokens(inputs: torch.Tensor, tokenizer, mlm_probability=0.15, pad=True):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
    labels = inputs.clone()
    # mlm_probability defaults to 0.15 in Bert
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    if pad:
        input_pads = tokenizer.max_len - inputs.shape[-1]
        label_pads = tokenizer.max_len - labels.shape[-1]
        
        inputs = F.pad(inputs, pad=(0,input_pads), value=tokenizer.pad_token_id)
        labels = F.pad(labels, pad=(0,label_pads), value=tokenizer.pad_token_id)
    
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels


In [383]:
inputs, labels = mask_tokens(tok.unsqueeze(0), tokenizer, pad=True)

In [384]:
tokenizer.decode(inputs.squeeze(0))

'[CLS] Hello, my dog [MASK] cute [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [385]:
tokenizer.decode(labels.squeeze(0))

'[UNK] [UNK] [UNK] [UNK] [UNK] is [UNK] [UNK] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

# Predictions in shape [batch_size, max_seq_len, vocab_size]

In [386]:
pred = model(inputs)
pred.shape

torch.Size([1, 128, 28996])

In [387]:
tokenizer.decode(torch.argmax(pred, dim=-1).squeeze(0))

'Tokyo [unused94] conference Chemical Pope rush sniper archaeological families এ Application uprisingnical Alleyrna Corps mankind youth Dick twelfth Exeter Moines Poznań thrill twelfthrna Runner Ā stock thirteen Consequentlyrna gazingdar stock Kennedy⊆ Ingram compiler twelfth 1964 shipped pub Programs⊆ crouched twelfthhalt twelfth stomach Poznań twelfth uprising bust Ingramficiency Alley twelfth compilerza level 1888halt summon uprising winrnarna Ingram Ā analogous beings Programs jailed Reed al foyer Ingram caring examplesficiency Ingram stockvin turbineswasorium win specialised postponedrkin Runnerficiency miniseries 」 Ingram reinforcements caring tolerate stock supremacy hence madnessstone lighthouse ventralwas twelfth laboratorystonewas Ingram Moines road nativesstone Thursday ventral remnants• Alternative Poznańficiency Poznań Reed mechanism natives Ingram'

In [388]:
loss_fn = nn.CrossEntropyLoss()  # -100 index = padding token

In [389]:
masked_lm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size), labels.view(-1))
masked_lm_loss

tensor(11.8019, grad_fn=<NllLossBackward>)

In [390]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
total_loss = 0.0
model.train()

model.to(device)
inputs = inputs.to(device)
labels = labels.to(device)

loss = []
optimizer = AdamW(params=model.parameters())

for _ in tqdm(range(100)):
    
    pred = model(inputs)
    mlm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size), labels.view(-1))
    
    total_loss += mlm_loss.item()
    loss.append(mlm_loss.item())
    
    mlm_loss.backward()
    optimizer.step()
    model.zero_grad()




  0%|                                                                                          | 0/100 [00:00<?, ?it/s][A[A[A


  1%|▊                                                                                 | 1/100 [00:05<08:20,  5.05s/it][A[A[A


KeyboardInterrupt: 

In [27]:
pred = model(inputs.cuda())

In [28]:
pred = torch.argmax(pred.detach().cpu(), dim=-1)
tokenizer.decode(pred.squeeze(0))

'my my [PAD] my [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [29]:
tokenizer.decode(labels.squeeze(0))

'[UNK] [UNK] [UNK] my [UNK] [UNK] [UNK] [UNK] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

# Number of Wiki Files

In [391]:
wikifiles = []

for root, dirs, files in os.walk('D:/Data/enwiki'):
    for file in files:
        wikifiles.append(os.path.join(root, file))
print(f'Total Files: {len(wikifiles)}')

Total Files: 8173


In [330]:
jsons = []

with open(wikifiles[0], 'r', encoding='utf-8') as file:
    data = file.readlines()
    for d in data:
        jsons.append(json.loads(d))
    file.close()

In [331]:
len(jsons)

37

In [334]:
print(jsons[1]['text'])

Autism

Autism is a developmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some children with autism experience worsening in their communication and social skills after reaching developmental milestones at a normal pace.
Autism is associated with a combination of genetic and environmental factors. Risk factors during pregnancy include certain infections, such as rubella, toxins including valproic acid, alcohol, cocaine, pesticides, lead, and air pollution, fetal growth restriction, and autoimmune diseases. Controversies surround other proposed environmental causes; for example, the vaccine hypothesis, which has been disproven. Autism affects information processing in the brain and how nerve cells and their synapses connect and organize; how this occurs is not well understood. The Di

In [321]:
re.sub('\\n+', ' ', jsons[0]['text'])

'Anarchism Anarchism is an anti-authoritarian political and social philosophy that rejects hierarchies as unjust and advocates their replacement with self-managed, self-governed societies based on voluntary, cooperative institutions. These institutions are often described as stateless societies, although several authors have defined them more specifically as distinct institutions based on non-hierarchical or free associations. Anarchism\'s central disagreement with other ideologies is that it holds the state to be undesirable, unnecessary, and harmful. Anarchism is usually placed on the far-left of the political spectrum, and much of its economics and legal philosophy reflect anti-authoritarian interpretations of communism, collectivism, syndicalism, mutualism, or participatory economics. As anarchism does not offer a fixed body of doctrine from a single particular worldview, many anarchist types and traditions exist and varieties of anarchy diverge widely. Anarchist schools of thought