In [1]:
!pip install torch
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
[K     |████████████████████████████████| 573kB 4.7MB/s 
Collecting tokenizers==0.5.2
[?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
[K     |████████████████████████████████| 3.7MB 14.9MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/98/2c/8df20f3ac6c22ac224fff307ebc102818206c53fc454ecd37d8ac2060df5/sentencepiece-0.1.86-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 37.1MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/99/50/93509f906a40bffd7d175f97fd75ea328ad9bd91f48f59c4bd084c94a25e/sacremoses-0.0.41.tar.gz (883kB)
[K     |██████████

In [0]:
import torch
from transformers import squad_convert_examples_to_features
from transformers.data.processors.squad import SquadResult, SquadV2Processor
from transformers import BertModel, BertConfig, BertTokenizer

In [4]:
device = torch.device('cuda')
feature_processor = SquadV2Processor()
examples = feature_processor.get_dev_examples('/content/drive/My Drive/cis530project/cis530project/data')

100%|██████████| 35/35 [00:04<00:00,  8.71it/s]


HBox(children=(IntProgress(value=0, description='Downloading', max=231508, style=ProgressStyle(description_wid…




# BERT Model

In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [0]:
from torch import nn
from transformers import BertModel, BertConfig, BertTokenizer

class BERT_SQUAD(nn.Module):
    def __init__(self):
        super(BERT_SQUAD, self).__init__()

        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

        self.fc_layers = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

        self.criterion = nn.CrossEntropyLoss()

        #self.softmax



    def forward(self, c_q_pairs, attention_mask, token_type_ids, start_indices, end_indices):

        bert_encoded = self.bert_model(
            input_ids=c_q_pairs,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        start_indices = (start_indices).clamp(0, start_outputs.shape[1]-1)
        end_indices = (end_indices).clamp(0, start_outputs.shape[1]-1)

        start_loss = self.criterion(start_outputs, start_indices)
        end_loss = self.criterion(end_outputs, end_indices)

        return 2*start_loss + end_loss


    def predict(self, c_q_pairs, attention_mask, token_type_ids):
        bert_encoded = self.bert_model(
            input_ids=c_q_pairs,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        starts, s_ind = start_outputs.max(1)
        ends, e_ind = end_outputs.max(1)


        answers = []
        for i in range(start_outputs.shape[0]):
            start = s_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            end = e_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            answers.append([start, end])
        return answers

In [148]:
bs1 = BERT_SQUAD()
bs1.load_state_dict(torch.load('/content/drive/My Drive/cis530project/bert-squad.pt'))

<All keys matched successfully>

In [149]:
features, dataset = squad_convert_examples_to_features(
    examples=examples,
    tokenizer=tokenizer,
    max_seq_length=512,
    doc_stride=128,
    max_query_length=128,
    is_training=False,
    return_dataset='pt',
    threads=1
)



convert squad examples to features:   0%|          | 0/11873 [00:00<?, ?it/s][A[A

convert squad examples to features:   0%|          | 1/11873 [00:00<54:54,  3.60it/s][A[A

convert squad examples to features:   0%|          | 33/11873 [00:00<38:41,  5.10it/s][A[A

convert squad examples to features:   1%|          | 65/11873 [00:00<27:27,  7.17it/s][A[A

convert squad examples to features:   1%|          | 97/11873 [00:00<19:34, 10.02it/s][A[A

convert squad examples to features:   1%|          | 129/11873 [00:01<13:59, 14.00it/s][A[A

convert squad examples to features:   1%|▏         | 161/11873 [00:01<10:01, 19.48it/s][A[A

convert squad examples to features:   2%|▏         | 193/11873 [00:01<07:18, 26.64it/s][A[A

convert squad examples to features:   2%|▏         | 225/11873 [00:01<05:27, 35.57it/s][A[A

convert squad examples to features:   2%|▏         | 257/11873 [00:01<04:04, 47.45it/s][A[A

convert squad examples to features:   2%|▏         | 289/11873 

In [0]:
dev_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False)

In [151]:
import json
from datetime import datetime
from transformers.data.processors.squad import SquadFeatures
bs1 = bs1.to(device)
bs1.eval()

start_time = datetime.now()
outputs = dict()

for i, batch in enumerate(dev_loader):
    c_q_pairs = batch[0].to(device)
    attention_mask = batch[1].to(device)
    token_type_ids = batch[2].to(device)
    indices = bs1.predict(c_q_pairs, attention_mask, token_type_ids)
    start, end = indices[0][0], indices[0][1]
    tokens = tokenizer.convert_ids_to_tokens(c_q_pairs.view(-1).tolist())#[start:end+1])

    token_to_orig_map = features[i].token_to_orig_map
    if start in token_to_orig_map and end in token_to_orig_map:
        start, end = token_to_orig_map[start], token_to_orig_map[end]
        tokens = examples[features[i].example_index].doc_tokens[start:end+1] if start <= end else ""
        answer = ' '.join(tokens)
    else:
        answer = ""    
    

    if '[CLS]' in answer:
        answer = ""

    q_id = examples[features[i].example_index].qas_id
    if q_id not in outputs:
        outputs[examples[features[i].example_index].qas_id] = answer


    if i % 100 == 0:
        print('done with example : {}'.format(i))

print('total time for dev predictions : {}'.format(datetime.now() - start_time))
with open('bert-dev-preds.json', 'w') as f:
    json.dump(outputs, f)

done with example : 0
done with example : 100
done with example : 200
done with example : 300
done with example : 400
done with example : 500
done with example : 600
done with example : 700
done with example : 800
done with example : 900
done with example : 1000
done with example : 1100
done with example : 1200
done with example : 1300
done with example : 1400
done with example : 1500
done with example : 1600
done with example : 1700
done with example : 1800
done with example : 1900
done with example : 2000
done with example : 2100
done with example : 2200
done with example : 2300
done with example : 2400
done with example : 2500
done with example : 2600
done with example : 2700
done with example : 2800
done with example : 2900
done with example : 3000
done with example : 3100
done with example : 3200
done with example : 3300
done with example : 3400
done with example : 3500
done with example : 3600
done with example : 3700
done with example : 3800
done with example : 3900
done with ex

# distilBERT Model

In [0]:
import torch

from torch import nn
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

class DISTILBERT_SQUAD(nn.Module):
    def __init__(self):
        super(DISTILBERT_SQUAD, self).__init__()

        self.distilbert_model = DistilBertModel.from_pretrained('distilbert-base-uncased')

        self.fc_layers = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

        self.criterion = nn.CrossEntropyLoss()



    def forward(self, c_q_pairs, attention_mask, token_type_ids, start_indices, end_indices):

        bert_encoded = self.distilbert_model(
            input_ids=c_q_pairs,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        start_indices = (start_indices).clamp(0, start_outputs.shape[1]-1)
        end_indices = (end_indices).clamp(0, start_outputs.shape[1]-1)

        start_loss = self.criterion(start_outputs, start_indices)
        end_loss = self.criterion(end_outputs, end_indices)

        return 2*start_loss + end_loss


    def predict(self, c_q_pairs, attention_mask, token_type_ids):
        bert_encoded = self.distilbert_model(
            input_ids=c_q_pairs,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        starts, s_ind = start_outputs.max(1)
        ends, e_ind = end_outputs.max(1)


        answers = []
        for i in range(start_outputs.shape[0]):
            start = s_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            end = e_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            answers.append([start, end])
        return answers


In [0]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [154]:
bs1 = DISTILBERT_SQUAD()
bs1.load_state_dict(torch.load('/content/drive/My Drive/cis530project/distilbert-squad.pt'))

<All keys matched successfully>

In [155]:
features, dataset = squad_convert_examples_to_features(
    examples=examples,
    tokenizer=tokenizer,
    max_seq_length=512,
    doc_stride=128,
    max_query_length=128,
    is_training=False,
    return_dataset='pt',
    threads=1
)



convert squad examples to features:   0%|          | 0/11873 [00:00<?, ?it/s][A[A

convert squad examples to features:   0%|          | 1/11873 [00:00<52:44,  3.75it/s][A[A

convert squad examples to features:   0%|          | 33/11873 [00:00<37:11,  5.31it/s][A[A

convert squad examples to features:   1%|          | 65/11873 [00:00<26:24,  7.45it/s][A[A

convert squad examples to features:   1%|          | 97/11873 [00:00<18:50, 10.42it/s][A[A

convert squad examples to features:   1%|          | 129/11873 [00:01<13:28, 14.52it/s][A[A

convert squad examples to features:   1%|▏         | 161/11873 [00:01<09:40, 20.18it/s][A[A

convert squad examples to features:   2%|▏         | 193/11873 [00:01<07:04, 27.53it/s][A[A

convert squad examples to features:   2%|▏         | 225/11873 [00:01<05:18, 36.52it/s][A[A

convert squad examples to features:   2%|▏         | 257/11873 [00:01<03:59, 48.53it/s][A[A

convert squad examples to features:   2%|▏         | 289/11873 

In [156]:
import json
from datetime import datetime
from transformers.data.processors.squad import SquadFeatures
bs1 = bs1.to(device)
bs1.eval()

start_time = datetime.now()
outputs = dict()

for i, batch in enumerate(dev_loader):
    c_q_pairs = batch[0].to(device)
    attention_mask = batch[1].to(device)
    token_type_ids = batch[2].to(device)
    indices = bs1.predict(c_q_pairs, attention_mask, token_type_ids)
    start, end = indices[0][0], indices[0][1]
    tokens = tokenizer.convert_ids_to_tokens(c_q_pairs.view(-1).tolist())#[start:end+1])

    token_to_orig_map = features[i].token_to_orig_map
    if start in token_to_orig_map and end in token_to_orig_map:
        start, end = token_to_orig_map[start], token_to_orig_map[end]
        tokens = examples[features[i].example_index].doc_tokens[start:end+1] if start <= end else ""
        answer = ' '.join(tokens)
    else:
        answer = ""    
    

    if '[CLS]' in answer:
        answer = ""

    q_id = examples[features[i].example_index].qas_id
    if q_id not in outputs:
        outputs[examples[features[i].example_index].qas_id] = answer


    if i % 100 == 0:
        print('done with example : {}'.format(i))

print('total time for dev predictions : {}'.format(datetime.now() - start_time))
with open('distilbert-dev-preds.json', 'w') as f:
    json.dump(outputs, f)

done with example : 0
done with example : 100
done with example : 200
done with example : 300
done with example : 400
done with example : 500
done with example : 600
done with example : 700
done with example : 800
done with example : 900
done with example : 1000
done with example : 1100
done with example : 1200
done with example : 1300
done with example : 1400
done with example : 1500
done with example : 1600
done with example : 1700
done with example : 1800
done with example : 1900
done with example : 2000
done with example : 2100
done with example : 2200
done with example : 2300
done with example : 2400
done with example : 2500
done with example : 2600
done with example : 2700
done with example : 2800
done with example : 2900
done with example : 3000
done with example : 3100
done with example : 3200
done with example : 3300
done with example : 3400
done with example : 3500
done with example : 3600
done with example : 3700
done with example : 3800
done with example : 3900
done with ex

# RoBERTa Model

In [0]:
import torch

from torch import nn
from transformers import RobertaModel, RobertaConfig, RobertaTokenizer

class ROBERTA_SQUAD(nn.Module):
    def __init__(self):
        super(ROBERTA_SQUAD, self).__init__()

        self.roberta_model = RobertaModel.from_pretrained('roberta-base')

        self.fc_layers = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

        self.criterion = nn.CrossEntropyLoss()



    def forward(self, c_q_pairs, attention_mask, token_type_ids, start_indices, end_indices):

        bert_encoded = self.roberta_model(
            input_ids=c_q_pairs,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        start_indices = (start_indices).clamp(0, start_outputs.shape[1]-1)
        end_indices = (end_indices).clamp(0, start_outputs.shape[1]-1)

        start_loss = self.criterion(start_outputs, start_indices)
        end_loss = self.criterion(end_outputs, end_indices)

        return 2*start_loss + end_loss


    def predict(self, c_q_pairs, attention_mask, token_type_ids):
        bert_encoded = self.roberta_model(
            input_ids=c_q_pairs,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        starts, s_ind = start_outputs.max(1)
        ends, e_ind = end_outputs.max(1)


        answers = []
        for i in range(start_outputs.shape[0]):
            start = s_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            end = e_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            answers.append([start, end])
        return answers


In [0]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

In [108]:
bs1 = ROBERTA_SQUAD()
bs1.load_state_dict(torch.load('/content/drive/My Drive/cis530project/roberta-squad.pt'))



  0%|          | 0/35 [00:00<?, ?it/s][A[A

  6%|▌         | 2/35 [00:00<00:02, 13.48it/s][A[A

 11%|█▏        | 4/35 [00:00<00:02, 14.18it/s][A[A

 17%|█▋        | 6/35 [00:00<00:02, 12.68it/s][A[A

 20%|██        | 7/35 [00:00<00:02, 10.74it/s][A[A

 23%|██▎       | 8/35 [00:00<00:02,  9.68it/s][A[A

 29%|██▊       | 10/35 [00:01<00:03,  8.17it/s][A[A

 34%|███▍      | 12/35 [00:01<00:02,  8.70it/s][A[A

 40%|████      | 14/35 [00:01<00:02,  8.78it/s][A[A

 46%|████▌     | 16/35 [00:01<00:02,  9.37it/s][A[A

 51%|█████▏    | 18/35 [00:01<00:01,  9.83it/s][A[A

 57%|█████▋    | 20/35 [00:01<00:01, 10.89it/s][A[A

 63%|██████▎   | 22/35 [00:02<00:01, 11.40it/s][A[A

 69%|██████▊   | 24/35 [00:02<00:01,  9.61it/s][A[A

 74%|███████▍  | 26/35 [00:02<00:01,  8.22it/s][A[A

 80%|████████  | 28/35 [00:02<00:00,  9.25it/s][A[A

 86%|████████▌ | 30/35 [00:03<00:00,  8.96it/s][A[A

 89%|████████▊ | 31/35 [00:03<00:00,  8.85it/s][A[A

 91%|█████████▏| 32/35

In [109]:
features, dataset = squad_convert_examples_to_features(
    examples=examples,
    tokenizer=tokenizer,
    max_seq_length=512,
    doc_stride=128,
    max_query_length=128,
    is_training=False,
    return_dataset='pt',
    threads=1
)



convert squad examples to features:   0%|          | 0/11873 [00:00<?, ?it/s][A[A

convert squad examples to features:   0%|          | 1/11873 [00:00<35:32,  5.57it/s][A[A

convert squad examples to features:   0%|          | 33/11873 [00:00<25:06,  7.86it/s][A[A

convert squad examples to features:   1%|          | 65/11873 [00:00<17:45, 11.09it/s][A[A

convert squad examples to features:   1%|          | 97/11873 [00:00<12:39, 15.50it/s][A[A

convert squad examples to features:   1%|          | 129/11873 [00:00<09:02, 21.63it/s][A[A

convert squad examples to features:   2%|▏         | 193/11873 [00:00<06:29, 30.02it/s][A[A

convert squad examples to features:   2%|▏         | 225/11873 [00:01<04:45, 40.74it/s][A[A

convert squad examples to features:   2%|▏         | 289/11873 [00:01<03:30, 54.90it/s][A[A

convert squad examples to features:   3%|▎         | 321/11873 [00:01<02:39, 72.28it/s][A[A

convert squad examples to features:   3%|▎         | 353/11873 

In [0]:
dev_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False)

In [139]:
import json
from datetime import datetime
from transformers.data.processors.squad import SquadFeatures
bs1 = bs1.to(device)
bs1.eval()

start_time = datetime.now()
outputs = dict()

for i, batch in enumerate(dev_loader):
    c_q_pairs = batch[0].to(device)
    attention_mask = batch[1].to(device)
    token_type_ids = batch[2].to(device)
    indices = bs1.predict(c_q_pairs, attention_mask, token_type_ids)
    start, end = indices[0][0], indices[0][1]
    tokens = tokenizer.convert_ids_to_tokens(c_q_pairs.view(-1).tolist())#[start:end+1])

    token_to_orig_map = features[i].token_to_orig_map
    if start in token_to_orig_map and end in token_to_orig_map:
        start, end = token_to_orig_map[start], token_to_orig_map[end]
        tokens = examples[features[i].example_index].doc_tokens[start:end+1] if start <= end else ""
        answer = ' '.join(tokens)
    else:
        answer = ""    
    

    if '<s>' in answer:
        answer = ""

    q_id = examples[features[i].example_index].qas_id
    if q_id not in outputs:
        outputs[examples[features[i].example_index].qas_id] = answer


    if i % 100 == 0:
        print('done with example : {}'.format(i))

print('total time for dev predictions : {}'.format(datetime.now() - start_time))
with open('roberta-dev-preds.json', 'w') as f:
    json.dump(outputs, f)

done with example : 0
done with example : 100
done with example : 200
done with example : 300
done with example : 400
done with example : 500
done with example : 600
done with example : 700
done with example : 800
done with example : 900
done with example : 1000
done with example : 1100
done with example : 1200
done with example : 1300
done with example : 1400
done with example : 1500
done with example : 1600
done with example : 1700
done with example : 1800
done with example : 1900
done with example : 2000
done with example : 2100
done with example : 2200
done with example : 2300
done with example : 2400
done with example : 2500
done with example : 2600
done with example : 2700
done with example : 2800
done with example : 2900
done with example : 3000
done with example : 3100
done with example : 3200
done with example : 3300
done with example : 3400
done with example : 3500
done with example : 3600
done with example : 3700
done with example : 3800
done with example : 3900
done with ex

# distilRoBERTa Model

In [0]:
import torch

from torch import nn
from transformers import RobertaModel, RobertaConfig, RobertaTokenizer

class DISTILROBERTA_SQUAD(nn.Module):
    def __init__(self):
        super(DISTILROBERTA_SQUAD, self).__init__()

        self.roberta_model = RobertaModel.from_pretrained('distilroberta-base')

        self.fc_layers = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

        self.criterion = nn.CrossEntropyLoss()



    def forward(self, c_q_pairs, attention_mask, token_type_ids, start_indices, end_indices):

        bert_encoded = self.roberta_model(
            input_ids=c_q_pairs,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        start_indices = (start_indices).clamp(0, start_outputs.shape[1]-1)
        end_indices = (end_indices).clamp(0, start_outputs.shape[1]-1)

        start_loss = self.criterion(start_outputs, start_indices)
        end_loss = self.criterion(end_outputs, end_indices)

        return 2*start_loss + end_loss


    def predict(self, c_q_pairs, attention_mask, token_type_ids):
        bert_encoded = self.roberta_model(
            input_ids=c_q_pairs,
            attention_mask=attention_mask
        )[0]

        fc_output = self.fc_layers(bert_encoded)
        start_outputs, end_outputs = fc_output[:, :, 0].squeeze(-1), fc_output[:, :, 1].squeeze(-1)

        starts, s_ind = start_outputs.max(1)
        ends, e_ind = end_outputs.max(1)


        answers = []
        for i in range(start_outputs.shape[0]):
            start = s_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            end = e_ind[i].clamp(0, start_outputs.shape[1]-1).item()
            answers.append([start, end])
        return answers


In [0]:
tokenizer = RobertaTokenizer.from_pretrained('distilroberta-base')

In [142]:
bs1 = DISTILROBERTA_SQUAD()
bs1.load_state_dict(torch.load('/content/drive/My Drive/cis530project/distilroberta-squad.pt'))

<All keys matched successfully>

In [143]:
features, dataset = squad_convert_examples_to_features(
    examples=examples,
    tokenizer=tokenizer,
    max_seq_length=512,
    doc_stride=128,
    max_query_length=128,
    is_training=False,
    return_dataset='pt',
    threads=1
)



convert squad examples to features:   0%|          | 0/11873 [00:00<?, ?it/s][A[A

convert squad examples to features:   0%|          | 1/11873 [00:00<36:33,  5.41it/s][A[A

convert squad examples to features:   0%|          | 33/11873 [00:00<25:45,  7.66it/s][A[A

convert squad examples to features:   1%|          | 65/11873 [00:00<18:15, 10.78it/s][A[A

convert squad examples to features:   1%|          | 97/11873 [00:00<13:00, 15.10it/s][A[A

convert squad examples to features:   1%|          | 129/11873 [00:00<09:17, 21.08it/s][A[A

convert squad examples to features:   1%|▏         | 161/11873 [00:00<06:39, 29.28it/s][A[A

convert squad examples to features:   2%|▏         | 193/11873 [00:00<04:52, 40.00it/s][A[A

convert squad examples to features:   2%|▏         | 225/11873 [00:01<03:37, 53.57it/s][A[A

convert squad examples to features:   2%|▏         | 257/11873 [00:01<02:43, 71.00it/s][A[A

convert squad examples to features:   2%|▏         | 289/11873 

In [0]:
dev_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False)

In [145]:
import json
from datetime import datetime
from transformers.data.processors.squad import SquadFeatures
bs1 = bs1.to(device)
bs1.eval()

start_time = datetime.now()
outputs = dict()

for i, batch in enumerate(dev_loader):
    c_q_pairs = batch[0].to(device)
    attention_mask = batch[1].to(device)
    token_type_ids = batch[2].to(device)
    indices = bs1.predict(c_q_pairs, attention_mask, token_type_ids)
    start, end = indices[0][0], indices[0][1]
    tokens = tokenizer.convert_ids_to_tokens(c_q_pairs.view(-1).tolist())#[start:end+1])

    token_to_orig_map = features[i].token_to_orig_map
    if start in token_to_orig_map and end in token_to_orig_map:
        start, end = token_to_orig_map[start], token_to_orig_map[end]
        tokens = examples[features[i].example_index].doc_tokens[start:end+1] if start <= end else ""
        answer = ' '.join(tokens)
    else:
        answer = ""    
    

    if '<s>' in answer:
        answer = ""

    q_id = examples[features[i].example_index].qas_id
    if q_id not in outputs:
        outputs[examples[features[i].example_index].qas_id] = answer


    if i % 100 == 0:
        print('done with example : {}'.format(i))

print('total time for dev predictions : {}'.format(datetime.now() - start_time))
with open('distilroberta-dev-preds.json', 'w') as f:
    json.dump(outputs, f)

done with example : 0
done with example : 100
done with example : 200
done with example : 300
done with example : 400
done with example : 500
done with example : 600
done with example : 700
done with example : 800
done with example : 900
done with example : 1000
done with example : 1100
done with example : 1200
done with example : 1300
done with example : 1400
done with example : 1500
done with example : 1600
done with example : 1700
done with example : 1800
done with example : 1900
done with example : 2000
done with example : 2100
done with example : 2200
done with example : 2300
done with example : 2400
done with example : 2500
done with example : 2600
done with example : 2700
done with example : 2800
done with example : 2900
done with example : 3000
done with example : 3100
done with example : 3200
done with example : 3300
done with example : 3400
done with example : 3500
done with example : 3600
done with example : 3700
done with example : 3800
done with example : 3900
done with ex