In [1]:
from datasets import load_from_disk
from transformers import GPT2LMHeadModel, GPT2Config, GPT2TokenizerFast, DataCollatorForLanguageModeling
import torch

In [2]:
from datasets import disable_caching
disable_caching()

In [None]:
# model_dir = "models/script1/left_sentence/checkpoint-75047"
# # model_dir = "gpt2"
# tokenizer_name = "gpt2"

# model = load_pretrained_model(model_dir)
# tokenizer = load_pretrained_tokenizer(tokenizer_name)
# data_collator = init_data_collator(tokenizer, 'left')

In [5]:
tokenized_testset_dir = "../data/coca_spoken/tokens_sentence/test"
test_set = load_from_disk(tokenized_testset_dir)
# test_set = test_set.remove_columns('text')

In [6]:
test_set

Dataset({
    features: ['text', 'input_ids', 'attention_mask'],
    num_rows: 600372
})

In [8]:
def load_pretrained_tokenizer(pretrained_model_name_or_path, context=None, add_prefix_space=False):
    print(f'Loading pretrained tokenizer from {pretrained_model_name_or_path}...')
    tokenizer = GPT2TokenizerFast.from_pretrained(
        pretrained_model_name_or_path, 
        add_prefix_space=add_prefix_space, #?
    )

    if context == 'bigram':
        tokenizer.bos_token = '<s>'
        tokenizer.eos_token = '</s>'

    tokenizer.pad_token = tokenizer.eos_token # ?
    print("Vocabulary size:", tokenizer.vocab_size)
    print("Max Model Input Sizes:", tokenizer.model_max_length)
    print("BOS token:", tokenizer.bos_token, tokenizer.bos_token_id)
    print("EOS token:", tokenizer.eos_token, tokenizer.eos_token_id)
    print("PAD token:", tokenizer.pad_token, tokenizer.pad_token_id)
    print("SEP token:", tokenizer.sep_token, tokenizer.sep_token_id)
    print("UNK token:", tokenizer.unk_token, tokenizer.unk_token_id)
    print("Special tokens:", tokenizer.all_special_tokens)
    print('...done')
    return tokenizer

In [9]:
BLANK = '[BLANK]'
FILLER = '[FILLER]'
SEP = '[SEP]'

In [11]:
tokenizer = load_pretrained_tokenizer('gpt2')
num_added_tokens = tokenizer.add_tokens([BLANK, FILLER, SEP])

BLANK_id = tokenizer.convert_tokens_to_ids(BLANK)
FILLER_id = tokenizer.convert_tokens_to_ids(FILLER)
SEP_id = tokenizer.convert_tokens_to_ids(SEP)
print(BLANK_id, FILLER_id, SEP_id)

Loading pretrained tokenizer from gpt2...
Vocabulary size: 50257
Max Model Input Sizes: 1024
BOS token: <|endoftext|> 50256
EOS token: <|endoftext|> 50256
PAD token: <|endoftext|> 50256
SEP token: None None
UNK token: <|endoftext|> 50256
Special tokens: ['<|endoftext|>']
...done
50257 50258 50259


## Making an IterableDataset

In [7]:
from datasets import IterableDataset

In [13]:
def expand_inputs(example):

    input_ids = example['input_ids']
    # attention_mask = features['attention_mask']

    n_tokens = len(input_ids)

    bidi_input_ids = [input_ids[:i] + [BLANK_id] + input_ids[i+1:] + [SEP_id, FILLER_id] 
                    for i in range(n_tokens)]

    bidi_attention_mask = [[1 for _ in range(n_tokens + 2)] for _ in range(n_tokens)]

    bidi_labels = [[-100 for _ in range(n_tokens + 1)] + [input_ids[i]] 
                for i in range(n_tokens)]
    
    for i in range(n_tokens):
        bidi_input_ids = input_ids[:i] + [BLANK_id] + input_ids[i+1:] + [SEP_id, FILLER_id]
        bidi_attention_mask = [1 for _ in range(n_tokens + 2)]
        bidi_labels = [-100 for _ in range(n_tokens + 1)] + [input_ids[i]] 
        
        bidi_input = {
            'input_ids': bidi_input_ids,
            'attention_mask': bidi_attention_mask,
            'labels': bidi_labels
        }
        yield bidi_input

    # return mini_batch

In [14]:
def gen_bidi_inputs(dataset):
    for example in dataset:
        yield from expand_inputs(example)

In [15]:
my_iterable_dataset = IterableDataset.from_generator(gen_bidi_inputs, gen_kwargs={"dataset": test_set})
i = 0
for example in my_iterable_dataset:
    i += 1
    print(example)
    if i == 40:
        break

{'input_ids': [50257, 705, 82, 1016, 284, 787, 340, 764, 50259, 50258], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1544]}
{'input_ids': [1544, 50257, 82, 1016, 284, 787, 340, 764, 50259, 50258], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, 705]}
{'input_ids': [1544, 705, 50257, 1016, 284, 787, 340, 764, 50259, 50258], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, 82]}
{'input_ids': [1544, 705, 82, 50257, 284, 787, 340, 764, 50259, 50258], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1016]}
{'input_ids': [1544, 705, 82, 1016, 50257, 787, 340, 764, 50259, 50258], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, 284]}
{'input_ids': [1544, 705, 82, 