In [2]:
"""
Student: Dinh Khac Tuyen
ID:20214182
"""

from platform import python_version
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
from torch.utils.data import DataLoader
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm
import numpy as np
from datasets import load_metric
import os.path
from os import path
from datasets import load_dataset
from transformers import AutoTokenizer

print("python", python_version())
print("torch", torch.__version__)

python 3.7.11
torch 1.10.0+cu102


In [2]:
sst_dataset = load_dataset('sst')

No config specified, defaulting to: sst/default
Reusing dataset sst (/home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_uncased = AutoTokenizer.from_pretrained("bert-base-uncased")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [4]:
"""
Problem 2.1: Sentence classification with bert-base-cased
"""
def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True)
def binarize_label(examples):
    examples["label"]=["neg" if x<0.5 else "pos" for x in examples["label"] ]
    return examples
tokenized_datasets = sst_dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.map(binarize_label, batched=True)
tokenized_datasets = tokenized_datasets.class_encode_column("label")

Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-d3ad3ac7efffe173.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-266df31685d7a46c.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-2a8242adb1933312.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-d77eff5bb802257e.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-6736b1661a261c97.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datas

In [5]:
tokenized_datasets = tokenized_datasets.remove_columns(["sentence","tokens","tree"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

Loading cached shuffled indices for dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-77efcecc8c13ae60.arrow
Loading cached shuffled indices for dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-01ae045a0919fcd6.arrow


In [6]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=6)
eval_dataloader = DataLoader(eval_dataset, batch_size=6)

In [9]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [7]:
if path.exists("trained_model_21"): # if I had trained the model before
    model = AutoModelForSequenceClassification.from_pretrained("trained_model_21")
else:
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
model.to(device)

In [8]:
optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [12]:
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|          | 0/4272 [00:00<?, ?it/s]

In [15]:
metric= load_metric("accuracy")
model.eval()
for batch in tqdm(eval_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

  0%|          | 0/369 [00:00<?, ?it/s]

{'accuracy': 0.8502262443438914}

In [15]:
model.save_pretrained("trained_model_21")

In [42]:
"""
Problem 2.2: Find 3 wrong predictions
"""
my_sentences=["The movie is not only interesting",
              "Only stupid people will think it's a bad movie",
              "The idea of the movie is as new as the earth "]
            # should be positive, positive, negative if the model understand me 
tokenized_sentences=[]
for st in my_sentences:
    tokenized_sentence=tokenizer(st, padding="max_length", truncation=True)
    tokenized_sentence={k:(torch.LongTensor(v)).to(device) for k,v in tokenized_sentence.items()}
    tokenized_sentences.append(tokenized_sentence)

In [43]:
test_loader=DataLoader(tokenized_sentences,batch_size=1)

for i,batch in enumerate(test_loader):
    out=model(**batch)
    predictions = torch.argmax(out.logits, dim=-1)
    print("Sentence {}:{}".format(i,"negative" if predictions.item()==0 else "positive"))

Sentence 0:negative
Sentence 1:negative
Sentence 2:positive


In [5]:
"""
Problem 2.4: Use bert-base-uncased
"""
def tokenize_uncased_function(examples):
    return tokenizer_uncased(examples["sentence"], padding="max_length", truncation=True)
def binarize_label(examples):
    examples["label"]=["neg" if x<0.5 else "pos" for x in examples["label"] ]
    return examples
uncased_datasets = sst_dataset.map(tokenize_uncased_function, batched=True)
uncased_datasets = uncased_datasets.map(binarize_label, batched=True)
uncased_datasets = uncased_datasets.class_encode_column("label")

if path.exists("trained_model_24"): # if I had trained the model before
    model_uncased = AutoModelForSequenceClassification.from_pretrained("trained_model_24")
else:
    model_uncased = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model_uncased.to(device)

  0%|          | 0/9 [00:00<?, ?ba/s]

Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-7262702fea9cafb5.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-989951e0688394fe.arrow


  0%|          | 0/9 [00:00<?, ?ba/s]

Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-ab923db49e7154da.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-1636fd4fefe8c466.arrow


Casting to class labels:   0%|          | 0/9 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-d62b889f3706a9f0.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-2099a7ab09c2933e.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-524a61662fec2663.arrow
Loading cached processed dataset at /home/ncl/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff/cache-33cb597dc89ca48b.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.tr

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [6]:
uncased_datasets = uncased_datasets.remove_columns(["sentence","tokens","tree"])
uncased_datasets = uncased_datasets.rename_column("label", "labels")
uncased_datasets.set_format("torch")

train_set = uncased_datasets["train"].shuffle(seed=42)
eval_set = uncased_datasets["test"].shuffle(seed=42)

train_loader = DataLoader(train_set, shuffle=True, batch_size=6)
eval_loader = DataLoader(eval_set, batch_size=6)

In [9]:
optimizer = AdamW(model_uncased.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)
lr_scheduler2 = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

model_uncased.train()
for epoch in range(num_epochs):
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model_uncased(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler2.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|          | 0/4272 [00:00<?, ?it/s]

In [11]:
metric2= load_metric("accuracy")
model_uncased.eval()
for batch in tqdm(eval_loader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model_uncased(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric2.add_batch(predictions=predictions, references=batch["labels"])

metric2.compute()

  0%|          | 0/369 [00:00<?, ?it/s]

{'accuracy': 0.860633484162896}

In [12]:
model_uncased.save_pretrained("trained_model_24")

In [95]:
"""
Problem 3.1: Finetuning bert for question answering
The solution here is based on the tutorial https://huggingface.co/transformers/custom_datasets.html#question-answering-with-squad-2-0
"""
squad_dataset=load_dataset("squad")

Reusing dataset squad (/home/ncl/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

In [96]:
squad_dataset['train'].features

{'id': Value(dtype='string', id=None),
 'title': Value(dtype='string', id=None),
 'context': Value(dtype='string', id=None),
 'question': Value(dtype='string', id=None),
 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}

In [97]:
squad_dataset['validation'][10]

{'id': '56bea9923aeaaa14008c91bb',
 'title': 'Super_Bowl_50',
 'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.',
 'question': 'What day was the Super Bowl played on?',
 'answers': {'text': ['February 7, 2016', 'February 7', 'February 7, 2016'],
  'answer_star

In [99]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def read_dataset(dataset):
    contexts = []
    questions = []
    answers = []
    for data in tqdm(dataset):
        contexts.append(data['context'])
        questions.append(data['question'])
        ans={}
        ans["text"]=data['answers']['text']
        ans["answer_start"]=data['answers']['answer_start']
        ans["answer_end"]=[]
        for i,each_start in enumerate(ans["answer_start"]):
            # sometime the squad answers are off by 1 or 2 character, we should verify it 
            end_each=each_start+len(data['answers']['text'][i])
            if data['context'][each_start:end_each]==ans['text'][i]:
                ans["answer_end"].append(end_each)
            elif data['context'][each_start-1:end_each-1]==ans['text'][i]:
                ans['answer_start'][i]=each_start-1
                ans["answer_end"].append(end_each-1)
            elif data['context'][each_start-2:end_each-2]==ans['text'][i]:
                ans['answer_start'][i]=each_start-2
                ans["answer_end"].append(end_each-2)
        answers.append(ans)
    return contexts, questions, answers

train_contexts, train_questions, train_answers = read_dataset(squad_dataset['train'])
val_contexts, val_questions, val_answers = read_dataset(squad_dataset['validation'])

  0%|          | 0/87599 [00:00<?, ?it/s]

  0%|          | 0/10570 [00:00<?, ?it/s]

In [100]:
print(val_answers[1])

{'text': ['Carolina Panthers', 'Carolina Panthers', 'Carolina Panthers'], 'answer_start': [249, 249, 249], 'answer_end': [266, 266, 266]}


In [101]:
# set max leangth as 256 
train_encodings = tokenizer(train_contexts, train_questions, truncation=True,max_length=256, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True,max_length=256, padding=True)

In [102]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append([encodings.char_to_token(i, answers[i]['answer_start'][j]) for j in range(len(answers[i]['answer_start']))])
        end_positions.append([encodings.char_to_token(i, answers[i]['answer_end'][j] - 1) for j in range(len(answers[i]['answer_end']))])

        # if start position is None, the answer passage has been truncated
        for j in range(len(start_positions[-1])):
            if start_positions[-1][j] is None:
                start_positions[-1][j] = tokenizer.model_max_length
        for j in range(len(end_positions[-1])):
            if end_positions[-1][j] is None:
                end_positions[-1][j] = tokenizer.model_max_length

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)   

In [103]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

In [104]:
from transformers import AutoModelForQuestionAnswering
qa_model = AutoModelForQuestionAnswering.from_pretrained("bert-base-cased")
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
qa_model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForQuestionAnswering: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and a

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [105]:
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

qa_model.to(device)
qa_model.train()

num_epochs=3
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
num_training_steps = num_epochs * len(train_loader)

progress_bar=tqdm(range(num_training_steps))
optim = AdamW(qa_model.parameters(), lr=5e-5)

for epoch in range(3):
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = qa_model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        optim.step()
        progress_bar.update(1)


  0%|          | 0/16425 [00:00<?, ?it/s]

In [110]:
qa_model.save_pretrained("qa_model")

In [106]:
val_loader = DataLoader(val_dataset, batch_size=1,shuffle=True) # set batch_size=1 to test each sample
qa_model.eval()
correct_count=0
for batch in tqdm(val_loader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_positions = (batch['start_positions'][0]).numpy()
    end_positions = (batch['end_positions'][0]).numpy()
    with torch.no_grad():
        outputs = qa_model(input_ids, attention_mask=attention_mask)
    start_=torch.argmax(outputs['start_logits'],dim=-1)
    end_=torch.argmax(outputs['end_logits'],dim=-1)
    ans=[start_.item(),end_.item()]
    # In the validation set, there might be more than 1 answer for a question. If the prediction is matched to one of them, it's counted as correct
    for i in range(len(start_positions)):
        if ans == [start_positions[i],end_positions[i]]:
            correct_count+=1
            break

  0%|          | 0/10570 [00:00<?, ?it/s]

In [108]:
print("Accuracy:{}".format(correct_count/len(val_dataset)))

Accuracy:0.7263954588457899


In [214]:
"""
Problem 3.2: Find 3 failure cases 
"""    
# context is a random passage I took from BBC news, has 129 words in total
context="""Two men in the southern Indian state of Karnataka have tested positive for the Omicron coronavirus variant.
        One of them, a 66-year-old South African national, had travelled from there and has already left India, officials said.
        The second - a 46-year-old doctor in the southern Indian city of Bengaluru - has no travel history. 
        These are the first cases of the new Omicron variant to be reported in India.
        The World Health Organization (WHO) has warned that Omicron poses a high infection risk.
        At a press briefing on Thursday, health officials said the two patients with the new strain had shown mild symptoms.
        At a press briefing on Thursday, health officials said the two patients with the new strain had shown mild symptoms.
        All their primary contacts and secondary contacts have been traced and are being tested.
        According to an official release, five contacts of the 46-year old man have tested positive so far. The patients have been isolated and their samples have been sent for genome testing
        """
question=["What is the gender of both infected cases?",
            "How many cases in total in India so far?",
          "Where are the first infected person now?"
          ]

for i in range(3):
    sample_encodings = tokenizer(context, question[i], truncation=True,max_length=256, padding=True)
    input_ids = torch.LongTensor(sample_encodings['input_ids']).unsqueeze(0).to(device)
    attention_mask = torch.LongTensor(sample_encodings['attention_mask']).unsqueeze(0).to(device)
    token_type_ids = torch.LongTensor(sample_encodings['token_type_ids']).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = qa_model(input_ids, attention_mask=attention_mask)
    start_=torch.argmax(outputs['start_logits'],dim=-1)
    end_=torch.argmax(outputs['end_logits'],dim=-1)
    answer=''
    for w_id in range(start_.item(),end_.item()+1):
        word=tokenizer.decode(sample_encodings['input_ids'][w_id])
        if word[0]=="#":
            word=word.replace("#","")
        else:
            word=" "+word
        answer+=word
    print("Answer to question {}:{}(word_start_idx:{})".format(i+1,answer,start_.item()))
        
    
    


Answer to question 1: Omicron coronavirus variant(word_start_idx:15)
Answer to question 2: five(word_start_idx:178)
Answer to question 3: India(word_start_idx:90)
