# Chapt 15.9, Dataset for pretraining BERT
https://d2l.ai/chapter_natural-language-processing-pretraining/bert-dataset.html




In [1]:
!pip install setuptools==66
!pip install matplotlib_inline
!pip install d2l==1.0.0b

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting setuptools==66
  Downloading setuptools-66.0.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 67.6.1
    Uninstalling setuptools-67.6.1:
      Successfully uninstalled setuptools-67.6.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0mSuccessfully installed setuptools-66.0.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/pub

In [2]:
import collections
import math
import os
import random
import torch
from d2l import torch as d2l
from torch import nn

In [3]:
d2l.DATA_HUB['wikitext-2'] = (
  'https://s3.amazonaws.com/research.metamind.io/wikitext/'
  'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

def _read_wiki(data_dir, debug = False):
  file_name = os.path.join(data_dir, 'wiki.train.tokens')
  with open(file_name, 'r') as f:
    lines = f.readlines()

  # each paragraph is seperated by dot. This is the format of wikitext dataset
  # so each element in pargraphs is an array of sentences
  paragraphs = [line.strip().lower().split(' . ') 
  for line in lines if len(line.split(' . ')) >= 2]

  #shuffle the paragrams
  if debug:
    print('total paragraphs ', len(paragraphs))
    print('sample paragraph')
    for i in range(5):
      print(paragraphs[random.randint(1,len(paragraphs))])

  random.shuffle(paragraphs)

  return paragraphs

#definition of _get_tokens_and_segments
#we're gonna use 
def get_tokens_and_segments(tokens_a, tokens_b = None):
  tokens = ['<cls>'] + tokens_a + ['<sep>']
  segments = [0] * (len(tokens_a) + 2)
  # if token_b none: --> (<cls> <token1> ... <sep>) 
  if tokens_b is not None:
    # if token_b not none: --> (<cls> <tokena> ... <sep> <tokenb> ... <sep>) 
    tokens += tokens_b + ['<sep>']
    segments += [1] * (len(tokens_b) + 1)
  #segments = [0,0,0,...,1,1,1] if token b is not none
  #else = [0,0,0...]
  return tokens, segments

def _get_next_sentence(sentence, next_sentence, paragraphs):
  if random.random() < 0.5:
    is_next = True
  
  else:
    next_sentence = random.choice(random.choice(paragraphs))
    is_next = False
  
  return sentence, next_sentence, is_next

#prepare data for next sentence prediction
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
  nsp_data_from_paragraph = []

  for i in range(len(paragraph) - 1):
    token_a, token_b, is_next = _get_next_sentence(
      paragraph[i],
      paragraph[i + 1],
      paragraphs
    )

    if len(token_a) + len(token_b) + 3 > max_len: continue
  
    tokens, segments = d2l.get_tokens_and_segments(token_a, token_b)

    nsp_data_from_paragraph.append((tokens, segments, is_next))

  return nsp_data_from_paragraph

def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab):
  #something like [token1, token2, ...]
  mlm_input_tokens = [token for token in tokens]
  pred_positions_and_labels = []

  #shuffle the index of to be predicted token
  random.shuffle(candidate_pred_positions)

  for mlm_pred_position in candidate_pred_positions:
    #if pred positions must not exceed num_mlm_preds 
    if len(pred_positions_and_labels) >= num_mlm_preds:
      break
    
    masked_token = None

    #80% of the time, replace tokens with <mask>
    if random.random() > 0.8:
      masked_token = '<mask>'
    else:
      # because this falls in 20% (0.2), 
      #10% (or 0.5 of 20%) keep the word unchange
      if random.random() < 0.5:
        masked_token = tokens[mlm_pred_position]
      # the remaining 10% (other 0.5 of that that 20%), pick a random word from vocab
      else:
        masked_token = random.choice(vocab.idx_to_token)
      
    #change that mlm_pred_position to masked token
    mlm_input_tokens[mlm_pred_position] = masked_token
    #put the ground truth at that position
    pred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))

  return mlm_input_tokens, pred_positions_and_labels

def _get_mlm_data_from_tokens(tokens, vocab):
  #index of tokens
  candidate_pred_positions = []

  for i, token in enumerate(tokens):
    #ignore these tokens
    if token in ['<cls>', '<sep>']:
      continue
    candidate_pred_positions.append(i)
  
  # num mlm preds is set to 15% of all tokens length
  num_mlm_preds = max(1, round(len(tokens) * 0.15))

  mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab)
  pred_positions_and_labels = sorted(pred_positions_and_labels, key = lambda x: x[0])

  pred_positions = [v[0] for v in pred_positions_and_labels]
  mlm_pred_labels = [v[1] for v in pred_positions_and_labels]

  return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

In [4]:
#Padding the BERT input
#return: 
#   all_token_ids,
#   all_segments, 
#   valid_lens, 
#   all_pred_positions, 
#   all_mlm_weights, 
#   all_mlm_labels, 
#   nsp_labels

def _pad_bert_inputs(examples, max_len, vocab):
  max_num_mlm_preds = round(len(examples) * 0.15)
  all_token_ids, all_segments, valid_lens = [], [], []  
  all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
  nsp_labels = []

  for token_ids, pred_positions, mlm_pred_label_ids, segments, is_next in examples:
    #pad token ids with vocab['<pad>']
    all_token_ids.append(
        torch.tensor(token_ids + [vocab['<pad>']] * (max_len - len(token_ids)), dtype=torch.long)
    )
    #pad segments with 0s
    all_segments.append(
        torch.tensor(segments + [0] * (max_len - len(segments)), dtype=torch.long)
    )
    valid_lens.append(torch.tensor(len(token_ids), dtype=torch.long))
    #pad pad_predictions with 0s
    all_pred_positions.append(
      torch.tensor([pred_positions + [0] * (max_num_mlm_preds - len(pred_positions))], dtype=torch.long)
    )
    # use 1.0 to the length of mlm_pred_label_ids, the rest uses 0s, max len is max_num_mlm_nums
    all_mlm_weights.append(
      torch.tensor([1.0] * len(mlm_pred_label_ids) + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.float32)
    )
    #pad mlm_labels with 0s
    all_mlm_labels.append(
      torch.tensor(mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long)
    )
    #nsp labels 
    nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
  
  return (all_token_ids, all_segments, valid_lens, all_pred_positions, all_mlm_weights, all_mlm_labels, nsp_labels)


In [5]:
class _WikiTextDataset(torch.utils.data.Dataset):
  def __init__(self, paragraphs, max_len, debug = False):
    #d2l.tokenize seperates each word in each paragraph --> [word] for each paragraph
    #format[[para1], [para2], ...]
    paragraphs = [d2l.tokenize(p, token = 'word') for p in paragraphs]
    #sentences get all the words in paragraphs
    #     ['word1','word2', ...]
    sentences = [sentence for p in paragraphs for sentence in p]

    self.vocab = d2l.Vocab(sentences, min_freq = 5, reserved_tokens=['<cls>', '<sep>', '<mask>', '<pad>'])

    examples = []
    if debug:
      print()
      print('len vocab =', len(self.vocab))
      print('getting nsp data from paragraphs')

    for index, p in enumerate(paragraphs):
      #get pretraining for next sentence prediction
      # output of _get_nsp_data_from_paragraph: [(tokens, segments, is_next), ...]
      # reminder that tokens = <cls> token a + <sep> + tokenb + <sep>
      examples.extend(_get_nsp_data_from_paragraph(p, paragraphs, self.vocab, max_len))
  
    #get data for masked language modeling
    # combine tuples in python (a,b,c) + (d,e) = (a,b,c,d,e)
    examples = [(_get_mlm_data_from_tokens(tokens, self.vocab) + (segments, is_next))\
                for tokens, segments, is_next in examples]

    if debug:
      print('sample of example')
      print('masked tokens ', examples[0][0])
      print('pred positions ', examples[0][1])
      print('mlm pred labels', examples[0][2])
      print('segments ', examples[0][3])
      print('is_next ', examples[0][4])
    
    #pad input
    (self.all_token_ids, 
      self.all_segments, 
      self.valid_lens, 
      self.all_pred_positions, 
      self.mlm_weights, 
      self.all_mlm_labels, 
      self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)

  def __getitem__(self, i):
    #getter
    return 
    (self.all_token_ids[i], 
      self.all_segments[i], 
      self.valid_lens[i], 
      self.all_pred_positions[i], 
      self.mlm_weights[i], 
      self.all_mlm_labels[i], 
      self.nsp_labels[i])
  
  def __len__(self):
    return len(self.all_token_ids)

#Just a testing some functionality
p = ["in 1881 the observatory 's director , charles <unk> , suggested adding a high @-@ quality telescope to the observatory",
     'he felt that direct solar observations would lead to a better understanding of sunspot effects on weather ( as late as 1910 the observatory \'s then @-@ director , r. f. <unk> , noted that " sun spots have more to do with our weather conditions than have the rings around the moon',
     '" )', 
     '<unk> , the canadian government ( having formed in 1867 ) was interested in taking part in the major international effort to accurately record the december 1882 transit of venus .']
    
sample1 = d2l.tokenize(p, token='word')
print(sample1)

print([s for s in sample1[0]])

[['in', '1881', 'the', 'observatory', "'s", 'director', ',', 'charles', '<unk>', ',', 'suggested', 'adding', 'a', 'high', '@-@', 'quality', 'telescope', 'to', 'the', 'observatory'], ['he', 'felt', 'that', 'direct', 'solar', 'observations', 'would', 'lead', 'to', 'a', 'better', 'understanding', 'of', 'sunspot', 'effects', 'on', 'weather', '(', 'as', 'late', 'as', '1910', 'the', 'observatory', "'s", 'then', '@-@', 'director', ',', 'r.', 'f.', '<unk>', ',', 'noted', 'that', '"', 'sun', 'spots', 'have', 'more', 'to', 'do', 'with', 'our', 'weather', 'conditions', 'than', 'have', 'the', 'rings', 'around', 'the', 'moon'], ['"', ')'], ['<unk>', ',', 'the', 'canadian', 'government', '(', 'having', 'formed', 'in', '1867', ')', 'was', 'interested', 'in', 'taking', 'part', 'in', 'the', 'major', 'international', 'effort', 'to', 'accurately', 'record', 'the', 'december', '1882', 'transit', 'of', 'venus', '.']]
['in', '1881', 'the', 'observatory', "'s", 'director', ',', 'charles', '<unk>', ',', 'sugg

In [None]:
#@title Load pretraining data for BERT
batch_size = 512 #@param {type:"number"}
max_len = 64 #@param {type:"number"}
debug = True #@param {type:"boolean"}

#Download wikitext 2 and use WikiTextDataset object to generate pretraining example
def load_data_wiki(batch_size, max_len, debug = False):
  num_workers = d2l.get_dataloader_workers()
  data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')

  #get paragraphs
  paragraphs = _read_wiki(data_dir,debug) 
  train_set = _WikiTextDataset(paragraphs, max_len)

  train_iter = torch.utils.data.DataLoader(
      train_set, batch_size, 
      shuffle = True, num_workers=num_workers)

  return train_set, train_iter, train_set.vocab

train_set, train_iter, vocab = load_data_wiki(batch_size, max_len, debug)
print(train_set[0])
print(len(vocab))



total paragraphs  15496
sample paragraph
['wang \'s tale portrays zhou as an aging itinerant <unk> with " a fame <unk> like thunder " throughout the underworld society of <unk>', 'he is made the sworn brother of the outlaw " <unk> <unk> " lu zhishen , a military officer @-@ turned @-@ fighting monk , who is , according to hsia , first among the most popular protagonists of the water margin', 'he is also given the nickname " iron arm " ( <unk> ) , which carried over into the title of his fictional biography iron arm , golden sabre', 'while the tale fails to explain the reason for the moniker , it does mention zhou \'s ability to direct his <unk> to any part of his body to make it hard enough to <unk> the " iron shirt " technique of another martial artist', 'furthermore , zhou shares the same nickname with cai fu , an executioner @-@ turned @-@ outlaw known for his ease in wielding a heavy sword .']
["silver bullet 's layout passes through three of the park 's themed areas : ghost town ,

In [7]:
print(next(iter(train_iter)))


TypeError: ignored