Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Integrate BERT into Castor framework (#17)
* Remove unused classes in models/bert * Split data4bert module into multiple processors * Refactor BERT tokenizer
- Loading branch information
1 parent
af92a55
commit e4244ec
Showing
17 changed files
with
348 additions
and
408 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,74 @@ | |||
import os | |||
|
|||
import torch | |||
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset | |||
from tqdm import tqdm | |||
|
|||
from datasets.processors.bert_processor import convert_examples_to_features, accuracy | |||
from utils.tokenization4bert import BertTokenizer | |||
|
|||
|
|||
class BertEvaluator(object): | |||
def __init__(self, model, processor, args): | |||
self.args = args | |||
self.model = model | |||
self.processor = processor | |||
self.eval_examples = self.processor.get_dev_examples(args.data_dir) | |||
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) | |||
|
|||
def evaluate(self): | |||
label_list = self.processor.get_labels() | |||
eval_features = convert_examples_to_features(self.eval_examples, | |||
label_list, | |||
self.args.max_seq_length, | |||
self.tokenizer) | |||
|
|||
print("Num. of examples =", len(self.eval_examples)) | |||
print("Batch size = %d", self.args.eval_batch_size) | |||
|
|||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) | |||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) | |||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) | |||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) | |||
|
|||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) | |||
eval_sampler = SequentialSampler(eval_data) | |||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.args.eval_batch_size) | |||
|
|||
self.model.eval() | |||
|
|||
eval_loss, eval_accuracy = 0, 0 | |||
nb_eval_steps, nb_eval_examples = 0, 0 | |||
|
|||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): | |||
input_ids = input_ids.to(self.args.device) | |||
input_mask = input_mask.to(self.args.device) | |||
segment_ids = segment_ids.to(self.args.device) | |||
label_ids = label_ids.to(self.args.device) | |||
|
|||
with torch.no_grad(): | |||
tmp_eval_loss = self.model(input_ids, segment_ids, input_mask, label_ids) | |||
logits = self.model(input_ids, segment_ids, input_mask) | |||
|
|||
logits = logits.detach().cpu().numpy() | |||
label_ids = label_ids.to('cpu').numpy() | |||
tmp_eval_accuracy = accuracy(logits, label_ids) | |||
|
|||
eval_loss += tmp_eval_loss.mean().item() | |||
eval_accuracy += tmp_eval_accuracy | |||
|
|||
nb_eval_examples += input_ids.size(0) | |||
nb_eval_steps += 1 | |||
|
|||
eval_loss = eval_loss / nb_eval_steps | |||
eval_accuracy = eval_accuracy / nb_eval_examples | |||
|
|||
result = {'eval_loss': eval_loss, | |||
'eval_accuracy': eval_accuracy} | |||
|
|||
output_eval_file = os.path.join(self.args.output_dir, "eval_results.txt") | |||
with open(output_eval_file, "w") as writer: | |||
print("***** Eval results *****") | |||
for key in sorted(result.keys()): | |||
print(" %s = %s", key, str(result[key])) | |||
writer.write("%s = %s\n" % (key, str(result[key]))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Original file line | Diff line number | Diff line change |
---|---|---|---|
@@ -0,0 +1,33 @@ | |||
import os | |||
|
|||
from datasets.processors.bert_processor import BertProcessor, InputExample | |||
|
|||
|
|||
class AAPDProcessor(BertProcessor): | |||
"""Processor for the IMDB dataset""" | |||
def get_train_examples(self, data_dir): | |||
"""See base class.""" | |||
return self._create_examples( | |||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") | |||
|
|||
def get_dev_examples(self, data_dir): | |||
"""See base class.""" | |||
return self._create_examples( | |||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") | |||
|
|||
def get_labels(self): | |||
"""See base class.""" | |||
return ["0", "1"] | |||
|
|||
def _create_examples(self, lines, set_type): | |||
"""Creates examples for the training and dev sets.""" | |||
examples = [] | |||
for (i, line) in enumerate(lines): | |||
if i == 0: | |||
continue | |||
guid = "%s-%s" % (set_type, i) | |||
text_a = line[1] | |||
label = line[0] | |||
examples.append( | |||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) | |||
return examples |
Oops, something went wrong.