Load the SQuAD (Standford Question and Answer Dataset) Dataset

In [None]:
from datasets import load_dataset

raw_datasets = load_dataset("squad")
#raw_datasets

Print some value of the Datasets

In [None]:
print("Summary:",raw_datasets)

raw_datasets["train"][1]

# Validation DS can have multiple answers
raw_datasets["validation"][2]["answers"]


Start Training Process

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
# Try with sample data
context = raw_datasets["train"][1]["context"]
question = raw_datasets["train"][1]["question"]

# Note: Inputs only contain a single row of data
inputs = tokenizer(question, context)
print("Raw token ids:",inputs)

# Decode the token ids
print("Decoded tokens:",tokenizer.decode(inputs["input_ids"]))

In [None]:
# More complex tokenizer
# Split context into multiple chunks
# Split it into multiple samples using overlapping chunks
inputs = tokenizer(
  question,
  context,
  max_length=100, # max length of the string
  truncation="only_second", # only chunk/truncate second string which is the context
  stride=50, # overlap between chunks
  # better name for return_overflowing_tokens would be return_overlapping_tokens 
  return_overflowing_tokens=True, # Set to True, will chunk other tokens beyond the max_length
  return_offsets_mapping=True  
)

# Print the chunks inputs for the 1 data row
# each chunk token id will start for 101(for CLS) and end with 102 (for SEP)
print("Chunked context ids:",inputs["input_ids"])

# Decode the individual token ids
# Note the decoded question will be same in all cases
for ids in inputs["input_ids"]:
  print("Decoded and chunked tokens:",tokenizer.decode(ids))

# Format is [CLS] question [SEP] context [SEP]

Note the new keys overflow_to_sample_mapping and offset_mapping

overflow_to_sample_mapping: will contain [0, 0, 0, 1, 1, 2, 2] which shows how input data is split into multiple samples

[0, 0, 0, 1, 1, 2, 2] it will show how first sample [0] is split into 3 samples, [1] is split into 2 samples and [2] is split into 1, 2 are
split into 2 samples

offset_mapping: Shows offset/character positions of tokens in the mapping. For example, if the decoded tokens are:

[CLS] What is in front of the Notre Dame Main Building?

The offset mapping will be like: [(0, 0), (0, 4), (5, 7), (8, 10), ...

(0,0) : for CLS
(0,4) : For 'What', start from 0 and go till offset 4 as length is 4
(5,7) : For 'is', start from 5 and go till offset 7 as length is 2
(8,10): For 'in', start from 8 and go till offset 10 as length is 2

Note this list restart from (0,0) for a new chunk, but the index still refers to the original string.

In [None]:
print(inputs.keys())
inputs['overflow_to_sample_mapping']

# Output will be (0,0,0,0), which show that all the 4 chunks belong to the same data row

In [48]:
# Each SEP token will show up as (0,0) in this
inputs['offset_mapping'] # Shows index of each token in the string

[[(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [49]:
# For out example, there are 4 inputs chunks, and each chunk has question and chunked context
# the sequence id will be 0 for question and 1 for context, for each of the chunk ids passed to it 
# this is similar to token type ids

inputs.sequence_ids(0)  # Shows the sequence id of each token for chunk 0

# NOTE: inputs.sequence_ids(4) will throw an error as there are only 4 chunks
# SEP and CLS tokens will be shown as none 

[None,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 None,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 None]

In [None]:
answer = raw_datasets["train"][1]["answers"]
answer

In [None]:
# Find index in sequence_id where context starts
# Remember that the sequence_id is 0 for question and 1 for context

sequence_ids = inputs.sequence_ids(0)

#x:y:z in Python means to start at x, end at y, and step by z 
# find where sequence_id changes from 0 to 1
ctx_start = sequence_ids.index(1) # .index() will return the first index where the value is 1
# sequence_ids[::-1].index(1) will returns from the other side of the list where the value is 1
idx = sequence_ids[::-1].index(1) # will return 1 => second index from right where 1 starts 
ctx_end= len(sequence_ids) - idx - 1 # find the last index where the value is 1
# Example: sequence_ids = [0,0,1,1,1,1,None]
# ctx_start = 2, ctx_end = 7 - 1 - 1 = 5
ctx_start, ctx_end

In [None]:
# check whether or not the answer is fully contained within the context
# if not, target is (start, end) = (0, 0)
print("answer:",answer)

ans_start_char = answer['answer_start'][0] # location in first index, will return 188 for the sample
ans_end_char = ans_start_char + len(answer['text'][0])

offset = inputs['offset_mapping'][0] #[0] give first question and context chunk
print("offset:",offset)

In [None]:
# The answer is provided in terms of character positions in the context
# However for neurl network, we need to provide the answer in terms of token positions
# This function will find the token positions of the answer in the context
def find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset):
  
  start_idx = 0
  end_idx = 0

  if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
    pass
    # print("target is (0, 0)")
    # nothing else to do
  else:
    # find the start and end TOKEN positions

    # the 'trick' is knowing what is in units of tokens and what is in
    # units of characters

    # recall: the offset_mapping contains the character positions of each token

    i = ctx_start
    for start_end_char in offset[ctx_start:]:
      start, end = start_end_char
      if start == ans_start_char:
        start_idx = i
        # don't break yet
      
      if end == ans_end_char:
        end_idx = i
        break

      i += 1
  return start_idx, end_idx

# Token positions where the answer starts and ends
start_idx, end_idx = find_answer_token_idx(ctx_start, ctx_end, ans_start_char, ans_end_char, offset)

print (f"start_idx, end_idx: {start_idx, end_idx}")

In [None]:
# Verify the values of the answers based on the start_idx and end_idx
# Check the Token ids
input_ids = inputs['input_ids'][0]
print("Token ids for answer",input_ids[start_idx : end_idx + 1])

# Decoded values of the tokens
print("Decoded Values:",tokenizer.decode(input_ids[start_idx : end_idx + 1]))

Start the process of tokenizing the entire data set.

In [15]:
# Create a tokenize function for the entire batch which will be called from the map function

# Use these values as Google used 384 for SQuAD
max_length = 384
stride =  128

# This function is used only for the train data
def tokenize_fn_train(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them from the dict
  # offset mapping will have the question first followed by the chunked context 
  offset_mapping = inputs.pop("offset_mapping")
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping") # Shows which chunk belongs to which sample
  # e.g. orig_sample_idxs = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
  # => Chunk 0-3 belongs to sample 0, Chunk 4-7 belongs to sample 1, Chunk 8-11 belongs to sample 2
  answers = batch['answers']
  start_idxs, end_idxs = [], []

  # Put the start and end position of the answers
  # in the end positions, we will use the function defined previously
  # offset_mapping = [[(0,1),(2,5)...],[(0,6),(7,12)...],...]
  for i, offset in enumerate(offset_mapping): # i, offset =   0, [(0,1),(2,5)...] 
    sample_idx = orig_sample_idxs[i] # Sample index will be sample for multiple chunks of same sample
    
    # Searching for the answer in the specific context
    answer = answers[sample_idx]
    ans_start_char = answer['answer_start'][0]
    ans_end_char = ans_start_char + len(answer['text'][0])

    sequence_ids = inputs.sequence_ids(i)

    # find start + end of context (first 1 and last 1)
    # We will find if the answer is in this chunked context or not
    ctx_start = sequence_ids.index(1)
    ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

    start_idx, end_idx = find_answer_token_idx(
      ctx_start,
      ctx_end,
      ans_start_char,
      ans_end_char,
      offset)

    # Note that due to stride the answer can appear in multiple context 
    # windows
    start_idxs.append(start_idx) # if start_idx = end_idx = 0, then answer is not in the context
    end_idxs.append(end_idx)
  
  # Add new fields in the input.
  inputs["start_positions"] = start_idxs
  inputs["end_positions"] = end_idxs
  return inputs

In [None]:
# Prepare the train dataset
# Use the mapping functions to tokenize the data
train_dataset = raw_datasets["train"].map(
  tokenize_fn_train,
  batched=True,
  # Will remove all columns present in original data in the new dataset
  # This will insure that none of the original columns are present in the new dataset
  # as they are not used 
  remove_columns=raw_datasets["train"].column_names, 
)

# remove_columns will remove the columns from the dataset
# See the difference in the length of the raw dataset and the tokenized dataset
len(raw_datasets["train"]), len(train_dataset)

Do data prep for validation data set

In [None]:
# Check one Sample
raw_datasets["validation"][0]

In [18]:
# tokenize the validation set differently
# we won't need the targets since we will just compare with the original answer
# also: overwrite offset_mapping with Nones in place of question
# More details on this: https://huggingface.co/docs/transformers/tasks/question_answering
def tokenize_fn_validation(batch):
  # some questions have leading and/or trailing whitespace, strip them
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  # keep the offset mapping as it will be used to find the answer
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  sample_ids = []

  # rewrite offset mapping by replacing question tuples with None
  # this will be helpful later on when we compute metrics
  for i in range(len(inputs["input_ids"])):
    sample_idx = orig_sample_idxs[i]
    sample_ids.append(batch['id'][sample_idx])

    sequence_ids = inputs.sequence_ids(i)
    offset = inputs["offset_mapping"][i]
    # Change any value that does not belong to the context to None
    # Remember that the sequence_id is 0 for question and 1 for context
    inputs["offset_mapping"][i] = [
      x if sequence_ids[j] == 1 else None for j, x in enumerate(offset)]
    
  inputs['sample_id'] = sample_ids
  return inputs

In [None]:
# Generate the validation dataset using map function defined earlier
validation_dataset = raw_datasets["validation"].map(
  tokenize_fn_validation,
  batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
# The length will differ as chunking will create additional samples
len(raw_datasets["validation"]), len(validation_dataset)

Build code for the Metrics

In [20]:
# ----------- This is not used anymore
# from datasets import load_metric
# metric = load_metric("squad")
#------------------------
#pip install evaluate
import evaluate

# Most standards datasets for NLP tasks have associated metrics for them
metric = evaluate.load("squad")




This shows a sample structures of predicted and true answers, and how they are passed to the compute function  

In [None]:
predicted_answers = [
  {'id': '1', 'prediction_text': 'Albert Einstein'},
  {'id': '2', 'prediction_text': 'physicist'},
  {'id': '3', 'prediction_text': 'general relativity'},
]
true_answers = [
  {'id': '1', 'answers': {'text': ['Albert Einstein'], 'answer_start': [100]}},
  {'id': '2', 'answers': {'text': ['physicist'], 'answer_start': [100]}},
  {'id': '3', 'answers': {'text': ['special relativity'], 'answer_start': [100]}},
]

# id and answer_start seem superfluous but you'll get an error if not included
# metrics.compute will give accuracy and F1 score
metric.compute(predictions=predicted_answers, references=true_answers)

Create a smaller validation data set

In [None]:
# next problem: how to go from logits to prediction text?
small_validation_dataset = raw_datasets["validation"].select(range(100)) # select 1 to 100

# Let's work on an already-trained question-answering model
# Model name will be used in both AutoTokenizer and AutoModelForQuestionAnswering
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer2 = AutoTokenizer.from_pretrained(trained_checkpoint)

# temporarily assign tokenizer2 to tokenizer since it's used as a global
# in tokenize_fn_validation. The original tokenizer is declared earlier in the code
old_tokenizer = tokenizer
tokenizer = tokenizer2

small_validation_processed = small_validation_dataset.map(
    tokenize_fn_validation,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

# change it back
tokenizer = old_tokenizer

Start the definition of the model here

In [None]:
# Get the model prepped for training

# get the model outputs
import torch
# AutoModelForQuestionAnswering will be downloaded from the Hugging Face model hub
# based on the model name we had specified earlier
from transformers import AutoModelForQuestionAnswering

# the trained model doesn't use these columns
small_model_inputs = small_validation_processed.remove_columns(
  ["sample_id", "offset_mapping"])
small_model_inputs.set_format("torch")

# get gpu device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")  # should be cuda if GPU is available

# move tensors to gpu device using .to(device)
small_model_inputs_gpu = {
  k: small_model_inputs[k].to(device) for k in small_model_inputs.column_names
}

# download the model
# Note: This model is already pretrained on QA tasks
trained_model = AutoModelForQuestionAnswering.from_pretrained(
  trained_checkpoint).to(device)

# get the model outputs
# This will be used to get the logits before the training
with torch.no_grad():
  outputs = trained_model(**small_model_inputs_gpu)

# The logits will be of size N x T  
print(f"Entries in Output: {outputs.keys()})")

# outputs if of format: QuestionAnsweringModelOutput
outputs


In [None]:
# Check the Orig Sample IDs
print("Sample IDs:",validation_dataset['sample_id'])

# However if the context is too long, the answer may not be in the context so the sample ID
# will repeat. Therefore if we take the unique values of the sample ID, it will be lower than the sample
# IDs taken for the full dataset, as seen below:

len(set(validation_dataset['sample_id']))

 

Convert outputs fron DNN back to a string output, as metric inputs is a string. String also will helps the human understand the answer

In [None]:
# the start_logits and end_logits give the logits for the start and end of the answer
# within the context. 
# Note that the start_logits and end_logits are used only in this model type
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

# Shape is 100 x 384
# 100 is the number of samples selected from the main validation set 
# 384 is the max length of the input specified in the function used in the map function
# Hence logits gives the propobability for each token to be the start and end of the answer
start_logits.shape, end_logits.shape



In [39]:
# Get the highest probability from the model output for the start and end positions:
# This method is not robust as sometimes answer_end_index can be less than answer_start_index
# a different method is used in the next cell
# argmax() returns tensor of indices of the maximum value of all elements in the input tensor
answer_start_index = outputs.start_logits.argmax(keepdim=False)
answer_end_index = outputs.end_logits.argmax(keepdim=False)

# The start and end positions are shown for the entire index (row is added to the column)
print(f"Start, end: {answer_start_index}, {answer_end_index}")


Start, end: 22331, 22332


Different approach to convert the logits to the text is given below

In [40]:
# Print the sample of sample_ids, this is a unique id for each sample
small_validation_processed['sample_id'][:5]

['56be4db0acb8001400a502ec',
 '56be4db0acb8001400a502ed',
 '56be4db0acb8001400a502ee',
 '56be4db0acb8001400a502ef',
 '56be4db0acb8001400a502f0']

In [None]:
# Build a mapping function that maps the sample id to its actual index in the data structure
# example: {'56be4db0acb8001400a502ec': [0, 1, 2, 3], ...}
sample_id2idxs = {}
for i, id_ in enumerate(small_validation_processed['sample_id']):
  if id_ not in sample_id2idxs:
    sample_id2idxs[id_] = [i]
  else:
    sample_id2idxs[id_].append(i)

In [45]:
# argsort() returns the indices of values in sorted order
print("Values in descending order:",start_logits[0].argsort())
# Convert values to -ve to get the values in ascending order
print("Values in ascending order:",(-start_logits[0]).argsort())

Values in descending order: [360 361 362 364 370 365 359 354 374 356 363 357 373 358 222 328 280 383
 355 198 216 369 375 206 381 219 242 366 371 220 380 231 376 221 353 379
 352 382 368 247 367 215 344 235 207 241 244 372 217 243 224 226 340 351
 180 184 349 202 199 330 193 223 262 203 347 204 236 282 211 225 200 205
 201 208 263 378 213 197 339 188 248 195 196 258 377 182 257 194 246 249
 240 261 336 209 348 212 277 324 210 228 178 275 341 260 189 183 335 181
 343 345 342 279 271 254 185 268 281 259 350 278 325 192 214 334 253 238
 329 267 172 332 239 187 173 273 285 179 232 218 230 191 346 233 272 276
 326 255 284 227 237 234 327 331 288 264 286 293 190 266 337 256 174 283
 250 245 186 297 302 177 229 296 252 290 322 323 292 338 175 176 298 265
 274 299 269 303 333 251 294 289 295 315 270 287 305 301 318 291 321 319
 320 317 306 314 308 316 307 300 310 313 309 304 312 311 122 129 149 147
 121 113 146 125 158 142 120  99 116 161 163 148 126 168 167 156 135 138
 162 119 145 133 155 15

In [46]:
# Get the actial value from the sorted index
start_logits[0][(-start_logits[0]).argsort()]

array([10.694443  ,  9.803684  ,  4.459974  ,  4.400486  ,  2.9437783 ,
        2.7017372 ,  2.012644  ,  1.5780748 ,  0.5223746 ,  0.02073722,
       -0.02802708, -0.04971639, -0.38573125, -0.6945367 , -0.79795045,
       -0.8678042 , -0.87220824, -1.3516885 , -1.3703709 , -1.3878822 ,
       -1.5135099 , -1.735547  , -1.8827038 , -1.8932867 , -1.9078968 ,
       -1.9304981 , -2.2607315 , -2.2983897 , -2.306934  , -2.502741  ,
       -2.510062  , -2.5308425 , -2.539996  , -2.6718142 , -2.7323534 ,
       -2.7710214 , -2.7713675 , -2.9521344 , -3.0604677 , -3.1706069 ,
       -3.204545  , -3.569337  , -3.5798054 , -3.6668842 , -3.7250612 ,
       -3.7498565 , -3.7632205 , -3.996813  , -4.01133   , -4.0688014 ,
       -4.0944867 , -4.1954756 , -4.238311  , -4.332363  , -4.352416  ,
       -4.3879642 , -4.388612  , -4.396614  , -4.6790533 , -4.7030315 ,
       -4.775757  , -4.777815  , -4.7882195 , -4.788246  , -4.8221292 ,
       -4.8725405 , -4.8849354 , -4.8981485 , -5.072099  , -5.10

In [64]:
# Note: Earlier we made a change in offset_mapping where we stored None 
# everywhere except the context window
# in the context window we store tuples for each token containing:
# (start_character_position, end_character_position)
print("Offset Map for single train data (in index 0)\n", small_validation_processed[0]['offset_mapping'])

print("Single Offset Map value for single train data (in index 0) and char index 20\n", small_validation_processed[0]['offset_mapping'][20])

print("Start Value from Single Offset Map value for single train data (in index 0) and char index 20\n", small_validation_processed[0]['offset_mapping'][20][0])

print("End Value from Single Offset Map value for single train data (in index 0) and char index 20\n", small_validation_processed[0]['offset_mapping'][20][1])

Offset Map for single train data (in index 0)
 [None, None, None, None, None, None, None, None, None, None, None, None, None, [0, 5], [6, 10], [11, 13], [14, 17], [18, 20], [21, 29], [30, 38], [39, 43], [44, 46], [47, 56], [57, 60], [61, 69], [70, 72], [73, 76], [77, 85], [86, 94], [95, 101], [102, 103], [103, 106], [106, 107], [108, 111], [112, 115], [116, 120], [121, 127], [127, 128], [129, 132], [133, 141], [142, 150], [151, 161], [162, 163], [163, 166], [166, 167], [168, 176], [177, 183], [184, 191], [192, 200], [201, 204], [205, 213], [214, 222], [223, 233], [234, 235], [235, 238], [238, 239], [240, 248], [249, 257], [258, 266], [267, 269], [269, 270], [270, 272], [273, 275], [276, 280], [281, 286], [287, 292], [293, 298], [299, 303], [304, 309], [309, 310], [311, 314], [315, 319], [320, 323], [324, 330], [331, 333], [334, 342], [343, 344], [344, 345], [346, 350], [350, 351], [352, 354], [355, 359], [359, 360], [360, 361], [362, 369], [370, 372], [373, 376], [377, 380], [381, 390]