Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add roberta support #1503

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions deeppavlov/configs/squad/refactor_squad_torch_bert.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
{
"dataset_reader": {
"class_name": "squad_dataset_reader",
"data_path": "{DOWNLOADS_PATH}/squad/"
},
"dataset_iterator": {
"class_name": "squad_iterator",
"seed": 1337,
"shuffle": true
},
"chainer": {
"in": [
"context_raw",
"question_raw"
],
"in_y": [
"ans_raw",
"ans_raw_start"
],
"pipe": [
{
"class_name": "torch_squad_transformers_preprocessor",
"vocab_file": "{TRANSFORMER}",
"do_lower_case": "{LOWERCASE}",
"max_seq_length": 768,
"return_tokens": true,
"in": [
"question_raw",
"context_raw"
],
"out": [
"bert_features",
"subtokens"
]
},
{
"class_name": "squad_bert_mapping",
"do_lower_case": "{LOWERCASE}",
"in": [
"context_raw",
"bert_features",
"subtokens"
],
"out": [
"subtok2chars",
"char2subtoks"
]
},
{
"class_name": "squad_bert_ans_preprocessor",
"do_lower_case": "{LOWERCASE}",
"in": [
"ans_raw",
"ans_raw_start",
"char2subtoks"
],
"out": [
"ans",
"ans_start",
"ans_end"
]
},
{
"class_name": "torch_transformers_squad",
"pretrained_bert": "{TRANSFORMER}",
"save_path": "{MODEL_PATH}/model",
"load_path": "{MODEL_PATH}/model",
"optimizer": "AdamW",
"optimizer_parameters": {
"lr": 2e-05,
"weight_decay": 0.01,
"betas": [
0.9,
0.999
],
"eps": 1e-06
},
"learning_rate_drop_patience": 2,
"learning_rate_drop_div": 2.0,
"in": [
"bert_features"
],
"in_y": [
"ans_start",
"ans_end"
],
"out": [
"ans_start_predicted",
"ans_end_predicted",
"logits"
]
},
{
"class_name": "squad_bert_ans_postprocessor",
"in": [
"ans_start_predicted",
"ans_end_predicted",
"context_raw",
"bert_features",
"subtok2chars",
"subtokens"
],
"out": [
"ans_predicted",
"ans_start_predicted",
"ans_end_predicted"
]
}
],
"out": [
"ans_predicted",
"ans_start_predicted",
"logits"
]
},
"train": {
"show_examples": false,
"evaluation_targets": [
"valid"
],
"log_every_n_batches": 250,
"val_every_n_batches": 500,
"batch_size": 10,
"pytest_max_batches": 2,
"pytest_batch_size": 5,
"validation_patience": 10,
"metrics": [
{
"name": "squad_v1_f1",
"inputs": [
"ans",
"ans_predicted"
]
},
{
"name": "squad_v1_em",
"inputs": [
"ans",
"ans_predicted"
]
},
{
"name": "squad_v2_f1",
"inputs": [
"ans",
"ans_predicted"
]
},
{
"name": "squad_v2_em",
"inputs": [
"ans",
"ans_predicted"
]
}
],
"class_name": "torch_trainer"
},
"metadata": {
"variables": {
"LOWERCASE": true,
"TRANSFORMER": "roberta-base",
"ROOT_PATH": "~/.deeppavlov",
"DOWNLOADS_PATH": "{ROOT_PATH}/downloads",
"MODELS_PATH": "{ROOT_PATH}/models",
"MODEL_PATH": "{MODELS_PATH}/squad_torch_bert/{TRANSFORMER}"
},
"download": [
{
"url": "http://files.deeppavlov.ai/v1/squad/squad_torch_bert.tar.gz",
"subdir": "{ROOT_PATH}/models"
}
]
}
}
5 changes: 4 additions & 1 deletion deeppavlov/models/preprocessors/squad_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def __call__(self, contexts, bert_features, *args, **kwargs):
subtokens = args[0][batch_counter]
else:
subtokens = features.tokens
context_start = subtokens.index('[SEP]') + 1
if 'SEP' in subtokens:
context_start = subtokens.index('[SEP]') + 1
else:
context_start = subtokens.index('<s>') + 1
idx = 0
subtok2char: Dict[int, int] = {}
char2subtok: Dict[int, int] = {}
Expand Down
6 changes: 4 additions & 2 deletions deeppavlov/models/torch_bert/torch_transformers_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def train_on_batch(self, features: List[InputFeatures], y_st: List[List[int]], y
y_end = [x[0] for x in y_end]
b_y_st = torch.from_numpy(np.array(y_st)).to(self.device)
b_y_end = torch.from_numpy(np.array(y_end)).to(self.device)

input_ = {
'input_ids': b_input_ids,
'attention_mask': b_input_masks,
Expand Down Expand Up @@ -184,7 +184,9 @@ def __call__(self, features: List[InputFeatures]) -> Tuple[List[int], List[int],
b_input_ids = torch.cat(input_ids, dim=0).to(self.device)
b_input_masks = torch.cat(input_masks, dim=0).to(self.device)
b_input_type_ids = torch.cat(input_type_ids, dim=0).to(self.device)

if self.pretrained_bert == 'roberta-base':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like we should generalize this condition (not only to roberta-base)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

b_input_type_ids = b_input_type_ids.unsqueeze(1).expand(-1, b_input_ids.shape[-1])

input_ = {
'input_ids': b_input_ids,
'attention_mask': b_input_masks,
Expand Down