<a href="https://colab.research.google.com/github/bkim9/Resume/blob/main/15_9_The_Dataset_for_Pretraining_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
!pip install d2l==1.0.3
!pip install -U mxnet-cu112==1.9.1



In [22]:
import os
import random
import torch
from d2l import torch as d2l

In [270]:
d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

def _read_wiki(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # Uppercase letters are converted to lowercase ones
    paragraphs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

# nltk _read_wiki
import nltk
nltk.download('punkt')

def _read_wiki_nltk(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # Uppercase letters are converted to lowercase ones
    paragraphs = [nltk.tokenize.sent_tokenize(line.strip().lower())
                  for line in lines if len(nltk.tokenize.sent_tokenize(line.strip())) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [271]:
def _get_next_sentence(sentence, next_sentence, paragraphs):
  # if random.random() <  .5:
    if random.random() < 0.5:
        is_next = True
    else:
        # `paragraphs` is a list of lists of lists
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next

In [272]:
# random.choice(random.choice([[1,4,7],[2,5,8],[3,6,9]]))

In [273]:
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    nsp_data_from_paragraph = []
    #             (      7        - 1)
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = _get_next_sentence(
            paragraph[i],
            paragraph[i + 1],
            paragraphs)

        # Consider 1 '<cls>' token and 2 '<sep>' tokens
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue
        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))
    return nsp_data_from_paragraph

In [274]:
# d2l.get_tokens_and_segments(paragraph[0], paragraph[1])

In [64]:
paragraph[1]

['enraged',
 ',',
 'veeru',
 'attacks',
 'gabbar',
 "'s",
 'den',
 'and',
 'catches',
 'the',
 'dacoit']

In [275]:
_get_next_sentence(
    paragraph[0], paragraph[1], paragraphs)


(['veeru', 'returns', ',', 'and', 'jai', 'dies', 'in', 'his', 'arms'],
 ['the',
  'pdr',
  'has',
  'been',
  'constructed',
  'in',
  'separate',
  'link',
  'roads',
  'of',
  'between',
  '1',
  '@.@',
  '61',
  'km',
  '(',
  '1',
  '@.@',
  '00',
  'mi',
  ')',
  'and',
  '5',
  '@.@',
  '47',
  'km',
  '(',
  '3',
  '@.@',
  '40',
  'mi',
  ')',
  'around',
  'cardiff',
  'and',
  'to',
  'date',
  '22',
  'kilometres',
  '(',
  '14',
  'mi',
  ')',
  'including',
  'spurs',
  'have',
  'been',
  'opened',
  'to',
  'traffic',
  ',',
  'with',
  'plans',
  'for',
  'a',
  'further',
  '5',
  '@.@',
  '53',
  'km',
  '(',
  '3',
  '@.@',
  '44',
  'mi',
  ')'],
 False)

In [276]:
paragraph = paragraphs[0]
vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
                    '<pad>', '<mask>', '<cls>', '<sep>'])


In [277]:
# range(len(paragraph)-1)

In [279]:
# _get_nsp_data_from_paragraph(
#                 paragraphs[0],
#                 paragraphs,
#                 d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
#                     '<pad>', '<mask>', '<cls>', '<sep>']),
#                 max_len)

In [280]:
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
                        vocab):
    # For the input of a masked language model, make a new copy of tokens and
    # replace some of them by '<mask>' or random tokens
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    # Shuffle for getting 15% random tokens for prediction in the masked
    # language modeling task
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        # 80% of the time: replace the word with the '<mask>' token
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            # 10% of the time: keep the word unchanged
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            # 10% of the time: replace the word with a random word
            else:
                masked_token = random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append(
            (mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

In [281]:
# _get_mlm_data_from_tokens  (examples[0][0], vocab)
def _get_mlm_data_from_tokens(tokens        , vocab):
    candidate_pred_positions = []
    # `tokens` is a list of strings
    for i, token in enumerate(tokens):
        # Special tokens are not predicted in the masked language modeling
        # task
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
    # 15% of random tokens are predicted in the masked language modeling task
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
                                                    tokens,
                                                    candidate_pred_positions,
                                                    num_mlm_preds,
                                                    vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels,
                                       key=lambda x: x[0])
    pred_positions = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]
    #          Q                      Qn                  Ans

In [283]:
def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    all_token_ids, all_segments, valid_lens,  = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    nsp_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments,
         is_next) in examples:
        all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (
            max_len - len(token_ids)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (
            max_len - len(segments)), dtype=torch.long))
        # `valid_lens` excludes count of '<pad>' tokens
        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions + [0] * (
            max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
        # Predictions of padded tokens will be filtered out in the loss via
        # multiplication of 0 weights
        all_mlm_weights.append(
            torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (
                max_num_mlm_preds - len(pred_positions)),
                dtype=torch.float32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (
            max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
        nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
    return (all_token_ids, all_segments, valid_lens, all_pred_positions,
            all_mlm_weights, all_mlm_labels, nsp_labels)

In [284]:

class _WikiTextDataset(torch.utils.data.Dataset):
#_WikiTextDataset('../data/wikitext-2',  64    )
    def __init__ (self, paragraphs    , max_len):
        # Input `paragraphs[i]` is a list of sentence strings representing a
        # paragraph; while output `paragraphs[i]` is a list of sentences
        # representing a paragraph, where each sentence is a list of tokens
        paragraphs = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraphs]
        sentences = [sentence        for paragraph in paragraphs
                 for sentence         in paragraph]
        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
            '<pad>', '<mask>', '<cls>', '<sep>'])

        # Get data for nsp
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(
                paragraph, paragraphs, self.vocab, max_len))

        # Get data for mlm
        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
                      + (segments, is_next))
                     for tokens, segments, is_next in examples]
        # Pad inputs
        (self.all_token_ids,      self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights,
         self.all_mlm_labels,     self.nsp_labels) = _pad_bert_inputs(
            examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)

In [286]:
# examples[0][0]

In [287]:
# _get_mlm_data_from_tokens(examples[0][0], vocab)

In [288]:
#  examples = []
#  examples.extend(_get_nsp_data_from_paragraph(
#                 paragraphs[0],
#                 paragraphs,
#                 d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
#             '<pad>', '<mask>', '<cls>', '<sep>']),
#                 max_len))

In [50]:
paragraphs = _read_wiki('../data/wikitext-2')
paragraphs = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraphs]
sentences = [sentence        for paragraph in paragraphs
                         for sentence in paragraph]

In [None]:
paragraphs_nltk = _read_wiki_nltk('../data/wikitext-2')
paragraphs_nltk = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraphs_nltk]
sentences_nltk = [sentence        for paragraph in paragraphs_nltk
                 for sentence in paragraph]

In [289]:
# paragraphs[0]

In [290]:
# d2l.Vocab(['Hello world nice to meet you.', 'Always!','www.'], min_freq=1, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])

In [291]:
# sentences[:3]

In [34]:
d2l.get_dataloader_workers()

4

In [35]:
# d2l.download_extract('wikitext-2', 'wikitext-2')

'../data/wikitext-2'

In [264]:
load_data_wiki(batch_size, max_len)
load_data_wiki_nltk(batch_size, max_len)




(<torch.utils.data.dataloader.DataLoader at 0x7d1a60998880>,
 <d2l.torch.Vocab at 0x7d1a6eeecaf0>)

In [31]:
batch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)

for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
     mlm_Y, nsp_y) in train_iter:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    break



torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])


In [292]:
batch_size, max_len = 512, 64
# nltk
train_iter_nltk, vocab_nltk = load_data_wiki_nltk(batch_size, max_len)

for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
     mlm_Y, nsp_y) in train_iter_nltk:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    break



torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])


In [293]:
len(vocab)

20256

In [294]:
len(vocab_nltk)

20074

15.9.4. Exercises
1. Try other sentence splitting techniques
using nltk.

  
  the vocab size decrease from 20256 to 20074

2. What is the vocabulary size if we do not filter out any infrequent token?

In [142]:
!pip install nltk



In [144]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [146]:
sentences = 'This is great ! Why not ?'
nltk.tokenize.sent_tokenize(sentences)


['This is great !', 'Why not ?']

In [267]:
d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

def _read_wiki(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # Uppercase letters are converted to lowercase ones
    paragraphs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

# nltk _read_wiki
import nltk
nltk.download('punkt')

def _read_wiki_nltk(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # Uppercase letters are converted to lowercase ones
    paragraphs = [nltk.tokenize.sent_tokenize(line.strip().lower())
                  for line in lines if len(nltk.tokenize.sent_tokenize(line.strip())) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [268]:
readWiki = _read_wiki('../data/wikitext-2')
# readWiki[:2]

In [269]:
 nltk_res = _read_wiki_nltk('../data/wikitext-2')
#  nltk_res[:2]

In [262]:
nltk_paragraphs = _read_wiki_nltk('../data/wikitext-2')
nltk_paragraphs = [d2l.tokenize(
            nltk_paragraph, token='word') for nltk_paragraph in nltk_paragraphs]
nltk_sentences = [nltk_sentence        for nltk_paragraph in nltk_paragraphs
                         for nltk_sentence in nltk_paragraph]