In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from collections import OrderedDict
import zstandard as zstd
import io
import json

import torch
from torch.utils.data import DataLoader, Sampler
import numpy as np
import pandas as pd
import fairseq
from fairseq.data.encoders.gpt2_bpe import GPT2BPE, GPT2BPEConfig
from fairseq import options
from fairseq_cli.preprocess import main as preprocess
from fairseq.data import Dictionary, TokenBlockDataset, MonolingualDataset, PrependTokenDataset, MaskTokensDataset, RightPadDataset,  IdDataset, NestedDictionaryDataset, NumSamplesDataset, \
    NestedDictionaryDataset, NumelDataset, SortDataset
from fairseq.data.encoders import BPE_REGISTRY, register_bpe, build_bpe

## Preprocess Dataset OpenWebText

In [2]:
dataset_path = '/mnt/dl/fairseq/Masked_Language_Model/openwebtext/'

In [3]:
arr = np.load(os.path.join(dataset_path, 'owt0.npz'))

In [4]:
arr['arr_0'].shape

(23055709,)

In [5]:
arr['arr_0']

array([13749, 12409,   716, ...,  5239,    91,    29])

In [6]:
arr['arr_0'].astype(np.uint16)

array([13749, 12409,   716, ...,  5239,    91,    29], dtype=uint16)

In [7]:
os.path.join(dataset_path, 'encoder.json')

'/mnt/dl/fairseq/Masked_Language_Model/openwebtext/encoder.json'

In [8]:
gpt2_bpe_cfg = GPT2BPEConfig(gpt2_encoder_json=os.path.join(dataset_path, 'encoder.json'),
                            gpt2_vocab_bpe=os.path.join(dataset_path, 'vocab.bpe')
                    )

In [9]:
gpt2_bpe = GPT2BPE(gpt2_bpe_cfg)

In [10]:
gpt2_bpe.bpe.decode(arr['arr_0'][:100]).strip().split('\n')

['Historical amnesia is at once the most endearing and the most frustrating of American qualities. On the one hand, it means that -- F. Scott Fitzgerald notwithstanding -- there really are second acts in American lives. People can move somewhere else, reinvent themselves, start again.',
 '',
 "On the other hand, our inability to remember what our policy was last week, never mind last decade, drives outsiders crazy. We forget that we supported the dictator before we decided to destroy him. Then we can't"]

In [11]:
def write_train_file(dataset_size = 1000):
    rng = np.random.RandomState(seed=0)
    filename = os.path.join(dataset_path, '2020-01.jsonl.zst')
    data = []
    with open(filename, 'rb') as fh, open(os.path.join(dataset_path, 'train.raw.txt'), 'w', encoding='utf-8') as oh:
        dctx = zstd.ZstdDecompressor(max_window_size=2147483648)
        stream_reader = dctx.stream_reader(fh)
        text_stream = io.TextIOWrapper(stream_reader, encoding='utf-8')
        for i, line in enumerate(text_stream):
            text = json.loads(line)['text'].strip()
            if not text.strip():
                continue
            if len(data) == dataset_size:
                break
            # if rng.uniform(0, 1, (1, ))[0] < 100000:
            data.append(text + '\n')
            oh.write(data[-1])
    
    return data
            

In [12]:
data = write_train_file(10000)

In [13]:
len(data)

10000

In [14]:
data[:10]

['Advertising Read more\n\nParis (AFP)\n\nMore than 16,000 desalination plants scattered across the globe produce far more toxic sludge than fresh water, according to a first global assessment of the sector\'s industrial waste, published Monday.\n\nFor every litre of fresh water extracted from the sea or brackish waterways, a litre-and-a-half of salty slurry, called brine, is dumped directly back into the ocean or the ground.\n\nThe super-salty substance is made even more toxic by the chemicals used in the desalination process, researchers reported in the journal Science of the Total Environment.\n\nCopper and chlorine, for example, are both commonly used.\n\nThe amount of brine produced worldwide every year -- more than 50 billion cubic metres -- is enough to cover the state of Florida, or England and Wales combined, in a 30-centimetre (one-foot) layer of salty slime, they calculated.\n\n"The world produces less desalinated water than brine," co-author Manzoor Qadir, a scientist at th

In [15]:
bpe_data = []
with open(os.path.join(dataset_path, 'train.bpe.txt'), 'w', encoding='utf-8') as oh:
    for line in data:
        bpe_data.append(gpt2_bpe.encode(line).strip())
        oh.write(bpe_data[-1] + '\n')

In [16]:
len(bpe_data)

10000

In [17]:
bpe_data[-1]

'21478 8981 504 351 257 4301 1700 783 423 257 10595 3108 329 13925 257 1597 355 257 649 1099 12850 8733 319 11524 329 4708 16625 13 198 198 9012 1432 13 10923 756 12652 7504 1539 360 12 25705 11 5495 2097 3941 2608 2154 938 614 11 340 3804 262 3611 10006 287 1737 11 10964 13 449 33 350 29574 6122 4488 340 287 2932 290 340 1718 1245 2365 13 352 13 198'

In [18]:
gpt2_bpe.decode(bpe_data[-1])

'Illinoisans with a criminal record now have a wider path for launching a business as a new law reduces restrictions on applying for professional licenses.\n\nState Rep. Lamont Robinson Jr., D-Chicago, introduced House Bill 2670 last year, it passed the General Assembly in May, Gov. JB Pritzker signed it in August and it took effect Jan. 1.\n'

In [19]:
parser = options.get_preprocessing_parser()

In [20]:
args = parser.parse_args([])

In [21]:
lm_dataset = os.path.join(dataset_path, 'roberta-data-bin')
os.makedirs(lm_dataset, exist_ok=True)

In [22]:
args.only_source = True
args.trainpref = os.path.join(dataset_path, 'train.bpe.txt')
args.destdir = lm_dataset
args.srcdict = os.path.join(dataset_path, 'dict.txt')

In [23]:
preprocess(args)

2023-06-25 09:35:35 | INFO | fairseq_cli.preprocess | Namespace(no_progress_bar=False, log_interval=100, log_format=None, log_file=None, aim_repo=None, aim_run_hash=None, tensorboard_logdir=None, wandb_project=None, azureml_logging=False, seed=1, cpu=False, tpu=False, bf16=False, memory_efficient_bf16=False, fp16=False, memory_efficient_fp16=False, fp16_no_flatten_grads=False, fp16_init_scale=128, fp16_scale_window=None, fp16_scale_tolerance=0.0, on_cpu_convert_precision=False, min_loss_scale=0.0001, threshold_loss_scale=None, amp=False, amp_batch_retries=2, amp_init_scale=128, amp_scale_window=None, user_dir=None, empty_cache_freq=0, all_gather_list_size=16384, model_parallel_size=1, quantization_config_path=None, profile=False, reset_logging=False, suppress_crashes=False, use_plasma_view=False, plasma_path='/tmp/plasma', criterion='cross_entropy', tokenizer=None, bpe=None, optimizer=None, lr_scheduler='fixed', scoring='bleu', task='translation', source_lang=None, target_lang=None, tr

In [24]:
dictionary = Dictionary.load(os.path.join(lm_dataset, 'dict.txt'))

In [25]:
dictionary.symbols[:10]

['<s>', '<pad>', '</s>', '<unk>', '13', '262', '11', '284', '290', '286']

In [26]:
dataset = fairseq.data.data_utils.load_indexed_dataset(os.path.join(lm_dataset, 'train'), dictionary, args.dataset_impl)

2023-06-25 09:36:38 | INFO | fairseq.data.data_utils | loaded 10,000 examples from: /mnt/dl/fairseq/Masked_Language_Model/openwebtext/roberta-data-bin/train


In [27]:
dataset.sizes

array([ 849,  751,  426, ...,  684, 3770,   79], dtype=int32)

In [28]:
len(dataset.sizes)

10000

In [29]:
def get_text(idx):
    return gpt2_bpe.decode(dictionary.string(idx))

In [30]:
get_text(dataset[8].unsqueeze(0))

'Rumsfeld Memos Won by NSArchive Play Key Role in “The Afghanistan Papers”: FRINFORMSUM 12/13/19\n\nRumsfeld Memos Play Key Role in “The Afghanistan Papers”\n\nDonald Rumsfeld’s “snowflakes” – memos that the former Secretary of Defense was as fond of sending subordinates as President Trump is of tweeting – play an important role in the Washington Post’s massive exposé on the Afghanistan war, The Afghanistan Papers. The series draws on both “lesson learned” interviews conducted by the Special Inspector General for Afghanistan Reconstruction, as well as Rumsfeld’s “snowflakes” that were obtained by the National Security Archive and provided to the Post (both the interviews and the snowflakes were obtained through FOIA lawsuits).\n\nSeveral of the snowflake highlights include:\n\nAn April 17, 2002 snowflake, Subject: Afghanistan, in which Rumsfeld states “We are never going to get the U.S. military out of Afghanistan unless we take care to see there is something going on that will provide

In [31]:
gpt2_bpe.decode(bpe_data[8])

'Rumsfeld Memos Won by NSArchive Play Key Role in “The Afghanistan Papers”: FRINFORMSUM 12/13/19\n\nRumsfeld Memos Play Key Role in “The Afghanistan Papers”\n\nDonald Rumsfeld’s “snowflakes” – memos that the former Secretary of Defense was as fond of sending subordinates as President Trump is of tweeting – play an important role in the Washington Post’s massive exposé on the Afghanistan war, The Afghanistan Papers. The series draws on both “lesson learned” interviews conducted by the Special Inspector General for Afghanistan Reconstruction, as well as Rumsfeld’s “snowflakes” that were obtained by the National Security Archive and provided to the Post (both the interviews and the snowflakes were obtained through FOIA lawsuits).\n\nSeveral of the snowflake highlights include:\n\nAn April 17, 2002 snowflake, Subject: Afghanistan, in which Rumsfeld states “We are never going to get the U.S. military out of Afghanistan unless we take care to see there is something going on that will provide

In [32]:
tokens_per_sample = 512
max_tokens = 2048
shorten_method = "none"
shorten_data_split_list = ""
seed = 0
split = "train"
sample_break_mode = "complete"

# split_path = os.path.join(os.path.join(lm_dataset, 'train'))

In [33]:
dataset = fairseq.data.shorten_dataset.maybe_shorten_dataset(
            dataset,
            split,
            shorten_data_split_list,
            shorten_method,
            tokens_per_sample,
            seed,
        )

In [34]:
dataset

<fairseq.data.indexed_dataset.MMapIndexedDataset at 0x7f4e5f6036a0>

In [35]:
import copy
_dataset = copy.deepcopy(dataset)

In [36]:
dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            tokens_per_sample - 1,  # one less for <s>
            pad=dictionary.pad(),
            eos=dictionary.eos(),
            break_mode=sample_break_mode,
        )

In [37]:
len(dataset)

9381

In [38]:
(dataset[0] == _dataset[0]).all()

tensor(True)

In [39]:
dataset = PrependTokenDataset(dataset, dictionary.bos())

In [40]:
dataset[0][:10]

tensor([    0,  9167, 45781,  1163,    55, 50118, 50118, 32826,    36, 11528])

In [41]:
_dataset[0][:10]

tensor([ 9167, 45781,  1163,    55, 50118, 50118, 32826,    36, 11528,    43])

In [42]:
_dataset[0].size(), dataset[0].size()

(torch.Size([849]), torch.Size([850]))

In [43]:
gpt2_bpe_cfg.bpe = 'gpt2'

In [44]:
bpe = fairseq.data.encoders.build_bpe(gpt2_bpe_cfg)

In [45]:
bpe

<fairseq.data.encoders.gpt2_bpe.GPT2BPE at 0x7f4f8c4cff10>

In [46]:
mask_whole_words = False

In [47]:
mask_whole_words = fairseq.data.encoders.utils.get_whole_word_mask(gpt2_bpe_cfg, dictionary)
no_mask_whole_words = None

In [48]:
mask_whole_words.size()

torch.Size([50264])

In [49]:
mask_symbol = "<mask>"

In [50]:
mask_symbol in dictionary.indices

False

In [51]:
dictionary.add_symbol(mask_symbol)

50264

In [52]:
mask_symbol in dictionary.indices

True

In [53]:
dictionary.index(mask_symbol)

50264

In [54]:
mask_whole_words

tensor([1, 1, 1,  ..., 1, 1, 1], dtype=torch.uint8)

In [55]:
torch.where(mask_whole_words != 1), torch.where(mask_whole_words != 1)[0].size()

((tensor([    4,     6,    12,  ..., 50256, 50258, 50260]),),
 torch.Size([17122]))

In [56]:
# [i.item() for i in torch.where(mask_whole_words != 1)[0]]

[bpe.decode(dictionary[i]) for i in torch.where(mask_whole_words != 1)[0][:]]

['.',
 ',',
 '-',
 '�',
 "'s",
 '�',
 's',
 ':',
 ')',
 '�',
 '�',
 ',"',
 '."',
 '/',
 "'t",
 't',
 'I',
 'a',
 'S',
 'ë',
 "'",
 '"',
 '?',
 'i',
 'm',
 ';',
 'The',
 '1',
 'o',
 '000',
 'ing',
 'We',
 'com',
 '2',
 'in',
 'year',
 'ed',
 '%',
 'th',
 "'re",
 'y',
 'en',
 '),',
 're',
 'e',
 'It',
 '5',
 '3',
 'A',
 'er',
 'u',
 'an',
 'on',
 'j',
 'ers',
 'ar',
 'old',
 'as',
 'n',
 '0',
 'es',
 'h',
 '4',
 ').',
 'ie',
 '!',
 'z',
 'k',
 'al',
 'r',
 'C',
 "'ve",
 'ly',
 'is',
 'os',
 'or',
 'B',
 'man',
 '8',
 '6',
 'it',
 '7',
 'at',
 'd',
 'am',
 'b',
 "'m",
 'c',
 'M',
 'le',
 'l',
 'and',
 '9',
 '.,',
 'N',
 'ia',
 'D',
 'R',
 'f',
 'P',
 'el',
 'K',
 'G',
 '30',
 've',
 'to',
 'T',
 'g',
 'L',
 '—',
 "'ll",
 'AP',
 'et',
 'F',
 'w',
 'ley',
 'ch',
 '00',
 'st',
 'ad',
 'the',
 'ic',
 '://',
 'p',
 '�',
 'up',
 'ist',
 'O',
 'ak',
 'us',
 '10',
 'he',
 'v',
 'ur',
 '�',
 'This',
 'E',
 'il',
 'H',
 '...',
 ']',
 'im',
 'ra',
 'W',
 'um',
 'U',
 'based',
 'id',
 'ian',
 'ine',


In [57]:
mask_whole_words[torch.where(mask_whole_words != 1)]

tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8)

In [58]:
mask_idx = dictionary.index(mask_symbol)
seed = 1
mask_prob = 0.15 # probability of replacing a token with mask
leave_unmasked_prob = 0.1 # probability that a masked token is unmasked
random_token_prob = 0.1 # probability of replacing a token with a random token
freq_weighted_replacement = False
mask_multiple_length = 1
mask_stdev = 0.0
skip_masking = False

In [59]:
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            dictionary,
            pad_idx=dictionary.pad(),
            mask_idx= mask_idx,
            seed=seed,
            mask_prob=mask_prob,
            leave_unmasked_prob=leave_unmasked_prob,
            random_token_prob=random_token_prob,
            freq_weighted_replacement=freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
            mask_multiple_length=mask_multiple_length,
            mask_stdev=mask_stdev,
        )

In [60]:
src_dataset

<fairseq.data.lru_cache_dataset.LRUCacheDataset at 0x7f4e5f5f9280>

In [61]:
tgt_dataset

<fairseq.data.lru_cache_dataset.LRUCacheDataset at 0x7f4e6fc8b280>

In [62]:
src_dataset[0].size()

torch.Size([850])

In [63]:
tgt_dataset[0].size()

torch.Size([850])

In [64]:
torch.argwhere(dataset[0] == mask_idx)

tensor([], size=(0, 1), dtype=torch.int64)

In [65]:
torch.argwhere(src_dataset[0] == mask_idx).squeeze()

tensor([  4,   5,   6,   7,  38,  46,  50,  51,  52,  53,  54,  59,  95, 110,
        131, 153, 161, 167, 170, 176, 177, 181, 201, 208, 216, 217, 218, 219,
        220, 221, 243, 246, 247, 281, 282, 283, 286, 295, 296, 297, 298, 299,
        300, 302, 303, 306, 311, 314, 315, 328, 330, 339, 347, 383, 438, 443,
        445, 467, 484, 495, 496, 522, 523, 531, 540, 542, 545, 546, 557, 562,
        576, 582, 590, 603, 611, 619, 636, 644, 645, 646, 664, 677, 678, 682,
        688, 692, 699, 702, 703, 704, 705, 706, 724, 725, 741, 759, 788, 793,
        796, 826])

In [66]:
def mask_fn(index):
    np.random.seed(seed)
    item = dataset[index]
    sz = len(item)
    assert (
        mask_idx not in item
    ), "Dataset contains mask_idx (={}), this is not expected!".format(
        mask_idx,
    )

    if mask_whole_words is not None:
        word_begins_mask = mask_whole_words.gather(0, item)
        word_begins_idx = word_begins_mask.nonzero().view(-1)
        sz = len(word_begins_idx)
        words = np.split(word_begins_mask, word_begins_idx)[1:]
        assert len(words) == sz
        word_lens = list(map(len, words))

    # decide elements to mask
    mask = np.full(sz, False)
    num_mask = int(
        # add a random number for probabilistic rounding
        mask_prob * sz / float(mask_multiple_length)
        + np.random.rand()
    )

    # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
    mask_idc = np.random.choice(sz, num_mask, replace=False)
    if mask_stdev > 0.0:
        lengths = np.random.normal(
            mask_multiple_length, mask_stdev, size=num_mask
        )
        lengths = [max(0, int(round(x))) for x in lengths]
        mask_idc = np.asarray(
            [
                mask_idc[j] + offset
                for j in range(len(mask_idc))
                for offset in range(lengths[j])
            ],
            dtype=np.int64,
        )
    else:
        mask_idc = np.concatenate(
            [mask_idc + i for i in range(mask_multiple_length)]
        )
    mask_idc = mask_idc[mask_idc < len(mask)]
    try:
        mask[mask_idc] = True
    except:  # something wrong
        print(
            "Assigning mask indexes {} to mask {} failed!".format(
                mask_idc, mask
            )
        )
        raise
    # decide unmasking and random replacement
    rand_or_unmask_prob = random_token_prob + leave_unmasked_prob
    if rand_or_unmask_prob > 0.0:
        rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
        if random_token_prob == 0.0:
            unmask = rand_or_unmask
            rand_mask = None
        elif leave_unmasked_prob == 0.0:
            unmask = None
            rand_mask = rand_or_unmask
        else:
            unmask_prob = leave_unmasked_prob / rand_or_unmask_prob
            decision = np.random.rand(sz) < unmask_prob
            unmask = rand_or_unmask & decision
            rand_mask = rand_or_unmask & (~decision)
    else:
        unmask = rand_mask = None

    if unmask is not None:
        mask = mask ^ unmask
    print(mask.shape, len(word_lens))
    if mask_whole_words is not None:
        mask = np.repeat(mask, word_lens)
    new_item = np.copy(item)
    new_item[mask] = mask_idx
    if rand_mask is not None:
        if freq_weighted_replacement:
            weights = np.array(dictionary.count)
        else:
            weights = np.ones(len(dictionary))
            
        weights[:dictionary.nspecial] = 0
        weights = weights / weights.sum()
        num_rand = rand_mask.sum()
        if num_rand > 0:
            if mask_whole_words is not None:
                rand_mask = np.repeat(rand_mask, word_lens)
                num_rand = rand_mask.sum()

            new_item[rand_mask] = np.random.choice(
                len(dictionary),
                num_rand,
                p=weights,
            )

    return torch.from_numpy(new_item)

In [67]:
mask_fn(0)

(579,) 579


tensor([50264, 50264, 50264,  1163,    55, 50118, 50118, 32826, 50264, 50264,
        50264, 50264, 50264, 50264,    87,   545,     6,   151,  2694,   337,
         8111,  3451, 12827, 41649,     5,  7183,  2592,   444,    55,  8422,
         3369, 14328, 50264,  2310,   514,     6,   309,     7,    10,    78,
          720,  4990,     9,     5,  1293,    18, 50264,  3844,     6,  1027,
          302,     4, 50118, 50118,  2709,   358,  6474,   241,     9,  2310,
          514, 27380,    31, 50264, 50264,    50,  5378,  2990,  1173, 24416,
            6,    10,  6474,   241,    12,   463,    12,   102,    12,  4809,
        50264, 31924, 50264, 50264, 50264,   373,  5378,   833,     6,    16,
        12961,  2024,   124,    88,     5,  6444,    50,     5,  1255,     4,
        50118, 50118,   133,  2422,    12,    29, 12107,  6572,    16, 50264,
        50264,    55,  8422, 50264,     5,  8321,   341,    11,     5,  2694,
          337,  8111,   609,     6,  2634,   431,    11,     5, 

In [68]:
include_target_tokens = False

In [69]:
target_dataset = RightPadDataset(
            tgt_dataset,
            pad_idx=dictionary.pad(),
        )

input_dict = {
    "src_tokens": RightPadDataset(
        src_dataset,
        pad_idx=dictionary.pad(),
    ),
    "src_lengths": NumelDataset(src_dataset, reduce=False),
}
if include_target_tokens:
    input_dict["target_tokens"] = target_dataset
np.random.seed(seed)
shuffle = np.random.permutation(len(src_dataset))
dataset = SortDataset(
    NestedDictionaryDataset(
        {
            "id": IdDataset(),
            "net_input": input_dict,
            "target": target_dataset,
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_dataset, reduce=True),
        },
        sizes=[src_dataset.sizes],
    ),
    sort_order=[
        shuffle,
        src_dataset.sizes,
    ],
)


In [70]:
dataset.sizes

[array([ 850,  752,  465, ...,  685, 3771,   80], dtype=uint16)]

In [71]:
dataset[1]

OrderedDict([('id', 1),
             ('net_input.src_tokens',
              tensor([    0,  5320,  1851,  1672,  8114, 27773,   154, 50264, 50264, 50264,
                      50264, 50264, 50264, 50264, 27773,   154,   208,  7396, 50118, 50118,
                        246, 10468, 23486,  3689, 11397,   654,  5471, 23486,  3689, 16875,
                         73, 30131, 50264, 23486,  3689, 40924,   820,  5471, 23486,  3689,
                      16875,    73, 30131,   112, 10468, 50264,  3689, 40924,   564,  5471,
                      23486,  3689, 50264, 50264, 50264,   132, 10468, 23486,  3689, 11397,
                        158,  5471, 23486,  3689, 16875,    73, 30131, 50264, 10468, 23486,
                       3689, 11397,   654,  5471, 23486,  3689, 16875,    73, 30131,   112,
                      10468, 23486,  3689, 11397,   971,  5471, 23486,  3689, 16875,    73,
                      30131,   204, 23486,  3689, 40924,  2107,  5471, 23486, 50264, 16875,
                  

In [72]:
dictionary.pad()

1

In [73]:
_dataset[0]

tensor([ 9167, 45781,  1163,    55, 50118, 50118, 32826,    36, 11528,    43,
        50118, 50118,  9690,    87,   545,     6,   151,  2694,   337,  8111,
         3451, 12827,   420,     5,  7183,  2592,   444,    55,  8422,  3369,
        14328,    87,  2310,   514,     6,   309,     7,    10,    78,   720,
         4990,     9,     5,  1293,    18,  2683,  3844,     6,  1027,   302,
            4, 50118, 50118,  2709,   358,  6474,   241,     9,  2310,   514,
        27380,    31,     5,  3342,    50,  5378,  2990,  1173, 24416,     6,
           10,  6474,   241,    12,   463,    12,   102,    12,  4809,     9,
        31924,  3369, 27358,     6,   373,  5378,   833,     6,    16, 12961,
         2024,   124,    88,     5,  6444,    50,     5,  1255,     4, 50118,
        50118,   133,  2422,    12,    29, 12107,  6572,    16,   156,   190,
           55,  8422,    30,     5,  8321,   341,    11,     5,  2694,   337,
         8111,   609,     6,  2634,   431,    11,     5,  8812, 

In [74]:
shuffle

array([7460, 9116, 3627, ...,  905, 5192,  235])

In [75]:
indices = np.lexsort([shuffle, dataset.sizes[0]])

In [76]:
indices

array([6058, 5953, 3843, ..., 7316, 7439, 6659])

In [77]:
len(indices)

9381

In [78]:
sizes = dataset.sizes[0]

In [79]:
sizes[indices]

array([   10,    12,    13, ..., 26119, 32135, 39828], dtype=uint16)

In [80]:
indices = indices[sizes[indices] <= tokens_per_sample]

In [81]:
indices

array([6058, 5953, 3843, ..., 3656, 3521, 1275])

In [82]:
len(indices)

3294

In [83]:
sizes[indices]

array([ 10,  12,  13, ..., 512, 512, 512], dtype=uint16)

In [84]:
def batchify(dataset, 
            indices, 
            max_tokens=-1, 
            max_sentences=8, 
            bsz_mult=8):
    num_tokens_vec = np.array([dataset.num_tokens(idx) for idx in indices]) # The  number of max tokens in the batch position
    indices_len = indices.shape[0]
    
    batches_ends = np.zeros(indices_len, dtype=np.int32)
    batches_ends_view = batches_ends[:]
    num_tokens_view = num_tokens_vec

    pos = 0
    new_batch_end = 0

    new_batch_max_tokens = 0
    new_batch_sentences = 0
    new_batch_num_tokens = 0

    overflow = False
    size_matches_with_bsz_mult = False

    batches_count = 0
    batch_start = 0
    tail_max_tokens = 0
    batch_max_tokens = 0

    for pos in range(indices_len):
        # At every pos we keep stats about the last complete batch [batch_start:batch_end),
        #      and tail [batch_end:pos].
        # 1) Every time when (batch + tail) forms a valid batch
        #      (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
        # 2) When (batch+tail) violates max_tokens or max_sentences constraints
        #      we finalize running batch, and tail becomes a new batch.
        # 3) There is a corner case when tail also violates constraints.
        #      In that situation [batch_end:pos-1] (tail without the current pos)
        #      gets added to the finalized batches, while [pos:pos] becomes a new tail.
        #
        # Important: For the sake of performance try to avoid using function calls within this loop.

        tail_max_tokens = tail_max_tokens \
                            if tail_max_tokens > num_tokens_view[pos] \
                            else num_tokens_view[pos]
        new_batch_end = pos + 1
        new_batch_max_tokens = batch_max_tokens \
                                if batch_max_tokens > tail_max_tokens \
                                else tail_max_tokens
        new_batch_sentences = new_batch_end - batch_start
        new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens

        overflow = (new_batch_sentences > max_sentences > 0 or
                    new_batch_num_tokens > max_tokens > 0)
        size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
                                      new_batch_sentences % bsz_mult == 0)

        if overflow:
            tail_num_tokens = tail_max_tokens * \
                    (new_batch_end - batches_ends_view[batches_count])
            tail_overflow = tail_num_tokens > max_tokens > 0
            # In case of a tail overflow finalize two batches
            if tail_overflow:
                batches_count += 1
                batches_ends_view[batches_count] = pos
                tail_max_tokens = num_tokens_view[pos]
            batch_start = batches_ends_view[batches_count]
            batches_count += 1
            new_batch_max_tokens = tail_max_tokens

        if overflow or size_matches_with_bsz_mult:
            batches_ends_view[batches_count] = new_batch_end
            batch_max_tokens = new_batch_max_tokens
            tail_max_tokens = 0
    if batches_ends_view[batches_count] != indices_len:
        batches_count += 1
    # Memory and time-efficient split
    return np.split(indices, batches_ends[:batches_count])

In [85]:
batch_size = 32

In [86]:
batch_indices = batchify(dataset, indices, max_sentences=batch_size, bsz_mult=batch_size)

In [87]:
indices[:20]

array([6058, 5953, 3843, 4885, 4548, 2248, 4967,  144, 3506, 6045, 6165,
       9239,  205, 3108, 7472, 9243, 2159, 3307, 2025,  971])

In [88]:
batch_indices[:4]

[array([6058, 5953, 3843, 4885, 4548, 2248, 4967,  144, 3506, 6045, 6165,
        9239,  205, 3108, 7472, 9243, 2159, 3307, 2025,  971, 7130, 1777,
        2164, 6368, 8819, 1480, 1759, 2192, 1432, 4939, 2373, 2988]),
 array([8687, 4821, 3761, 1143, 1749, 3224, 3662, 1382, 7798, 5165, 4091,
        2087, 6222, 5552, 8658, 3092,  496, 1615, 5353, 2148, 7442, 8023,
        5097, 1473,  667, 1096, 1744, 8065, 7626, 1502, 6701, 7783]),
 array([ 954, 7094,  748, 6795, 1038, 2279, 3051, 6140, 3014, 4779, 1427,
         285, 2311, 5875, 8725, 6338, 4903, 3898, 7144, 8858, 4184, 2047,
        9360, 6620, 4789, 8927, 6862, 2648, 7128, 5391, 1829, 8482]),
 array([1404, 7113, 3969,  142,  862,  555, 5765, 3527, 2466, 6388, 1415,
        9269, 7501, 3474, 2329, 1177,  122, 2512, 8640, 7426, 2873, 8651,
        1691, 1146, 7859, 3262, 8716, 3386, 9214, 9056, 7274, 6136])]

In [89]:
list(len(s) for s in batch_indices[:2])

[32, 32]

In [90]:
for idx in batch_indices:
    if 321 in idx or 6058 in idx:
        print(idx)

[6058 5953 3843 4885 4548 2248 4967  144 3506 6045 6165 9239  205 3108
 7472 9243 2159 3307 2025  971 7130 1777 2164 6368 8819 1480 1759 2192
 1432 4939 2373 2988]
[7368 7387 6144 6758 6880 1740 7065 6558  321 9064 6328 2020 8969 6883
 8197 2608 5338 5625  902 8268 6262 7734 4848 8547 4670 9039 2728 6886
 1547 9350 8883 1999]


In [91]:
sizes[[ 321, 9064, 6328, 2020, 8969, 6883, 8197, 2608]]

array([381, 381, 381, 381, 381, 381, 381, 381], dtype=uint16)

In [92]:
class DataSampler(Sampler):
    def __init__(self, dataset, batch_size, batch_indices):
        super().__init__(dataset)
        self.indices = batch_indices
        self.batch_size = batch_size
    
    def __iter__(self):
        for idx in self.indices:
            for i in range(self.batch_size):
                try:
                    yield idx[i]
                except Exception:
                    return
        
        

In [93]:
def collater(batches):
    sample = OrderedDict([(k, []) for k in batches[0].keys()] )
    max_length = max(batch['ntokens'] for batch in batches)
    for k in sample:
        is_tensor = False
        for i in range(len(batches)):
            is_tensor = torch.is_tensor(batches[i][k])
            item = batches[i][k]
            if k in ['net_input.src_tokens', 'target']:
                item = torch.cat((item, torch.tensor([dictionary.pad()]*(max_length - batches[i]['ntokens'])).to(torch.int64)) )
            sample[k].append(item)                
            
        if is_tensor:
            sample[k] = torch.stack(sample[k])
        else:
            sample[k] = torch.tensor(sample[k])
    return sample
    
    
    

In [94]:
dataloader = DataLoader(dataset, sampler=DataSampler(dataset, batch_size, batch_indices), batch_size=batch_size, shuffle=False, collate_fn=collater)

In [95]:
for i, sample in enumerate(dataloader):
    src_tokens, tgt_tokens = sample['net_input.src_tokens'], sample['target']
    print(src_tokens.size(), tgt_tokens.size())

torch.Size([32, 25]) torch.Size([32, 25])
torch.Size([32, 34]) torch.Size([32, 34])
torch.Size([32, 41]) torch.Size([32, 41])
torch.Size([32, 52]) torch.Size([32, 52])
torch.Size([32, 61]) torch.Size([32, 61])
torch.Size([32, 74]) torch.Size([32, 74])
torch.Size([32, 87]) torch.Size([32, 87])
torch.Size([32, 96]) torch.Size([32, 96])
torch.Size([32, 106]) torch.Size([32, 106])
torch.Size([32, 114]) torch.Size([32, 114])
torch.Size([32, 123]) torch.Size([32, 123])
torch.Size([32, 133]) torch.Size([32, 133])
torch.Size([32, 144]) torch.Size([32, 144])
torch.Size([32, 152]) torch.Size([32, 152])
torch.Size([32, 158]) torch.Size([32, 158])
torch.Size([32, 167]) torch.Size([32, 167])
torch.Size([32, 174]) torch.Size([32, 174])
torch.Size([32, 182]) torch.Size([32, 182])
torch.Size([32, 188]) torch.Size([32, 188])
torch.Size([32, 196]) torch.Size([32, 196])
torch.Size([32, 204]) torch.Size([32, 204])
torch.Size([32, 210]) torch.Size([32, 210])
torch.Size([32, 216]) torch.Size([32, 216])
torc

In [96]:
sample['id']

tensor([3733, 4785, 8557, 8415, 7367, 7083, 9144, 5081, 5711, 2688, 3401, 2990,
        8305, 6438, 3718,  626, 1980, 2503, 7457, 3203, 8070, 5592, 4959, 6476,
        3949, 5487, 7198, 3656, 3521, 1275])

In [97]:
sample['net_input.src_tokens'].shape

torch.Size([30, 512])

In [98]:
sample['net_input.src_tokens']

tensor([[50264, 50264,   692,  ...,     1,     1,     1],
        [    0,   170,    32,  ...,     1,     1,     1],
        [    0, 15852, 48853,  ...,     1,     1,     1],
        ...,
        [    0,  5762,   578,  ...,   494, 50118,     2],
        [    0, 19993,   324,  ...,  1090, 50118, 50264],
        [    0,  3632,  3293,  ..., 50264, 50264,     2]])

In [99]:
sample['target']

tensor([[    0, 28061,     1,  ...,     1,     1,     1],
        [    1,     1,     1,  ...,     1,     1,     1],
        [    1,     1,   682,  ...,     1,     1,     1],
        ...,
        [    1,     1,     1,  ...,     1,     1,     1],
        [    1,     1,     1,  ...,     1,     1,     2],
        [    1,     1,     1,  ...,     4, 50118,     1]])

In [150]:
torch.save({'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens, 'dictionary': dictionary,
            'tokens_per_sample': tokens_per_sample}, 
           '/mnt/dl/fairseq/Masked_Language_Model/openwebtext/model.inp.pkl')

## Model Training

In [24]:
%load_ext autoreload
%autoreload 2
import torch
from fairseq.models.roberta import RobertaModel
from fairseq import options
from torch.nn import functional as F

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# state = torch.load(  '/mnt/dl/fairseq/Masked_Language_Model/openwebtext/model.inp.pkl')

In [3]:
# dictionary = state['dictionary']
# src_tokens = state['src_tokens']
# tgt_tokens = state['tgt_tokens']
# tokens_per_sample = state['tokens_per_sample']

In [4]:
class MaskedLMTask:
    
    def __init__(self, dictionary) -> None:
        self.source_dictionary = dictionary

In [5]:
task = MaskedLMTask(dictionary)

In [6]:
model_args = options.get_parser("Model", default_task='masked_lm')

In [7]:
model_args.max_positions = tokens_per_sample

In [8]:
model = RobertaModel.build_model(model_args, task)

In [9]:
model.encoder

RobertaEncoder(
  (sentence_encoder): TransformerEncoder(
    (dropout_module): FairseqDropout()
    (embed_tokens): Embedding(50265, 768, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(514, 768, padding_idx=1)
    (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayerBase(
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout_module): FairseqDropout()
        (activation_dropout_module): FairseqDropout()
        (fc1): Linear(in_features=768, out_

In [10]:
model.cuda()

RobertaModel(
  (encoder): RobertaEncoder(
    (sentence_encoder): TransformerEncoder(
      (dropout_module): FairseqDropout()
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(514, 768, padding_idx=1)
      (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x TransformerEncoderLayerBase(
          (self_attn): MultiheadAttention(
            (dropout_module): FairseqDropout()
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout_module): FairseqDropout()
          (activation_dropout_module):

In [11]:
def forward(model, src_tokens, tgt_tokens):
    print(src_tokens.size())
    print(tgt_tokens.size())
    masked_tokens = tgt_tokens.ne(dictionary.pad())
    with torch.no_grad():
        logits = model(src_tokens.cuda(), masked_tokens=masked_tokens.cuda())
    return logits

In [15]:
logits, *_ = forward(model, src_tokens, tgt_tokens)

torch.Size([30, 512])
torch.Size([30, 512])


In [16]:
logits.size()

torch.Size([2319, 50265])

In [21]:
labels = tgt_tokens[torch.where(tgt_tokens.ne(dictionary.pad()))]

In [22]:
labels

tensor([    0, 28061,     5,  ...,   353,     4, 50118])

In [23]:
labels.size()

torch.Size([2319])

In [26]:
logits

tensor([[-0.5911,  0.0000, -0.1709,  ..., -0.7963, -0.6580, -0.3671],
        [-0.5502,  0.0000,  0.8713,  ...,  0.6290, -0.0139, -0.3900],
        [-0.6489,  0.0000, -0.2749,  ...,  0.8538,  0.2589, -0.2691],
        ...,
        [ 0.8187,  0.0000,  0.3288,  ..., -0.6236, -0.0106,  0.0048],
        [ 0.3306,  0.0000,  0.0701,  ...,  0.2340, -0.3663, -1.1104],
        [ 0.0511,  0.0000, -0.1746,  ..., -0.1997,  0.3947, -0.3209]],
       device='cuda:0')

In [32]:
loss = F.cross_entropy( logits, labels.to(torch.int64).cuda(), reduction='mean')

In [33]:
loss

tensor(10.9865, device='cuda:0')

fairseq-hydra-train -m --config-dir /env_nlp/lib/python3.9/site-packages/fairseq/examples/roberta/config/pretraining --config-name base2

 fairseq-preprocess     --only-source     --srcdict /mnt/dl/fairseq/Masked_Language_Model/openwebtext2/dict.txt     --trainpref /mnt/dl/fairseq/Masked_Language_Model/openwebtext2/train.raw.txt     --validpref /mnt/dl/fairseq/Masked_Language_Model/openwebtext2/valid.raw.txt     --testpref /mnt/dl/fairseq/Masked_Language_Model/openwebtext2/test.raw.txt     --destdir /mnt/dl/fairseq/Masked_Language_Model/openwebtext2/roberta-data-bin     --workers  2

In [9]:
!pip freeze

aiofiles==22.1.0
aiohttp==3.8.4
aiosignal==1.3.1
aiosqlite==0.19.0
antlr4-python3-runtime==4.8
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
astroid==2.15.4
asttokens==2.2.1
async-timeout==4.0.2
attrs==23.1.0
autopep8==2.0.2
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bitarray==2.7.3
bleach==6.0.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.1.0
click==8.1.3
cmake==3.26.3
colorama==0.4.6
comm==0.1.3
contourpy==1.0.7
cycler==0.11.0
Cython==0.29.34
datasets==2.12.0
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
docopt==0.6.2
docstring-to-markdown==0.12
executing==1.2.0
fairseq==0.12.2
fastBPE==0.1.1
fastjsonschema==2.16.3
filelock==3.12.0
flake8==6.0.0
fonttools==4.39.4
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.5.0
huggingface-hub==0.15.1
hydra-core==1.0.7
idna==3.4
importlib-metadata==6.6.0
importlib-resources==5.12.0
ipykernel==6.22.0
ipython==8.13.1
ipython-genutils==0.2.0
isoduration==20.11.0
isort==5.12.0
jedi==0.18.2
Ji