In [None]:
!pip install transformers
!pip install spacy
!python -m spacy download en_core_web_sm
!mkdir ./data
!wget https://nlp.stanford.edu/data/coqa/coqa-train-v1.0.json -P ./data
!wget https://nlp.stanford.edu/data/coqa/coqa-dev-v1.0.json -P ./data

In [4]:
import numpy as np
import collections
import json
import logging
import os
import re
import string
from collections import Counter
from functools import partial
from multiprocessing import Pool, cpu_count
import spacy
import torch
from torch.utils.data import TensorDataset
from tqdm import tqdm
import toml

# Conversational question answering with BERT-base

In this module we will explore the Conversational Question Answering task (CoQA) with the BERT-base model. The dataset contains a context paragraph, and several questions that can be asked based on the context paragraph. The answer to one question may depend on previous questions or answers, thus making it conversational.

## Making a few objects
In the preliminary steps we are making three objects to store a CoQA example, CoQA features (vectors) and Results (outputs).

In [12]:
class CoqaExample(object):
    """Single CoQA example"""
    def __init__( self, qas_id, question_text, doc_tokens, orig_answer_text=None, start_position=None, end_position=None,
                 rational_start_position=None, rational_end_position=None, additional_answers=None,):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position
        self.additional_answers = additional_answers
        self.rational_start_position = rational_start_position
        self.rational_end_position = rational_end_position
    def __repr__(self):
        return f"""Question : {self.question_text}
                   Document : {self.doc_tokens}
                   Ground Truth : {self.orig_answer_text}
                   Answer Span : {self.start_position} -> {self.end_position}"""

class CoqaFeatures(object):
    """Single CoQA feature"""
    def __init__(self, unique_id, example_index, doc_span_index, tokens, token_to_orig_map, token_is_max_context, input_ids,
                 input_mask, segment_ids, start_position=None, end_position=None, cls_idx=None, rational_mask=None):
        self.unique_id = unique_id
        self.example_index = example_index
        self.doc_span_index = doc_span_index
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map
        self.token_is_max_context = token_is_max_context
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.start_position = start_position
        self.end_position = end_position
        self.cls_idx = cls_idx
        self.rational_mask = rational_mask

class Result(object):
    """Single result """
    def __init__(self, unique_id, start_logits, end_logits, yes_logits, no_logits, unk_logits):
        self.unique_id = unique_id
        self.start_logits = start_logits
        self.end_logits = end_logits
        self.yes_logits = yes_logits
        self.no_logits = no_logits
        self.unk_logits = unk_logits

## Next we define the preprocessing step

In [13]:
class Processor:
    train_file = "/content/data/coqa-train-v1.0.json"
    dev_file = "/content/data/coqa-dev-v1.0.json"
    def is_whitespace(self, c):
        """ Clean whitespaces """
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False
    def _str(self, s):
        """ Convert PTB tokens to normal tokens """
        if (s.lower() == '-lrb-'):
            s = '('
        elif (s.lower() == '-rrb-'):
            s = ')'
        elif (s.lower() == '-lsb-'):
            s = '['
        elif (s.lower() == '-rsb-'):
            s = ']'
        elif (s.lower() == '-lcb-'):
            s = '{'
        elif (s.lower() == '-rcb-'):
            s = '}'
        return s
    def space_extend(self, matchobj):
        return ' ' + matchobj.group(0) + ' '
    def pre_proc(self, text):
        text = re.sub(u'-|\u2010|\u2011|\u2012|\u2013|\u2014|\u2015|%|\[|\]|:|\(|\)|/|\t', self.space_extend, text)
        text = text.strip(' \n')
        text = re.sub('\s+', ' ', text)
        return text
    def process(self, parsed_text):
        output = {'word': [], 'offsets': [], 'sentences': []}
        for token in parsed_text:
            output['word'].append(self._str(token.text))
            output['offsets'].append((token.idx, token.idx + len(token.text)))
        word_idx = 0
        for sent in parsed_text.sents:
            output['sentences'].append((word_idx, word_idx + len(sent)))
            word_idx += len(sent)
        assert word_idx == len(output['word'])
        return output
    def get_raw_context_offsets(self, words, raw_text):
        """ Find word boundaries for the context """
        raw_context_offsets = []
        p = 0
        for token in words:
            while p < len(raw_text) and re.match('\s', raw_text[p]):
                p += 1
            if raw_text[p:p + len(token)] != token:
                print('something is wrong! token', token, 'raw_text:', raw_text)
            raw_context_offsets.append((p, p + len(token)))
            p += len(token)
        return raw_context_offsets
    def find_span(self, offsets, start, end):
        """ Find the span in word start/end from character start/end"""
        start_index = -1
        end_index = -1
        for i, offset in enumerate(offsets):
            if (start_index < 0) or (start >= offset[0]):
                start_index = i
            if (end_index < 0) and (end <= offset[1]):
                end_index = i
        return (start_index, end_index)
    def normalize_answer(self, s):
        """ some answers may only be different by articles like a, an, the punctuation, etc but otherwise be similar. Fix that"""
        def remove_articles(text):
            regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
            return re.sub(regex, ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))
    def find_span_with_gt(self, context, offsets, ground_truth):
        """ Given the ground truth answer, find the best matching span in the context that matches the ground truth"""
        best_f1 = 0.0
        best_span = (len(offsets) - 1, len(offsets) - 1)
        gt = self.normalize_answer(self.pre_proc(ground_truth)).split()
        ls = [ i for i in range(len(offsets)) if context[offsets[i][0]:offsets[i][1]].lower() in gt ]
        """ Check every possible span and check F1 score w.r.t. ground truth to pick best span"""
        for i in range(len(ls)):
            for j in range(i, len(ls)):
                pred = self.normalize_answer(
                    self.pre_proc( context[offsets[ls[i]][0]:offsets[ls[j]][1]])).split()
                common = Counter(pred) & Counter(gt)
                num_same = sum(common.values())
                if num_same > 0:
                    precision = 1.0 * num_same / len(pred)
                    recall = 1.0 * num_same / len(gt)
                    f1 = (2 * precision * recall) / (precision + recall)
                    if f1 > best_f1:
                        best_f1 = f1
                        best_span = (ls[i], ls[j])
        return best_span
    def get_examples(self, evaluate, history_len = 2, threads=16):
        """ Returns the training examples from the data directory.
        This method basically wraps the next method and implements multi-processing."""
        filename = self.dev_file if evaluate else self.train_file
        with open(filename, "r", encoding="utf-8") as reader:
            input_data = json.load(reader)["data"]
        threads = min(threads, cpu_count())
        with Pool(threads) as p:
            annotate_ = partial(self._create_examples, history_len=history_len)
            examples = list(tqdm( p.imap(annotate_, input_data), total=len(input_data), desc="Preprocessing examples",))
        examples = [item for sublist in examples for item in sublist]
        return examples
    def _create_examples(self, input_data, history_len):
        """ Primary pre-processing procedure """
        nlp = spacy.load('en_core_web_sm')
        examples = []
        datum = input_data
        context_str = datum['story']
        _datum = {'context': context_str, 'source': datum['source'], 'id': datum['id'], 'filename': datum['filename']}
        nlp_context = nlp(self.pre_proc(context_str))
        _datum['annotated_context'] = self.process(nlp_context)
        _datum['raw_context_offsets'] = self.get_raw_context_offsets(_datum['annotated_context']['word'], context_str)
        assert len(datum['questions']) == len(datum['answers'])
        additional_answers = {}
        if 'additional_answers' in datum:
            for k, answer in datum['additional_answers'].items():
                if len(answer) == len(datum['answers']):
                    for ex in answer:
                        idx = ex['turn_id']
                        if idx not in additional_answers:
                            additional_answers[idx] = []
                        additional_answers[idx].append(ex['input_text'])
        for i in range(len(datum['questions'])):
            question, answer = datum['questions'][i], datum['answers'][i]
            assert question['turn_id'] == answer['turn_id']
            idx = question['turn_id']
            _qas = {
                'turn_id': idx,
                'question': question['input_text'],
                'answer': answer['input_text']
            }
            if idx in additional_answers:
                _qas['additional_answers'] = additional_answers[idx]
            _qas['raw_answer'] = answer['input_text']
            if _qas['raw_answer'].lower() in ['yes', 'yes.']:
                _qas['raw_answer'] = 'yes'
            if _qas['raw_answer'].lower() in ['no', 'no.']:
                _qas['raw_answer'] = 'no'
            if _qas['raw_answer'].lower() in ['unknown', 'unknown.']:
                _qas['raw_answer'] = 'unknown'
            _qas['answer_span_start'] = answer['span_start']
            _qas['answer_span_end'] = answer['span_end']
            start = answer['span_start']
            end = answer['span_end']
            chosen_text = _datum['context'][start:end].lower()
            while len(chosen_text) > 0 and self.is_whitespace(chosen_text[0]):
                chosen_text = chosen_text[1:]
                start += 1
            while len(chosen_text) > 0 and self.is_whitespace(chosen_text[-1]):
                chosen_text = chosen_text[:-1]
                end -= 1
            r_start, r_end = self.find_span(_datum['raw_context_offsets'], start, end)
            input_text = _qas['answer'].strip().lower()
            if input_text in chosen_text:
                p = chosen_text.find(input_text)
                _qas['answer_span'] = self.find_span(_datum['raw_context_offsets'], start + p, start + p + len(input_text))
            else:
                _qas['answer_span'] = self.find_span_with_gt(_datum['context'], _datum['raw_context_offsets'], input_text)
            long_questions = []
            for j in range(i - history_len, i + 1):
                long_question = ''
                if j < 0:
                    continue
                long_question += ' ' + datum['questions'][j]['input_text']
                if j < i:
                    long_question += ' ' + datum['answers'][j]['input_text'] + ' '
                long_question = long_question.strip()
                long_questions.append(long_question)
            doc_tok = _datum['annotated_context']['word']
            if len(doc_tok) == 0:
                continue
            example = CoqaExample(
                qas_id = _datum['id'] + ' ' + str(_qas['turn_id']),
                question_text = long_questions,
                doc_tokens = doc_tok,
                orig_answer_text = _qas['raw_answer'],
                start_position = _qas['answer_span'][0],
                end_position = _qas['answer_span'][1],
                rational_start_position = r_start,
                rational_end_position = r_end,
                additional_answers=_qas['additional_answers'] if 'additional_answers' in _qas else None,
            )
            examples.append(example)
        return examples

train_examples = Processor().get_examples(evaluate = False)
test_examples = Processor().get_examples(evaluate = True)

Preprocessing examples: 100%|██████████| 7199/7199 [18:43<00:00,  6.41it/s]
Preprocessing examples: 100%|██████████| 500/500 [01:18<00:00,  6.38it/s]


## Exploring the processed examples

Now that our example questions have been processed we can explore what we have achieved.

Let us take an example from the dataset -

```
Story : Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton.
Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept.
But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters.
All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch.
The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy.
Being different made Cotton quite sad. She often wished she looked like the rest of her family.
So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them.
When her mommy and sisters found her they started laughing.

"What are you doing, Cotton?!"

"I only wanted to be more like you".

Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be any other way". And with that, Cotton's mommy picked her up and dropped her into a big bucket of water. When Cotton came out she was herself again. Her sisters licked her face until Cotton's fur was all all dry.

"Don't ever do that again, Cotton!" they all cried. "Next time you might mess up that pretty white fur of yours and we wouldn't want that!"

Then Cotton thought, "I change my mind. I like being special"

Question : What color was Cotton?

Answer : {SPAN - (59 -> 93),
          RATIONALE - a little white kitten named Cotton,
          GROUND TRUTH - white }

Question : Did she live alone?

Answer : {SPAN - (196 -> 215),
          RATIONALE - Cotton wasn't alone,
          GROUND TRUTH - no, }

```

**And see what our pre-processing has done to it...**


In [17]:
print(test_examples[0])
print(test_examples[0].doc_tokens[test_examples[0].start_position:test_examples[0].end_position+1])

print(test_examples[2])
print(test_examples[2].doc_tokens[test_examples[2].start_position:test_examples[2].end_position+1])

Question : ['What color was Cotton?']
                   Document : ['Once', 'upon', 'a', 'time', ',', 'in', 'a', 'barn', 'near', 'a', 'farm', 'house', ',', 'there', 'lived', 'a', 'little', 'white', 'kitten', 'named', 'Cotton', '.', 'Cotton', 'lived', 'high', 'up', 'in', 'a', 'nice', 'warm', 'place', 'above', 'the', 'barn', 'where', 'all', 'of', 'the', 'farmer', "'s", 'horses', 'slept', '.', 'But', 'Cotton', 'was', "n't", 'alone', 'in', 'her', 'little', 'home', 'above', 'the', 'barn', ',', 'oh', 'no', '.', 'She', 'shared', 'her', 'hay', 'bed', 'with', 'her', 'mommy', 'and', '5', 'other', 'sisters', '.', 'All', 'of', 'her', 'sisters', 'were', 'cute', 'and', 'fluffy', ',', 'like', 'Cotton', '.', 'But', 'she', 'was', 'the', 'only', 'white', 'one', 'in', 'the', 'bunch', '.', 'The', 'rest', 'of', 'her', 'sisters', 'were', 'all', 'orange', 'with', 'beautiful', 'white', 'tiger', 'stripes', 'like', 'Cotton', "'s", 'mommy', '.', 'Being', 'different', 'made', 'Cotton', 'quite', 'sad', '.', 'She'

Notice how the document **has been turned into an array of words** and the ground truth answer has been turned into **spans** from the document.

Play around with other examples from the dataset to see this in action.

# Now that we have the examples processed we can convert them to vectors


In [18]:
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)
    return (input_start, input_end)

def _check_is_max_context(doc_spans, cur_span_index, position):
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index
    return cur_span_index == best_span_index

def Extract_Feature_init(tokenizer_for_convert):
    global tokenizer
    tokenizer = tokenizer_for_convert

def Extract_Feature(example, tokenizer, max_seq_length = 512, doc_stride = 128, max_query_length = 64):
    """ Extract features for a single turn of QA """
    features = []
    query_tokens = []
    for question_answer in example.question_text:
        query_tokens.extend(tokenizer.tokenize(question_answer))
    cls_idx = 3
    if example.orig_answer_text == 'yes':
        cls_idx = 0  # yes
    elif example.orig_answer_text == 'no':
        cls_idx = 1  # no
    elif example.orig_answer_text == 'unknown':
        cls_idx = 2  # unknown
    if len(query_tokens) > max_query_length:
        # keep tail
        query_tokens = query_tokens[-max_query_length:]
    tok_to_orig_index = []
    orig_to_tok_index = []
    all_doc_tokens = []
    for (i, token) in enumerate(example.doc_tokens):
        orig_to_tok_index.append(len(all_doc_tokens))
        sub_tokens = tokenizer.tokenize(token)
        for sub_token in sub_tokens:
            tok_to_orig_index.append(i)
            all_doc_tokens.append(sub_token)
    tok_r_start_position = orig_to_tok_index[example.rational_start_position]
    if example.rational_end_position < len(example.doc_tokens) - 1:
        tok_r_end_position = orig_to_tok_index[example.rational_end_position + 1] - 1
    else:
        tok_r_end_position = len(all_doc_tokens) - 1
    if cls_idx < 3:
        tok_start_position, tok_end_position = 0, 0
    else:
        tok_start_position = orig_to_tok_index[example.start_position]
        if example.end_position < len(example.doc_tokens) - 1:
            tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
        else:
            tok_end_position = len(all_doc_tokens) - 1
        (tok_start_position, tok_end_position) = _improve_answer_span(
            all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
            example.orig_answer_text)
    # The -3 accounts for [CLS], [SEP] and [SEP]
    max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
    _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
    doc_spans = []
    start_offset = 0
    while start_offset < len(all_doc_tokens):
        length = len(all_doc_tokens) - start_offset
        if length > max_tokens_for_doc:
            length = max_tokens_for_doc
        doc_spans.append(_DocSpan(start=start_offset, length=length))
        if start_offset + length == len(all_doc_tokens):
            break
        start_offset += min(length, doc_stride)
    for (doc_span_index, doc_span) in enumerate(doc_spans):
        slice_cls_idx = cls_idx
        tokens = []
        token_to_orig_map = {}
        token_is_max_context = {}
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in query_tokens:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        for i in range(doc_span.length):
            split_token_index = doc_span.start + i
            token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

            is_max_context = _check_is_max_context(doc_spans,
                                                   doc_span_index,
                                                   split_token_index)
            token_is_max_context[len(tokens)] = is_max_context
            tokens.append(all_doc_tokens[split_token_index])
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens.
        input_mask = [1] * len(input_ids)

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        doc_start = doc_span.start
        doc_end = doc_span.start + doc_span.length - 1
        out_of_span = False
        if example.rational_start_position == -1 or not (
                tok_r_start_position >= doc_start and tok_r_end_position <= doc_end):
            out_of_span = True
        if out_of_span:
            rational_start_position = 0
            rational_end_position = 0
        else:
            doc_offset = len(query_tokens) + 2
            rational_start_position = tok_r_start_position - doc_start + doc_offset
            rational_end_position = tok_r_end_position - doc_start + doc_offset

        rational_mask = [0] * len(input_ids)
        if not out_of_span:
            rational_mask[rational_start_position:rational_end_position + 1] = [1] * (
                        rational_end_position - rational_start_position + 1)

        if cls_idx >= 3:
            # For training, if our document chunk does not contain an annotation we remove it
            doc_start = doc_span.start
            doc_end = doc_span.start + doc_span.length - 1
            out_of_span = False
            if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
                out_of_span = True
            if out_of_span:
                start_position = 0
                end_position = 0
                slice_cls_idx = 2
            else:
                doc_offset = len(query_tokens) + 2
                start_position = tok_start_position - doc_start + doc_offset
                end_position = tok_end_position - doc_start + doc_offset
        else:
            start_position = 0
            end_position = 0

        features.append(
            CoqaFeatures(example_index=0,
                         unique_id=0,
                         doc_span_index=doc_span_index,
                         tokens=tokens,
                         token_to_orig_map=token_to_orig_map,
                         token_is_max_context=token_is_max_context,
                         input_ids=input_ids,
                         input_mask=input_mask,
                         segment_ids=segment_ids,
                         start_position=start_position,
                         end_position=end_position,
                         cls_idx=slice_cls_idx,
                         rational_mask=rational_mask))
    return features


def Extract_Features(examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training,threads=1):
    """ Convert to vectors """
    features = []
    threads = min(threads, cpu_count())
    with Pool(threads, initializer=Extract_Feature_init, initargs=(tokenizer,)) as p:
        annotate_ = partial( Extract_Feature, tokenizer=tokenizer, max_seq_length=max_seq_length,
                doc_stride=doc_stride, max_query_length=max_query_length,)
        features = list( tqdm( p.imap(annotate_, examples, chunksize=32),
                              total=len(examples), desc="Extracting features from dataset",))
    new_features = []
    unique_id = 1000000000
    example_index = 0
    for example_features in tqdm(features, total=len(features), desc="Tag unique id to each example"):
        if not example_features:
            continue
        for example_feature in example_features:
            example_feature.example_index = example_index
            example_feature.unique_id = unique_id
            new_features.append(example_feature)
            unique_id += 1
        example_index += 1
    features = new_features
    del new_features
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_tokentype_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    if not is_training:
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_tokentype_ids, all_input_mask, all_example_index)
    else:
        all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
        all_rational_mask = torch.tensor([f.rational_mask for f in features], dtype=torch.long)
        all_cls_idx = torch.tensor([f.cls_idx for f in features], dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_tokentype_ids, all_input_mask, all_start_positions,
                                all_end_positions, all_rational_mask, all_cls_idx)
    return features, dataset

In [22]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def load_dataset(examples, tokenizer, evaluate=False):
    features, dataset = Extract_Features(examples=examples, tokenizer=tokenizer,
                    max_seq_length=512, doc_stride=128, max_query_length=64, is_training=not evaluate, threads=12)
    return dataset, features

train_dataset, train_features = load_dataset(train_examples, tokenizer)
test_dataset, test_features = load_dataset(test_examples, tokenizer, evaluate = True)

Extracting features from dataset: 100%|██████████| 108647/108647 [09:43<00:00, 186.25it/s]
Tag unique id to each example: 100%|██████████| 108647/108647 [00:00<00:00, 779599.55it/s]
Extracting features from dataset: 100%|██████████| 7983/7983 [00:43<00:00, 183.47it/s]
Tag unique id to each example: 100%|██████████| 7983/7983 [00:00<00:00, 661604.24it/s]


## Let us explore the extracted features with a few examples

Note how we transform the input with addition of special tokens and all our pre-processing.

What started as a phrase from the context paragraph, now becomes a span of tokens.

In [28]:
ex = test_features[0]

print(ex.tokens)
print(f"Answer : {ex.tokens[ex.start_position:ex.end_position+1]} which is position ({ex.start_position}, {ex.end_position})")

['[CLS]', 'what', 'color', 'was', 'cotton', '?', '[SEP]', 'once', 'upon', 'a', 'time', ',', 'in', 'a', 'barn', 'near', 'a', 'farm', 'house', ',', 'there', 'lived', 'a', 'little', 'white', 'kitten', 'named', 'cotton', '.', 'cotton', 'lived', 'high', 'up', 'in', 'a', 'nice', 'warm', 'place', 'above', 'the', 'barn', 'where', 'all', 'of', 'the', 'farmer', "'", 's', 'horses', 'slept', '.', 'but', 'cotton', 'was', 'n', "'", 't', 'alone', 'in', 'her', 'little', 'home', 'above', 'the', 'barn', ',', 'oh', 'no', '.', 'she', 'shared', 'her', 'hay', 'bed', 'with', 'her', 'mommy', 'and', '5', 'other', 'sisters', '.', 'all', 'of', 'her', 'sisters', 'were', 'cute', 'and', 'fluffy', ',', 'like', 'cotton', '.', 'but', 'she', 'was', 'the', 'only', 'white', 'one', 'in', 'the', 'bunch', '.', 'the', 'rest', 'of', 'her', 'sisters', 'were', 'all', 'orange', 'with', 'beautiful', 'white', 'tiger', 'stripes', 'like', 'cotton', "'", 's', 'mommy', '.', 'being', 'different', 'made', 'cotton', 'quite', 'sad', '.', 

However, the transformer needs "one-hot-embeddings", or the index of each token. The input to the transformer looks like the following. Where [CLS] has token id 101, "what" has token id 2054, etc

In [29]:
print(ex.input_ids)

[101, 2054, 3609, 2001, 6557, 1029, 102, 2320, 2588, 1037, 2051, 1010, 1999, 1037, 8659, 2379, 1037, 3888, 2160, 1010, 2045, 2973, 1037, 2210, 2317, 18401, 2315, 6557, 1012, 6557, 2973, 2152, 2039, 1999, 1037, 3835, 4010, 2173, 2682, 1996, 8659, 2073, 2035, 1997, 1996, 7500, 1005, 1055, 5194, 7771, 1012, 2021, 6557, 2001, 1050, 1005, 1056, 2894, 1999, 2014, 2210, 2188, 2682, 1996, 8659, 1010, 2821, 2053, 1012, 2016, 4207, 2014, 10974, 2793, 2007, 2014, 20565, 1998, 1019, 2060, 5208, 1012, 2035, 1997, 2014, 5208, 2020, 10140, 1998, 27036, 1010, 2066, 6557, 1012, 2021, 2016, 2001, 1996, 2069, 2317, 2028, 1999, 1996, 9129, 1012, 1996, 2717, 1997, 2014, 5208, 2020, 2035, 4589, 2007, 3376, 2317, 6816, 13560, 2066, 6557, 1005, 1055, 20565, 1012, 2108, 2367, 2081, 6557, 3243, 6517, 1012, 2016, 2411, 6257, 2016, 2246, 2066, 1996, 2717, 1997, 2014, 2155, 1012, 2061, 2028, 2154, 1010, 2043, 6557, 2179, 1037, 2064, 1997, 1996, 2214, 7500, 1005, 1055, 4589, 6773, 1010, 2016, 2109, 2009, 2000, 6773

## Now we can define our model and train it.

In the following code blocks we will first define our transformer model, followed by training.

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import BertModel


class Bert(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        hidden_size = self.bert.config.hidden_size
        self.span_modelling = nn.Linear(hidden_size, 2,bias = False)

        self.fc = nn.Linear(hidden_size,hidden_size, bias = False)
        self.fc2 = nn.Linear(hidden_size,hidden_size, bias = False)
        self.rationale_modelling = nn.Linear(hidden_size, 1, bias = False)
        self.attention_modelling = nn.Linear(hidden_size,1, bias = False)
        self.unk_modelling = nn.Linear(2*hidden_size,1, bias = False)
        self.yes_no_modelling = nn.Linear(2*hidden_size,2, bias = False)
        self.relu = nn.ReLU()
        self.beta = 5.0
    def forward(self, input_ids, segment_ids=None, input_masks=None,
                start_positions=None, end_positions=None, rationale_mask=None, cls_idx=None):
        #   Bert-base outputs
        output_vector, bert_pooled_output = self.bert(input_ids,token_type_ids=segment_ids,attention_mask=input_masks,
                                                      head_mask = None, return_dict=False) # output from BERT model
        start_end_logits = self.span_modelling(output_vector)  #predict start and end positions
        start_logits, end_logits = start_end_logits.split(1, dim=-1)
        start_logits, end_logits = start_logits.squeeze(-1), end_logits.squeeze(-1)

        rationale_logits = self.relu(self.fc(output_vector))
        rationale_logits = self.rationale_modelling(rationale_logits)
        rationale_logits = torch.sigmoid(rationale_logits) #predict rationale positions

        output_vector = output_vector * rationale_logits
        attention  = self.relu(self.fc2(output_vector))
        attention  = (self.attention_modelling(attention)).squeeze(-1)
        input_masks = input_masks.type(attention.dtype)
        attention = attention*input_masks - (1-input_masks)*1e30
        attention = F.softmax(attention, dim=-1)
        attention_pooled_output = (attention.unsqueeze(-1) * output_vector).sum(dim=-2)
        cls_output = torch.cat((attention_pooled_output,bert_pooled_output),dim = -1) #predict yes/no/unk
        rationale_logits = rationale_logits.squeeze(-1)
        unk_logits = self.unk_modelling(cls_output)
        yes_no_logits = self.yes_no_modelling(cls_output)
        yes_logits, no_logits = yes_no_logits.split(1, dim=-1)
        if self.training:
            start_positions, end_positions = start_positions + cls_idx, end_positions + cls_idx #predict start and end positions
            start = torch.cat((yes_logits, no_logits, unk_logits, start_logits), dim=-1)
            end = torch.cat((yes_logits, no_logits, unk_logits, end_logits), dim=-1)
            Entropy_loss = CrossEntropyLoss()
            start_loss = Entropy_loss(start, start_positions)
            end_loss = Entropy_loss(end, end_positions)
            rationale_positions = rationale_mask.type(attention.dtype) #predict rationale positions
            rationale_loss = (-rationale_positions*torch.log(rationale_logits + 1e-8)
                             - (1-rationale_positions)*torch.log(1-rationale_logits + 1e-8))
            rationale_loss = torch.mean(rationale_loss)
            total_loss = (start_loss + end_loss) / 2.0 + rationale_loss * self.beta #calculate total loss (multi-task)
            return total_loss
        return start_logits, end_logits, yes_logits, no_logits, unk_logits

model = Bert()

In [34]:
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

!mkdir "./weights"
device = torch.device("cuda")
model.to(device)

def train(model):
    epochs = 1
    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle = True)
    optimizer_parameters = [{"params": [p for n, p in model.named_parameters()  #Which parameters to optimize
                                if not any(nd in n for nd in ["bias", "LayerNorm.weight"])],"weight_decay": 0.01,},
                            {"params": [p for n, p in model.named_parameters()
                                if any(nd in n for nd in ["bias", "LayerNorm.weight"])], "weight_decay": 0.0}]
    optimizer = AdamW(optimizer_parameters, lr = 3e-5, eps = 1e-8) #The optimizer (AdamW)
    scheduler = get_linear_schedule_with_warmup(optimizer,  num_warmup_steps=500, num_training_steps=(len(train_dataloader)//epochs))
    # Learning rate scheduler
    counter,train_loss, loss = 1, 0.0, 0.0
    model.zero_grad()
    for ep in range(epochs): #Epoch iteration
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for i,batch in enumerate(epoch_iterator): #Iterating over mini-batches
            model.train()
            batch = tuple(t.to(device) for t in batch)
            inputs = { "input_ids": batch[0],"segment_ids": batch[1],
                  "input_masks": batch[2],"start_positions": batch[3],
                  "end_positions": batch[4],"rationale_mask": batch[5],"cls_idx": batch[6]}
            loss = model(**inputs)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            counter += 1
            epoch_iterator.set_description(f"Epoch {ep+1}/{epochs} | Loss : {(train_loss/counter)}")
            epoch_iterator.refresh()
            if counter % 1000 == 0:
                save_path = os.path.join("./weights", "weights.pth")
                torch.save(model.state_dict(), save_path)
    return model

trained_model = train(model)

mkdir: cannot create directory ‘./weights’: File exists


Epoch 1/1 | Loss : 2.452767640136831: 100%|██████████| 7193/7193 [55:26<00:00,  2.16it/s]


# Now we need our inference logic

We need to convert the span predictions, yes/no/unk predictions, etc to a textual answer.

The logic is roughly as follows.
  - Get the best yes, no, unknown prediction scores.
  - Get top 20 start positions and end positions (400 spans).
  - Out of these choose top 20 again.
  - Convert these to text (yes, no, unknown or from the span).
  - Softmax over the scores of these candidates.
  - Pick best answer.

In [39]:
from transformers.models.bert.tokenization_bert import BasicTokenizer
import math

def get_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, verbose_logging, tokenizer):
    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)
    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result
    _PrelimPrediction = collections.namedtuple("PrelimPrediction", ["feature_index", "start_index", "end_index", "score", "cls_idx",])
    all_predictions = []
    all_nbest_json = collections.OrderedDict()
    for (example_index, example) in enumerate(tqdm(all_examples, desc="Writing preditions")):
        features = example_index_to_features[example_index]
        prelim_predictions = []
        score_yes, score_no, score_span, score_unk = -float('INF'), -float('INF'), -float('INF'), -float('INF')
        min_unk_feature_index, max_yes_feature_index, max_no_feature_index, max_span_feature_index = -1, -1, -1, -1
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]
            feature_yes_score, feature_no_score, feature_unk_score = \
                result.yes_logits[0] * 2, result.no_logits[0] * 2, result.unk_logits[0] * 2
            # Compute 20 best span starts/span ends
            start_indexes, end_indexes = _get_best_indexes(result.start_logits, n_best_size), \
                                         _get_best_indexes(result.end_logits, n_best_size)
            # Compute top (20x20) span scores
            for start_index in start_indexes:
                for end_index in end_indexes:
                    """ Perform validation """
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    feature_span_score = result.start_logits[start_index] + result.end_logits[end_index]
                    prelim_predictions.append(_PrelimPrediction(feature_index=feature_index,start_index=start_index,end_index=end_index,score=feature_span_score,cls_idx=3))
            if feature_unk_score < score_unk:
                score_unk = feature_unk_score
                min_unk_feature_index = feature_index
            if feature_yes_score > score_yes:
                score_yes = feature_yes_score
                max_yes_feature_index = feature_index
            if feature_no_score > score_no:
                score_no = feature_no_score
                max_no_feature_index = feature_index
        #including yes/no/unknown answers in preliminary predictions.
        prelim_predictions.append(_PrelimPrediction(feature_index=min_unk_feature_index,start_index=0,end_index=0,score=score_unk,cls_idx=2))
        prelim_predictions.append(_PrelimPrediction(feature_index=max_yes_feature_index,start_index=0,end_index=0,score=score_yes,cls_idx=0))
        prelim_predictions.append(_PrelimPrediction(feature_index=max_no_feature_index,start_index=0,end_index=0,score=score_no,cls_idx=1))
        # Get best yes/no/unk/span
        prelim_predictions = sorted(prelim_predictions,key=lambda p: p.score,reverse=True)
        _NbestPrediction = collections.namedtuple("NbestPrediction", ["text", "score", "cls_idx"])
        seen_predictions = {}
        nbest = []
        # From all preliminary predictions choose top 20
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]
            # free-form answers (ie span answers)
            if pred.cls_idx == 3:
                tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
                orig_doc_start = feature.token_to_orig_map[pred.start_index]
                orig_doc_end = feature.token_to_orig_map[pred.end_index]
                orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
                tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
                # removing whitespaces
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)
                final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
                if final_text in seen_predictions:
                    continue
                seen_predictions[final_text] = True
                nbest.append(_NbestPrediction(text=final_text,score=pred.score,cls_idx=pred.cls_idx))
            # 'yes'/'no'/'unknown' answers
            else:
                text = ['yes', 'no', 'unknown']
                nbest.append(_NbestPrediction(text=text[pred.cls_idx], score=pred.score, cls_idx=pred.cls_idx))
        if len(nbest) < 1:
            nbest.append(_NbestPrediction(text='unknown', score=-float('inf'), cls_idx=2))
        assert len(nbest) >= 1
        probs = _compute_softmax([p.score for p in nbest])
        nbest_json = []
        # Get final best predictions
        for i, entry in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["score"] = entry.score
            nbest_json.append(output)
        assert len(nbest_json) >= 1
        _id, _turn_id = example.qas_id.split()
        all_predictions.append({
            'id': _id,
            'turn_id': int(_turn_id),
            'answer': confirm_preds(nbest_json)})
        all_nbest_json[example.qas_id] = nbest_json
    #   Writing all the predictions in the predictions.json file in the BERT directory
    with open(output_prediction_file, "w") as writer:
        writer.write(json.dumps(all_predictions, indent=4) + "\n")
    return all_predictions

def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for (i, c) in enumerate(text):
            if c == " ":
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = "".join(ns_chars)
        return (ns_text, ns_to_s_map)
    tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    tok_text = " ".join(tokenizer.tokenize(orig_text))
    start_position = tok_text.find(pred_text)
    if start_position == -1:
        return orig_text
    end_position = start_position + len(pred_text) - 1
    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
    if len(orig_ns_text) != len(tok_ns_text):
        return orig_text
    tok_s_to_ns_map = {}
    for (i, tok_index) in tok_ns_to_s_map.items():
        tok_s_to_ns_map[tok_index] = i
    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]
    if orig_start_position is None:
        return orig_text
    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]
    if orig_end_position is None:
        return orig_text
    output_text = orig_text[orig_start_position:(orig_end_position + 1)]
    return output_text

def confirm_preds(nbest_json):
    #unsuccessful attempt at trying to predict for how many and True or false type of questions
    subs = [ 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine','ten', 'eleven', 'twelve', 'true', 'false']
    ori = nbest_json[0]['text']
    if len(ori) < 2:
        for e in nbest_json[1:]:
            if _normalize_answer(e['text']) in subs:
                return e['text']
        return 'unknown'
    return ori

def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes


def _compute_softmax(scores):
    """Compute softmax probability over raw logits."""
    if not scores:
        return []
    max_score = None
    for score in scores:
        if max_score is None or score > max_score:
            max_score = score
    exp_scores = []
    total_sum = 0.0
    for score in scores:
        x = math.exp(score - max_score)
        exp_scores.append(x)
        total_sum += x
    probs = []
    for score in exp_scores:
        probs.append(score / total_sum)
    return probs

def _normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()

## Now we can evaluate our trained model.

In [43]:
def convert_to_list(tensor):
    """ Helper function to convert predictions to python list"""
    return tensor.detach().cpu().tolist()

def write_predictions(model):
    #dataset, examples, features = load_dataset(tokenizer, evaluate=True)
    evaluation_dataloader = DataLoader(test_dataset, batch_size=32, shuffle = False)
    mod_results = []
    for batch in tqdm(evaluation_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {"input_ids": batch[0],"segment_ids": batch[1],"input_masks": batch[2]}
            example_indices = batch[3]
            outputs = model(**inputs)
        for i, example_index in enumerate(example_indices):
            eval_feature = test_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            output = [convert_to_list(output[i]) for output in outputs]
            start_logits, end_logits, yes_logits, no_logits, unk_logits = output
            result = Result(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits, yes_logits=yes_logits, no_logits=no_logits, unk_logits=unk_logits)
            mod_results.append(result)

    output_prediction_file = os.path.join("./weights", "predictions.json")
    return get_predictions(test_examples, test_features, mod_results, 20, 30, True, output_prediction_file, False, tokenizer), mod_results

final_predictions, model_predictions = write_predictions(trained_model)

Evaluating: 100%|██████████| 268/268 [01:19<00:00,  3.36it/s]
Writing preditions: 100%|██████████| 7983/7983 [00:29<00:00, 269.25it/s]


## What does a prediction look like?

Notice how we are converting the logits into text (inference logic above). In this example the start,end span wins out over yes, no, unk.

In [49]:
logits = model_predictions[0]
print(len(logits.start_logits), len(logits.end_logits))
print(f"Start logits : {logits.start_logits}")
print(f"Best start position : {np.argmax(logits.start_logits)}")
print(f"End logits : {logits.end_logits}")
print(f"Best end position : {np.argmax(logits.end_logits)}")
print(f"Yes score : {logits.yes_logits}")
print(f"No score : {logits.no_logits}")
print(f"Unknown score : {logits.unk_logits}")

print(final_predictions[0]['answer'])

512 512
Start logits : [-4.226983547210693, -5.639468193054199, -6.5320820808410645, -6.843592643737793, -6.946135520935059, -6.551481246948242, -7.345547676086426, -3.3142457008361816, -4.395535469055176, -5.318740367889404, -5.1656599044799805, -5.001951217651367, -2.6636996269226074, -3.4633851051330566, -4.686534881591797, -4.249269008636475, -4.287463188171387, -5.1372456550598145, -6.298287868499756, -6.521264553070068, -2.968651294708252, -4.071929931640625, -0.21001997590065002, 0.3773552179336548, 5.7509565353393555, 1.3557193279266357, -2.802481174468994, 0.12395046651363373, -3.8300960063934326, -0.7991817593574524, -3.9897940158843994, -4.150891304016113, -5.782223224639893, -4.312005043029785, -4.730404376983643, -4.716391086578369, -3.2292027473449707, -6.046236991882324, -4.6751298904418945, -5.898762226104736, -5.575925827026367, -5.49281120300293, -3.6707310676574707, -5.429015636444092, -5.542548179626465, -4.938834190368652, -6.9178571701049805, -6.691030025482178, -

## Now we can evaluate our answers with the official evaluation script.

This is stored in processors/eval.py

We should have an exact match (em) score of around 65% and an F1 score of around 75%.

In [None]:
from processors.eval import CoQAEvaluator

def print_results():
    evaluator = CoQAEvaluator("./data/coqa-dev-v1.0.json")
    prediction_file = os.path.join("./weights", "predictions.json")
    with open(prediction_file) as f:
        pred_data = CoQAEvaluator.preds_to_dict(prediction_file)
    print(json.dumps(evaluator.model_performance(pred_data), indent=2))

print_results()


## Now we can use our very own story and questions to get desired answers.

We have to go through the whole pipeline to convert our story, question to tensors, run inference and decode output.

In [67]:
Story = """In the hushed confines of a cutting-edge particle physics laboratory, Dr. Samuel Reynolds and his team embarked on a journey to uncover the mysteries of the universe.
They sought to unravel the enigma of dark matter, an invisible substance that eluded comprehension for centuries.
One fateful day, as the team conducted an experiment involving the behavior of subatomic particles, they noticed an anomaly.
Particles interacted in a peculiar manner, exhibiting an inexplicable attraction to one another.
Dr. Reynolds named this uncharted force the "Cryptic Force," symbolizing the hidden secrets it promised to unveil.

The discovery sent shockwaves through the scientific community.
Skepticism abounded initially, but subsequent experiments replicated the results.
The Cryptic Force, distinct from the four known fundamental forces, held the potential to bridge gaps in our understanding of the cosmos.
Cryptic Force seemed to be intertwined with dark matter, offering a glimmer of hope in solving the puzzle of the universe's invisible enigma.
The revelation ignited new lines of inquiry into the nature of dark matter, as scientists worldwide raced to unlock its secrets."""

Question = "Who discovered the cryptic force?"

input = {'story':Story,
         'questions':[{'turn_id':0, 'input_text':Question}],
         'answers':[{'turn_id':0, 'span_start':0, 'span_end':0, 'span_text':"", 'input_text':""}],
         'source':"Me", 'filename':"Me", 'id':"Me"}

my_example = Processor()._create_examples(input, 1)
print(my_example)

[Question : ['Who discovered the cryptic force?']
                   Document : ['In', 'the', 'hushed', 'confines', 'of', 'a', 'cutting', '-', 'edge', 'particle', 'physics', 'laboratory', ',', 'Dr.', 'Samuel', 'Reynolds', 'and', 'his', 'team', 'embarked', 'on', 'a', 'journey', 'to', 'uncover', 'the', 'mysteries', 'of', 'the', 'universe', '.', 'They', 'sought', 'to', 'unravel', 'the', 'enigma', 'of', 'dark', 'matter', ',', 'an', 'invisible', 'substance', 'that', 'eluded', 'comprehension', 'for', 'centuries', '.', 'One', 'fateful', 'day', ',', 'as', 'the', 'team', 'conducted', 'an', 'experiment', 'involving', 'the', 'behavior', 'of', 'subatomic', 'particles', ',', 'they', 'noticed', 'an', 'anomaly', '.', 'Particles', 'interacted', 'in', 'a', 'peculiar', 'manner', ',', 'exhibiting', 'an', 'inexplicable', 'attraction', 'to', 'one', 'another', '.', 'Dr.', 'Reynolds', 'named', 'this', 'uncharted', 'force', 'the', '"', 'Cryptic', 'Force', ',', '"', 'symbolizing', 'the', 'hidden', 'secrets', '

## Convert to tensors

In [71]:
my_dataset, my_features = load_dataset(my_example, tokenizer, evaluate = True)
my_features = my_features[0]
print(my_features.tokens)

inp = {'input_ids':torch.tensor([my_features.input_ids]).to(device),
       "segment_ids": torch.tensor([my_features.segment_ids]).to(device),
       "input_masks": torch.tensor([my_features.input_mask]).to(device)}

Extracting features from dataset: 100%|██████████| 1/1 [00:00<00:00, 22.96it/s]
Tag unique id to each example: 100%|██████████| 1/1 [00:00<00:00, 13315.25it/s]

['[CLS]', 'who', 'discovered', 'the', 'cryptic', 'force', '?', '[SEP]', 'in', 'the', 'hushed', 'confines', 'of', 'a', 'cutting', '-', 'edge', 'particle', 'physics', 'laboratory', ',', 'dr', '.', 'samuel', 'reynolds', 'and', 'his', 'team', 'embarked', 'on', 'a', 'journey', 'to', 'uncover', 'the', 'mysteries', 'of', 'the', 'universe', '.', 'they', 'sought', 'to', 'un', '##rave', '##l', 'the', 'enigma', 'of', 'dark', 'matter', ',', 'an', 'invisible', 'substance', 'that', 'el', '##uded', 'comprehension', 'for', 'centuries', '.', 'one', 'fate', '##ful', 'day', ',', 'as', 'the', 'team', 'conducted', 'an', 'experiment', 'involving', 'the', 'behavior', 'of', 'sub', '##ato', '##mic', 'particles', ',', 'they', 'noticed', 'an', 'anomaly', '.', 'particles', 'interact', '##ed', 'in', 'a', 'peculiar', 'manner', ',', 'exhibiting', 'an', 'in', '##ex', '##pl', '##ica', '##ble', 'attraction', 'to', 'one', 'another', '.', 'dr', '.', 'reynolds', 'named', 'this', 'un', '##cha', '##rted', 'force', 'the', '"




## Now we can pass it through the model and get logits.



In [83]:
with torch.no_grad():
    outp = trained_model(**inp)
outp = [o.detach().cpu().tolist()[0] for o in outp]
start_logits, end_logits, yes_logits, no_logits, unk_logits = outp
print(start_logits)
print(end_logits)
print(yes_logits)
print(no_logits)
print(unk_logits)

[-3.0385701656341553, -4.43826961517334, -6.55425500869751, -7.164093017578125, -7.3818206787109375, -7.840551376342773, -4.996629238128662, -6.9712443351745605, -0.9702745676040649, -3.4593305587768555, -4.290556907653809, -5.0918169021606445, -4.2428483963012695, -0.8664641380310059, -3.257561683654785, -5.441644668579102, -6.145437717437744, -1.9303851127624512, -3.901210308074951, -3.3312127590179443, -2.6177828311920166, 7.235795021057129, -0.4374772012233734, 4.472531318664551, 1.6176892518997192, 1.6589255332946777, 2.968040943145752, 2.080980062484741, -1.4893442392349243, -3.5463972091674805, -1.8986639976501465, -3.786787509918213, -2.7399799823760986, -3.836890459060669, -2.2709193229675293, -4.152251720428467, -4.979352951049805, -2.576104164123535, -4.519848823547363, -6.162850856781006, -0.3449029326438904, -4.240826606750488, -3.511794090270996, -4.128286838531494, -6.265950679779053, -6.257771968841553, -2.7426671981811523, -3.521883964538574, -5.332711219787598, -1.795

# Simplified inference

Since, the span is the best answer, our prediction should be the span.

In [86]:
result = [np.amax(start_logits)+np.amax(end_logits), yes_logits, no_logits, unk_logits]
print(_compute_softmax(result))
best_span_start = np.argmax(start_logits)
best_span_end = np.argmax(end_logits)

print("The answer is : ")
print(" ".join(my_features.tokens[best_span_start:best_span_end+1]))

[0.9999981877234422, 1.6693900411397654e-07, 5.934713645617118e-08, 1.5859904173789526e-06]
The answer is : 
dr . samuel reynolds


Our system seems to be working perfectly and produces the right answer. Trace back your steps and try with your own questions, or stories.