In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d5/43/cfe4ee779bbd6a678ac6a97c5a5cdeb03c35f9eaebbb9720b036680f9a2d/transformers-4.6.1-py3-none-any.whl (2.2MB)
[K     |████████████████████████████████| 2.3MB 7.8MB/s 
Collecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 48.3MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |█

### Import modules

In [2]:
import json
from pathlib import Path
from transformers import DistilBertTokenizerFast
import torch
from transformers import DistilBertForQuestionAnswering
from torch.utils.data import DataLoader
from transformers import AdamW
import numpy as np

### Dataset download

In [3]:
!mkdir squad

In [4]:
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O squad/train-v2.0.json
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad/dev-v2.0.json

--2021-06-07 10:01:39--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42123633 (40M) [application/json]
Saving to: ‘squad/train-v2.0.json’


2021-06-07 10:01:39 (73.6 MB/s) - ‘squad/train-v2.0.json’ saved [42123633/42123633]

--2021-06-07 10:01:39--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4370528 (4.2M) [application/json]
Saving to: ‘squad/dev-v2.0.json’


2021-06-07 10:01:40 (31.5 MB/s) - ‘squad/dev-v2.0.json’ saved [4370528/4370528]



### Utils function of read dataset

In [5]:
def read_squad(path):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    return contexts, questions, answers


train_contexts, train_questions, train_answers = read_squad('squad/train-v2.0.json')
val_contexts, val_questions, val_answers = read_squad('squad/dev-v2.0.json')

### Tokenize and positional encoding

In [6]:
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx - 1:end_idx - 1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1  # When the gold label is off by one character
        elif context[start_idx - 2:end_idx - 2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2  # When the gold label is off by two characters


add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)


def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
        # if None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})


add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




### Dataloader

In [7]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

### Model - Distil BERT

In [None]:
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

### Optimizer 

In [9]:
optim = AdamW(model.parameters(), lr=5e-5)

### Train

In [10]:
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

In [12]:
train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)
total_step = len(train_loader)

In [13]:
EPOCH = 10
for epoch in range(EPOCH):
    for i,batch in enumerate(train_loader):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions,
                        end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        optim.step()
        if i % 200 == 0:
          print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, EPOCH, i, total_step, loss.item(), np.exp(loss.item())))


Epoch [0/10], Step [0/2895], Loss: 6.2070, Perplexity: 496.1866
Epoch [0/10], Step [200/2895], Loss: 2.1159, Perplexity: 8.2972
Epoch [0/10], Step [400/2895], Loss: 1.8002, Perplexity: 6.0507
Epoch [0/10], Step [600/2895], Loss: 1.3843, Perplexity: 3.9919
Epoch [0/10], Step [800/2895], Loss: 1.0585, Perplexity: 2.8820
Epoch [0/10], Step [1000/2895], Loss: 1.0887, Perplexity: 2.9705
Epoch [0/10], Step [1200/2895], Loss: 0.9458, Perplexity: 2.5748
Epoch [0/10], Step [1400/2895], Loss: 1.1783, Perplexity: 3.2489
Epoch [0/10], Step [1600/2895], Loss: 1.1451, Perplexity: 3.1427
Epoch [0/10], Step [1800/2895], Loss: 0.9678, Perplexity: 2.6321
Epoch [0/10], Step [2000/2895], Loss: 1.0657, Perplexity: 2.9028
Epoch [0/10], Step [2200/2895], Loss: 1.0411, Perplexity: 2.8325


KeyboardInterrupt: ignored

**Above trained was stoped at one please train 10**

### Testing

In [27]:
text = "Asia is the largest and most populated continent. It has nearly one-third of the world’s total land area and " \
       "is home to more than half of Earth’s people. It also has impressive geographical features. It has Earth's " \
       "highest point is Mount Everest and lowest point is the Dead Sea. Asia also includes some of the world’s " \
       "wettest, driest, hottest, and coldest places. The continent was the home of the great early civilizations of " \
       "Mesopotamia and the Indus River valley. The world’s major religions Buddhism, Christianity, Hinduism, Islam, " \
       "and Judaism—all began in Asia as well. Today, though many people are farmers and live in small villages, " \
       "Asia also has enormous cities, including some of the world’s largest: Beijing, China; Tokyo, Japan; Seoul, " \
       "South Korea; and Delhi, India. "

ques = "what is the earths highest point?"

encodings = tokenizer.encode_plus(ques, text)

inputIds, attentionMask = encodings["input_ids"], encodings["attention_mask"]

start_scores = model(input_ids=torch.tensor([inputIds]).to(device))[0]
end_scores = model(input_ids=torch.tensor([inputIds]).to(device))[1]

tokens = inputIds[torch.argmax(start_scores): torch.argmax(end_scores) + 1]

answerTokens = tokenizer.convert_ids_to_tokens(tokens, skip_special_tokens=True)

ans = tokenizer.convert_tokens_to_string(answerTokens)

print(ans)


mount everest
