<a href="https://colab.research.google.com/github/hemantwani/GPT2_test_case/blob/master/longformer_qa_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Longformer for Question Answering

In [None]:
!nvidia-smi

Wed Sep  2 16:08:40 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.66       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!git clone https://github.com/huggingface/transformers.git
!pip install -U ./transformers
!pip install git+https://github.com/huggingface/nlp.git

Cloning into 'transformers'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects: 100% (87/87), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 40218 (delta 40), reused 28 (delta 3), pack-reused 40131[K
Receiving objects: 100% (40218/40218), 28.97 MiB | 32.17 MiB/s, done.
Resolving deltas: 100% (27882/27882), done.
Processing ./transformers
Collecting tokenizers==0.8.1.rc2
[?25l  Downloading https://files.pythonhosted.org/packages/80/83/8b9fccb9e48eeb575ee19179e2bdde0ee9a1904f97de5f02d19016b8804f/tokenizers-0.8.1rc2-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 3.4MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 31.2MB/s 
[?25hCollecting sacremoses
[?25l  Downloading http

The Longformer model was presented in [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. As the paper explains it

> `Longformer` is a BERT-like model for long documents.


Training longformer for QA is similar to how you train BERT for QA. But there few things to keep in mind when using longformer for QA task.

Longformer uses sliding-window local attention which scales linearly with sequence length. This is what allows longformer to handle longer sequences. For more details on how the sliding window attention works, please refer to the paper. Along with local attention longformer also allows you to use global attention for certain tokens. For QA task, all question tokens should have global attention.

The attention is configured using the `attention_mask` paramter of the `forward` method of `LongformerForQuestionAnswering`. Mask values are selected in [0, 1, 2]: 0 for no attention (padding tokens), 1 for local attention (a sliding window attention), 2 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).

As stated above all question tokens should be given gloabl attention. The `LongformerForQuestionAnswering` model handles this automatically for you. To allow it to do that
1. The input sequence must have three sep tokens, i.e the sequence should be encoded like this `<s> question</s></s> context</s>`. If you encode the question and answer as a input pair, then the tokenizer already takes care of that, you shouldn't worry about it.
2. input_ids should always be a batch of examples.

In [None]:
!pip install utils

Collecting utils
  Downloading https://files.pythonhosted.org/packages/55/e6/c2d2b2703e7debc8b501caae0e6f7ead148fd0faa3c8131292a599930029/utils-1.0.1-py2.py3-none-any.whl
Installing collected packages: utils
Successfully installed utils-1.0.1


## Load and process data

Here we are using the awesome new nlp library to load and process the dataset.
Also we will use Transformers's fast tokenizers alignement methods to get position of answer spans  

In [None]:
import torch
import nlp
from transformers import LongformerTokenizerFast

PyTorch version 1.6.0+cu101 available.
TensorFlow version 2.3.0 available.


In [None]:
tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




In [None]:
def get_correct_alignement(context, answer):
    """ Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
    gold_text = answer['text'][0]
    start_idx = answer['answer_start'][0]
    end_idx = start_idx + len(gold_text)
    if context[start_idx:end_idx] == gold_text:
        return start_idx, end_idx       # When the gold label position is good
    elif context[start_idx-1:end_idx-1] == gold_text:
        return start_idx-1, end_idx-1   # When the gold label is off by one character
    elif context[start_idx-2:end_idx-2] == gold_text:
        return start_idx-2, end_idx-2   # When the gold label is off by two character
    else:
        raise ValueError()

# Tokenize our training dataset
def convert_to_features(example):
    # Tokenize contexts and questions (as pairs of inputs)
    input_pairs = [example['question'], example['context']]
    encodings = tokenizer.encode_plus(input_pairs, pad_to_max_length=True, max_length=512)
    context_encodings = tokenizer.encode_plus(example['context'])
    

    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methodes.
    # this will give us the position of answer span in the context text
    start_idx, end_idx = get_correct_alignement(example['context'], example['answers'])
    start_positions_context = context_encodings.char_to_token(start_idx)
    end_positions_context = context_encodings.char_to_token(end_idx-1)
    # here we will compute the start and end position of the answer in the whole example
    # as the example is encoded like this <s> question</s></s> context</s>
    # and we know the postion of the answer in the context
    # we can just find out the index of the sep token and then add that to position + 1 (+1 because there are two sep tokens)
    # this will give us the position of the answer span in whole example 
    sep_idx = encodings['input_ids'].index(tokenizer.sep_token_id)
    start_positions = start_positions_context + sep_idx + 1
    end_positions = end_positions_context + sep_idx + 1

    if end_positions > 512:
      start_positions, end_positions = 0, 0

    encodings.update({'start_positions': start_positions,
                      'end_positions': end_positions,
                      'attention_mask': encodings['attention_mask']})
    return encodings

In [None]:
# load train and validation split of squad
train_dataset  = nlp.load_dataset('squad', split=nlp.Split.TRAIN)
valid_dataset = nlp.load_dataset('squad', split=nlp.Split.VALIDATION)




https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/tmphhwmbpr9


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5015.0, style=ProgressStyle(description…

storing https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py in cache at /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.8fee6e3d53a4d9e5483442c8ba26e06e4ef70eaca60ac7bebc8429fc64a5e86a.py
creating metadata file for /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.8fee6e3d53a4d9e5483442c8ba26e06e4ef70eaca60ac7bebc8429fc64a5e86a.py
https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/tmpio93i32x





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2240.0, style=ProgressStyle(description…

storing https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json in cache at /root/.cache/huggingface/datasets/9ba53336b6bc977097b39b8527b06ec6ba3f60a44230f2a0a918735fcd8ad902.893fb39fe374e4c574667dd71a3017b7e2e1d196f3a34fb00b56bac805447f7c
creating metadata file for /root/.cache/huggingface/datasets/9ba53336b6bc977097b39b8527b06ec6ba3f60a44230f2a0a918735fcd8ad902.893fb39fe374e4c574667dd71a3017b7e2e1d196f3a34fb00b56bac805447f7c
Checking /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.8fee6e3d53a4d9e5483442c8ba26e06e4ef70eaca60ac7bebc8429fc64a5e86a.py for additional imports.
Creating main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad
Creating specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp


Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.75 MiB, post-processed: Unknown size, total: 119.27 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab...


https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/tmpc2ew2aaq


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…

storing https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json in cache at /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.2260363226dda2e2e19f3c2c74ca92767f38f9de88a4afee8a94be122e8947fa
creating metadata file for /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.2260363226dda2e2e19f3c2c74ca92767f38f9de88a4afee8a94be122e8947fa





https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/tmpnu4tuy0w


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…

storing https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json in cache at /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.dbff9c68072d51656de2421b66cd67935b545f16130f044db5a7ba1149ef0af3
creating metadata file for /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.dbff9c68072d51656de2421b66cd67935b545f16130f044db5a7ba1149ef0af3
All the checksums matched successfully for dataset source files
Generating split train





HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Done writing 87599 examples in 79317110 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab.incomplete/squad-train.arrow.
Generating split validation




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Done writing 10570 examples in 10472653 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab.incomplete/squad-validation.arrow.
All the splits matched successfully.
Constructing Dataset for split train, from /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab
Unable to verify checksums.


Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab. Subsequent calls will reuse this data.


Checking /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.8fee6e3d53a4d9e5483442c8ba26e06e4ef70eaca60ac7bebc8429fc64a5e86a.py for additional imports.
Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad
Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab
Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/408a8fa46a1e2805445b793f1022e743428ca739a34809fce872f0c7f17b44ab/squad.py
Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json to /usr/local/lib/python3.6/dist-pa

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [None]:
# process the examples in input and target text format and the eos token at the end 
def add_eos_to_examples(example):
    example['input_text'] = 'question: %s  context: %s </s>' % (example['question'], example['context'])
    example['target_text'] = '%s </s>' % example['answers']['text'][0]
    return example

In [None]:

train_dataset = train_dataset.map(convert_to_features)

TypeError: ignored

In [None]:
valid_dataset = valid_dataset.map(convert_to_features, load_from_cache_file=False)


# set the tensor type and the columns which the dataset should return
columns = ['input_ids', 'attention_mask', 'start_positions', 'end_positions']
train_dataset.set_format(type='torch', columns=columns)
valid_dataset.set_format(type='torch', columns=columns)

In [None]:
import utils.utils
import utils.dataset_utils
import os
from tqdm import tqdm
import random
import nltk
import argparse


def get_text(qad, domain):
    local_file = os.path.join(args.web_dir, qad['Filename']) if domain == 'SearchResults' else os.path.join(args.wikipedia_dir, qad['Filename'])
    return utils.utils.get_file_contents(local_file, encoding='utf-8')


def select_relevant_portion(text):
    paras = text.split('\n')
    selected = []
    done = False
    for para in paras:
        sents = sent_tokenize.tokenize(para)
        for sent in sents:
            words = nltk.word_tokenize(sent)
            for word in words:
                selected.append(word)
                if len(selected) >= args.max_num_tokens:
                    done = True
                    break
            if done:
                break
        if done:
            break
        selected.append('\n')
    st = ' '.join(selected).strip()
    return st


def add_triple_data(datum, page, domain):
    qad = {'Source': domain}
    for key in ['QuestionId', 'Question', 'Answer']:
        qad[key] = datum[key]
    for key in page:
        qad[key] = page[key]
    return qad


def get_qad_triples(data):
    qad_triples = []
    for datum in data['Data']:
        for key in ['EntityPages', 'SearchResults']:
            for page in datum.get(key, []):
                qad = add_triple_data(datum, page, key)
                qad_triples.append(qad)
    return qad_triples


def convert_to_squad_format(qa_json_file, squad_file):
    qa_json = utils.dataset_utils.read_triviaqa_data(qa_json_file)
    qad_triples = get_qad_triples(qa_json)

    random.seed(args.seed)
    random.shuffle(qad_triples)

    data = []
    for qad in tqdm(qad_triples):
        qid = qad['QuestionId']

        text = get_text(qad, qad['Source'])
        selected_text = select_relevant_portion(text)

        question = qad['Question']
        para = {'context': selected_text, 'qas': [{'question': question, 'answers': []}]}
        data.append({'paragraphs': [para]})
        qa = para['qas'][0]
        qa['id'] = utils.dataset_utils.get_question_doc_string(qid, qad['Filename'])
        qa['qid'] = qid

        ans_string, index = utils.dataset_utils.answer_index_in_document(qad['Answer'], selected_text)
        if index == -1:
            if qa_json['Split'] == 'train':
                continue
        else:
            qa['answers'].append({'text': ans_string, 'answer_start': index})

        if qa_json['Split'] == 'train' and len(data) >= args.sample_size and qa_json['Domain'] == 'Web':
            break

    squad = {'data': data, 'version': qa_json['Version']}
    utils.utils.write_json_to_file(squad, squad_file)
    print ('Added', len(data))


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--triviaqa_file', help='Triviaqa file')
    parser.add_argument('--squad_file', help='Squad file')
    parser.add_argument('--wikipedia_dir', help='Wikipedia doc dir')
    parser.add_argument('--web_dir', help='Web doc dir')

    parser.add_argument('--seed', default=10, type=int, help='Random seed')
    parser.add_argument('--max_num_tokens', default=800, type=int, help='Maximum number of tokens from a document')
    parser.add_argument('--sample_size', default=80000, type=int, help='Random seed')
    parser.add_argument('--tokenizer', default='tokenizers/punkt/english.pickle', help='Sentence tokenizer')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    sent_tokenize = nltk.data.load(args.tokenizer)
    convert_to_squad_format(args.triviaqa_file, args.squad_file)

ModuleNotFoundError: ignored

In [None]:
# load train and validation split of squad
train_dataset  = nlp.load_dataset('squad', split=nlp.Split.TRAIN)
valid_dataset = nlp.load_dataset('squad', split=nlp.Split.VALIDATION)


train_dataset = train_dataset.map(convert_to_features)
valid_dataset = valid_dataset.map(convert_to_features, load_from_cache_file=False)


# set the tensor type and the columns which the dataset should return
columns = ['input_ids', 'attention_mask', 'start_positions', 'end_positions']
train_dataset.set_format(type='torch', columns=columns)
valid_dataset.set_format(type='torch', columns=columns)


Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


ValueError: ignored

In [None]:
len(train_dataset), len(valid_dataset)

(87599, 10570)

In [None]:
# cach the dataset, so we can load it directly for training

torch.save(train_dataset, 'train_data.pt')
torch.save(valid_dataset, 'valid_data.pt')

## Write training script

In [None]:
import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional

import numpy as np
import torch

from transformers import LongformerForQuestionAnswering, LongformerTokenizerFast, EvalPrediction
from transformers import (
    HfArgumentParser,
    DataCollator,
    Trainer,
    TrainingArguments,
    set_seed,
)


logger = logging.getLogger(__name__)

# @dataclass
class DummyDataCollator(DataCollator):
    def collate_batch(self, batch: List) -> Dict[str, torch.Tensor]:
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns:
            A dictionary of tensors
        """
        input_ids = torch.stack([example['input_ids'] for example in batch])
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        start_positions = torch.stack([example['start_positions'] for example in batch])
        end_positions = torch.stack([example['end_positions'] for example in batch])

        return {
            'input_ids': input_ids, 
            'start_positions': start_positions, 
            'end_positions': end_positions,
            'attention_mask': attention_mask
        }


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    train_file_path: Optional[str] = field(
        default='train_data.pt',
        metadata={"help": "Path for cached train dataset"},
    )
    valid_file_path: Optional[str] = field(
        default='valid_data.pt',
        metadata={"help": "Path for cached valid dataset"},
    )
    max_len: Optional[int] = field(
        default=512,
        metadata={"help": "Max input length for the source text"},
    )


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

    # we will load the arguments from a json file, 
    # make sure you save the arguments in at ./args.json
    model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath('args.json'))

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    tokenizer = LongformerTokenizerFast.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = LongformerForQuestionAnswering.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    print('loading data')
    train_dataset  = torch.load(data_args.train_file_path)
    valid_dataset = torch.load(data_args.valid_file_path)
    print('loading done')

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        data_collator=DummyDataCollator(),
        prediction_loss_only=True,
    )

    # Training
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        logger.info("*** Evaluate ***")

        eval_output = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(eval_output.keys()):
                logger.info("  %s = %s", key, str(eval_output[key]))
                writer.write("%s = %s\n" % (key, str(eval_output[key])))
    
        results.update(eval_output)
    
    return results


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()

## Train

In [None]:
import json

Let's write the arguments in a dict and store in a json file. The above code will load this file and parse the arguments.

In [None]:
args_dict = {
  "n_gpu": 1,
  "model_name_or_path": 'allenai/longformer-base-4096',
  "max_len": 512 ,
  "output_dir": './models',
  "overwrite_output_dir": True,
  "per_gpu_train_batch_size": 8,
  "per_gpu_eval_batch_size": 8,
  "gradient_accumulation_steps": 16,
  "learning_rate": 1e-4,
  "num_train_epochs": 3,
  "do_train": True
}

In [None]:
with open('args.json', 'w') as f:
  json.dump(args_dict, f)

Start training!

In [None]:
main()



HBox(children=(FloatProgress(value=0.0, description='Downloading', max=725.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=597257159.0, style=ProgressStyle(descri…


loading data
loading done


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=10950.0, style=ProgressStyle(description_…

{"loss": 1.2000314732473343, "learning_rate": 7.563352826510721e-05, "epoch": 0.730593607305936, "step": 500}







HBox(children=(FloatProgress(value=0.0, description='Iteration', max=10950.0, style=ProgressStyle(description_…

{"loss": 0.7662997634811327, "learning_rate": 5.126705653021443e-05, "epoch": 1.4617351598173516, "step": 1000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=10950.0, style=ProgressStyle(description_…

{"loss": 0.6072164938626811, "learning_rate": 2.6900584795321637e-05, "epoch": 2.192876712328767, "step": 1500}
{"loss": 0.46604699626751245, "learning_rate": 2.5341130604288498e-06, "epoch": 2.923470319634703, "step": 2000}




{}

## Eval

In [None]:
## SQuAD evaluation script. Modifed slightly for this notebook

from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(gold_answers, predictions):
    f1 = exact_match = total = 0

    for ground_truths, prediction in zip(gold_answers, predictions):
      total += 1
      exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
      f1 += metric_max_over_ground_truths(
          f1_score, prediction, ground_truths)
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

In [None]:
import torch
from transformers import LongformerTokenizerFast, LongformerForQuestionAnswering
from tqdm.auto import tqdm

In [None]:
tokenizer = LongformerTokenizerFast.from_pretrained('models')
model = LongformerForQuestionAnswering.from_pretrained('models')
model = model.cuda()
model.eval()

LongformerForQuestionAnswering(
  (longformer): LongformerModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
              (value_global): Linear(in_feat

In [None]:
valid_dataset = torch.load('valid_data.pt')
dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=16)

In [None]:
answers = []
with torch.no_grad():
  for batch in tqdm(dataloader):
    start_scores, end_scores = model(input_ids=batch['input_ids'].cuda(),
                                  attention_mask=batch['attention_mask'].cuda())
    for i in range(start_scores.shape[0]):
      all_tokens = tokenizer.convert_ids_to_tokens(batch['input_ids'][i])
      answer = ' '.join(all_tokens[torch.argmax(start_scores[i]) : torch.argmax(end_scores[i])+1])
      ans_ids = tokenizer.convert_tokens_to_ids(answer.split())
      answer = tokenizer.decode(ans_ids)
      answers.append(answer)

HBox(children=(FloatProgress(value=0.0, max=661.0), HTML(value='')))




In [None]:
predictions = []
references = []
for ref, pred in zip(valid_dataset, answers):
  predictions.append(pred)
  references.append(ref['answers']['text'])

In [None]:
evaluate(references, predictions)

{'exact_match': 85.14664143803216, 'f1': 91.54157494727959}

## Model in action 🚀

The trained model is available on Huggingface hub if you want to play with it.
You can find the model [here](https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1) 

In [None]:
import torch
from transformers import LongformerTokenizer, LongformerForQuestionAnswering

tokenizer = LongformerTokenizer.from_pretrained("valhalla/longformer-base-4096-finetuned-squadv1")
model = LongformerForQuestionAnswering.from_pretrained("valhalla/longformer-base-4096-finetuned-squadv1")

text = "Huggingface has democratized NLP. Huge thanks to Huggingface for this."
question = "What has Huggingface done ?"
encoding = tokenizer.encode_plus(question, text, return_tensors="pt")
input_ids = encoding["input_ids"]

# default is local attention everywhere
# the forward method will automatically set global attention on question tokens
attention_mask = encoding["attention_mask"]

start_scores, end_scores = model(input_ids, attention_mask=attention_mask)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1]
answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens))
# output => democratized NLP