In [1]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
name = "databricks/databricks-dolly-15k"

In [3]:
# Load the dataset
dataset = load_dataset(name)

In [4]:
training_set = dataset["train"]

In [5]:
sample = training_set[0]

In [6]:
# get all where 'category' is 'closed_qa'
closed = training_set.filter(lambda x: x["category"] == "closed_qa")

In [7]:
closed

Dataset({
    features: ['instruction', 'context', 'response', 'category'],
    num_rows: 1773
})

In [8]:
squad = load_dataset("squad")

In [9]:
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [10]:
squad_train, squad_val = squad["train"], squad["validation"]

In [11]:
squad_train

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})

In [12]:
sample = squad_train[0]

In [13]:
sample

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [14]:
sample_answer = sample["answers"]["text"][0]

In [15]:
sample_answer

'Saint Bernadette Soubirous'

In [16]:
import tokenmonster

In [17]:
tokenizer_file = "english-8000-balanced-v1"

In [18]:
vocab = tokenmonster.load_multiprocess_safe(tokenizer_file)

In [19]:
vocab.tokenize("the quick brown fox jumps over the lazy dog")

array([1154, 3453, 3109,  707,   68, 2463,   63, 5980,  745,  538, 1398],
      dtype=uint16)

In [20]:
# the issue with this is that the dtype is numpy.uint16 is returned by
# tokenizer, while we expect torch int32
result = vocab.tokenize("the quick brown fox jumps over the lazy dog")
result = result.astype("int32")

In [21]:
result

array([1154, 3453, 3109,  707,   68, 2463,   63, 5980,  745,  538, 1398],
      dtype=int32)

In [22]:
# vocab.decode([326, 1642, 33, 8001, 45])
# vocab size is only 8000, so the 8001 is skipped

In [23]:
def tokenize(sample):
    context, question, answer = sample["context"], sample["question"], sample["answers"]["text"][0]
    context_tokens = vocab.tokenize(context)
    question_tokens = vocab.tokenize(question)
    answer_tokens = vocab.tokenize(answer)
    return {
        "context": context_tokens.astype("int32"),
        "question": question_tokens.astype("int32"),
        "answers": answer_tokens.astype("int32")
    }

In [24]:
tokenized_squad_train = squad_train.map(tokenize).with_format("torch")

Map: 100%|██████████| 87599/87599 [01:09<00:00, 1267.76 examples/s]


In [25]:
tokenized_squad_train[0].keys()

dict_keys(['id', 'title', 'context', 'question', 'answers'])

In [26]:
closed[0]

{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [27]:
def tokenize_closed(sample):
    context, question, response = sample["context"], sample["instruction"], sample["response"]
    context_tokens = vocab.tokenize(context)
    question_tokens = vocab.tokenize(question)
    response_tokens = vocab.tokenize(response)
    return {
        "context": context_tokens.astype("int32"),
        "instruction": question_tokens.astype("int32"),
        "response": response_tokens.astype("int32")
    }

In [28]:
tokenized_closed = closed.map(tokenize_closed).with_format("torch")

Map:   7%|▋         | 119/1773 [00:00<00:01, 1152.91 examples/s]

Map: 100%|██████████| 1773/1773 [00:01<00:00, 899.31 examples/s] 


In [29]:
data_dir = "data"

In [30]:
tokenized_squad_train.save_to_disk(data_dir + "/squad_train")
tokenized_closed.save_to_disk(data_dir + "/closed")

Saving the dataset (1/1 shards): 100%|██████████| 87599/87599 [00:00<00:00, 445742.75 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1773/1773 [00:00<00:00, 137010.17 examples/s]


In [31]:
# load dataset to see if it works
from datasets import load_from_disk

In [32]:
loaded_squad_train = load_from_disk(data_dir + "/squad_train")

In [33]:
loaded_squad_train

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})

In [34]:
loaded_squad_train[0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': tensor([ 332, 6414,  519, 2031, 2829, 4381, 3267, 6871, 6450,  300,  228,   37,
         1800, 1788,   36, 2509,   36, 5691,  275, 2389,  683,  445, 2452, 4092,
         2695,  515, 4280,   36, 1828,  398,   58, 3755,  300, 7409, 7891,   36,
         2509,   36, 5691, 1263, 1440, 3794,   15, 2452, 1358, 1113, 2695,  515,
          769, 5501, 2788, 2204,  822,   37, 4341, 6232, 1559, 1006, 1212, 1822,
         1061,   36,  645, 1953,  921,   57, 1094,  850, 5132, 1788,   36, 2509,
           36, 5691, 4165,   36,  660,   37, 1737, 1038, 4280, 1966,   37, 1363,
           48, 1941,   37, 1274,  300, 7409, 6987,  660,   37, 1737, 1038, 4165,
           36, 1486,  512,  462,  228, 3753,  349, 3433,  769, 1670,   69, 3787,
         7169,  300, 5060, 1697, 1080,   45, 4280, 1486,  512,   59,  655,   36,
          239,   37, 1638, 1002,   15, 5506, 6815,   36, 1828,  398,   58, 3755,
          789,   37, 1677

In [36]:
tokenized_closed[0]

{'instruction': tensor([ 332, 2778, 1389,   36, 1828,  398,   58, 7280, 3537, 6645,   34]),
 'context': tensor([ 332, 1828,  398,   58, 7280, 2829, 1802, 2048, 2543,  769,   36, 1828,
          398,   58, 7280,   36, 1252, 2953,  922,  514, 2886,   15, 3305, 7280,
           58, 5479, 1252, 2079, 5484, 1788, 5075, 1252, 2079,  668, 1454,  389,
         2677,  815, 5390,   36, 1828,  398,   58, 3102, 2856, 2271, 2054,   48,
         6078,  772,  618, 5497, 2170,  654,   36, 1828,  398,   58,   36, 2230,
         3699, 1810, 1252,   37, 3153, 2561, 4419, 3474, 2856, 6111, 3229, 4178,
         2207, 3354, 1252, 2079,  727, 7280,  275, 5764, 4220, 6391, 1355,   37,
         1551,  498,  769, 1926, 1138,   64, 7280,  727, 7291, 2171, 3723, 1252,
         2079, 1495, 3507, 3257,  815, 5755, 3485,  619, 3961,  727, 7280, 3689,
          721,  358,  727,   36, 1324,   63,   37, 1292,   49,  290, 1589,   37,
         1318,  488, 2925,   36,  809,   48, 1096,   17]),
 'response': tensor([ 332, 1