In [3]:
import sys
import datasets
from transformers import AutoTokenizer
sys.path.append("..")
from babilong_utils import TaskDataset, SentenceSampler, NoiseInjectionDataset


libgomp: Invalid value for environment variable OMP_NUM_THREADS

libgomp: Invalid value for environment variable OMP_NUM_THREADS


In [5]:
# ### extract dataset archive
# !unzip ../data/tasks_1-20_v1-2.zip -d ../data/

In [4]:
train_path = "../data/tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_train.txt"
test_path = "../data/tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_test.txt"

noise_dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1")

### Load task datasets

In [4]:
# task 
task_dataset_train = TaskDataset(train_path)
task_dataset_test = TaskDataset(test_path)

In [5]:
# background text
tokenizer = AutoTokenizer.from_pretrained('gpt2')

noise_sampler_train = SentenceSampler(noise_dataset['train'], tokenizer=tokenizer)
noise_sampler_test = SentenceSampler(noise_dataset['test'], tokenizer=tokenizer)

In [6]:
sample_size = 512               # max number of tokens in sample
dataset_train = NoiseInjectionDataset(task_dataset=task_dataset_train,
                                        noise_sampler=noise_sampler_train,
                                        tokenizer=tokenizer,
                                        sample_size=sample_size)

dataset_test = NoiseInjectionDataset(task_dataset=task_dataset_test,
                                        noise_sampler=noise_sampler_test,
                                        tokenizer=tokenizer,
                                        sample_size=sample_size)

In [7]:
sample = dataset_train[0]
sample.keys()

dict_keys(['facts', 'question', 'answer', 'references', 'background_text', 'fact_positions', 'input_tokens', 'target_tokens'])

### Visualize one sample

In [8]:
facts = sample['facts']
question = sample['question']
answer = tokenizer.decode(sample['target_tokens'])

background_text = tokenizer.batch_decode(sample['background_text'])

input_tokens = tokenizer.decode(sample['input_tokens'])

print(f"Facts: {' '.join(facts)}")
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"References: {' '.join(sample['references'])}")
print()
print('Background text: ', ' '.join(background_text))
print('Fact positions: ', sample['fact_positions'])
print('Combined input: ', input_tokens)

print(f"Target: {answer}")


Facts: Mary moved to the bathroom. John went to the hallway.
Question: Where is Mary? 
Answer: bathroom
References: Mary moved to the bathroom.

Background text:   = Valkyria Chronicles III =  Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3, lit. Valkyria of the Battlefield 3 ), commonly referred to as Valkyria Chronicles III outside Japan, is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable. Released in January 2011 in Japan, it is the third game in the Valkyria series. Employing the same fusion of tactical and real @-@ time gameplay as its predecessors, the story runs parallel to the first game and follows the " Nameless ", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven ".  The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. Wh

### collate function

In [9]:
import torch
from torch.nn.utils.rnn import pad_sequence

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
gen_token = tokenizer.encode('GEN')[0]
eos_token = tokenizer.eos_token_id

def collate_fn(batch):
    targets = [torch.tensor(b['target_tokens']) for b in batch]
    input_ids = [torch.tensor(b['input_tokens'] + [gen_token] + b['target_tokens'] + [eos_token]) for b in batch]
    gen_inputs = [torch.tensor(b['input_tokens'] + [gen_token]) for b in batch]

    attention_mask = [torch.ones_like(b, dtype=int) for b in input_ids]
    labels_mask = [torch.zeros_like(b, dtype=bool) for b in input_ids]
    for m, t in zip(labels_mask, targets):
        m[-len(t) - 2:] = True

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
    gen_inputs = pad_sequence(gen_inputs, padding_value=id_pad_value, batch_first=True)
    # labels = pad_sequence(input_ids, padding_value=-100, batch_first=True)
    attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
    labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

    collated = {}
    collated['input_ids'] = collated['labels'] = input_ids
    collated['input_ids_generate'] = gen_inputs
    collated['labels_mask'] = labels_mask
    collated['attention_mask'] = attention_mask.bool()
    collated['attention_mask_generate'] = (gen_inputs != id_pad_value).bool()

    collated['target_text'] = [b['answer'] for b in batch]
    return collated

In [10]:
batch = [dataset_train[i] for i in range(10)]
collated = collate_fn(batch)
collated.keys()

dict_keys(['input_ids', 'labels', 'input_ids_generate', 'labels_mask', 'attention_mask', 'attention_mask_generate', 'target_text'])

In [11]:
# labels are marked with labels_mask
tokenizer.batch_decode([c[m] for c, m in zip(collated['input_ids'], collated['labels_mask'])])

['GENbathroom<|endoftext|>',
 'GENhallway<|endoftext|>',
 'GENhallway<|endoftext|>',
 'GENoffice<|endoftext|>',
 'GENbathroom<|endoftext|>',
 'GENbathroom<|endoftext|>',
 'GENbathroom<|endoftext|>',
 'GENbathroom<|endoftext|>',
 'GENoffice<|endoftext|>',
 'GENhallway<|endoftext|>']

In [12]:
# different input_ids for .forward() and .generate()
tokenizer.batch_decode([c[m] for c, m in zip(collated['input_ids'], collated['attention_mask'])])

['The player progresses through a series of linear missions, gradually unlocked as maps that can be freely scanned through and replayed as they are unlocked.The route to each story location on the map varies depending on an individual player\'s approach : when one option is selected, the other is sealed off to the player.Outside missions, the player characters rest in a camp, where units can be customized and character growth occurs.Alongside the main story missions are character @-@ specific sub missions relating to different squad members.Mary moved to the bathroom.After the game\'s completion, additional episodes are unlocked, some of them having a higher difficulty than those found in the rest of the game.There are also love simulation elements related to the game\'s two main heroines, although they take a very minor role. The game\'s battle system, the BliTZ system, is carried over directly from Valkyira Chronicles.During missions, players select each unit using a top @-@ down per

In [13]:
tokenizer.batch_decode([c[m] for c, m in zip(collated['input_ids_generate'], collated['attention_mask_generate'])])

['The player progresses through a series of linear missions, gradually unlocked as maps that can be freely scanned through and replayed as they are unlocked.The route to each story location on the map varies depending on an individual player\'s approach : when one option is selected, the other is sealed off to the player.Outside missions, the player characters rest in a camp, where units can be customized and character growth occurs.Alongside the main story missions are character @-@ specific sub missions relating to different squad members.Mary moved to the bathroom.After the game\'s completion, additional episodes are unlocked, some of them having a higher difficulty than those found in the rest of the game.There are also love simulation elements related to the game\'s two main heroines, although they take a very minor role. The game\'s battle system, the BliTZ system, is carried over directly from Valkyira Chronicles.During missions, players select each unit using a top @-@ down per

### Create a dataloader

In [14]:
from torch.utils.data import DataLoader

dl = DataLoader(batch_size=2, dataset=dataset_train, collate_fn=collate_fn)
gen = iter(dl)
batch = next(gen)
batch.keys()

dict_keys(['input_ids', 'labels', 'input_ids_generate', 'labels_mask', 'attention_mask', 'attention_mask_generate', 'target_text'])