Skip to content
Permalink
Browse files

add storing prediction/qrel files option for test set (#125)

* add prediction/qrel files dump option

* fix comment
  • Loading branch information...
Victor0118 committed Jun 19, 2018
1 parent 5cc027f commit 921a45e7caff0552f0e11a29041ba4e49abc4d92
@@ -10,3 +10,5 @@ trec_eval-9.0.5
*.pt
text/
kim_cnn/data
.results
.qrel
@@ -22,7 +22,7 @@ class EvaluatorFactory(object):
}

@staticmethod
def get_evaluator(dataset_cls, model, embedding, data_loader, batch_size, device, nce=False):
def get_evaluator(dataset_cls, model, embedding, data_loader, batch_size, device, nce=False, keep_results=False):
if data_loader is None:
return None

@@ -38,5 +38,5 @@ def get_evaluator(dataset_cls, model, embedding, data_loader, batch_size, device
raise ValueError('{} is not implemented.'.format(dataset_cls))

return evaluator_map[dataset_cls.NAME](
dataset_cls, model, embedding, data_loader, batch_size, device
dataset_cls, model, embedding, data_loader, batch_size, device, keep_results
)
@@ -3,13 +3,14 @@ class Evaluator(object):
Evaluates a model on a Dataset, using metrics specific to the Dataset.
"""

def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device):
def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False):
self.dataset_cls = dataset_cls
self.model = model
self.embedding = embedding
self.data_loader = data_loader
self.batch_size = batch_size
self.device = device
self.keep_results = keep_results

def get_sentence_embeddings(self, batch):
sent1 = self.embedding(batch.sentence_1).transpose(1, 2)
@@ -28,7 +28,9 @@ def get_scores(self):

qids = list(map(lambda n: int(round(n * 10, 0)) / 10, qids))

mean_average_precision, mean_reciprocal_rank = get_map_mrr(qids, predictions, true_labels, self.data_loader.device)
mean_average_precision, mean_reciprocal_rank = get_map_mrr(qids, predictions, true_labels,
self.data_loader.device,
keep_results=self.keep_results)
test_cross_entropy_loss /= len(batch.dataset.examples)

return [mean_average_precision, mean_reciprocal_rank, test_cross_entropy_loss], ['map', 'mrr', 'cross entropy loss']
@@ -29,8 +29,9 @@ def get_logger():
return logger


def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device):
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device)
def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device, keep_results=False):
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device,
keep_results=keep_results)
scores, metric_names = saved_model_evaluator.get_scores()
logger.info('Evaluation metrics for {}'.format(split_name))
logger.info('\t'.join([' '] + metric_names))
@@ -79,6 +80,9 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
parser.add_argument('--tensorboard', action='store_true', default=False,
help='use TensorBoard to visualize training (default: false)')
parser.add_argument('--run-label', type=str, help='label to describe run')
parser.add_argument('--keep-results', action='store_true',
help='store the output score and qrel files into disk for the test set')

args = parser.parse_args()

device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() and args.device >= 0 else 'cpu')
@@ -114,9 +118,12 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
else:
raise ValueError('optimizer not recognized: it should be either adam or sgd')

train_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, train_loader, args.batch_size, args.device)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, test_loader, args.batch_size, args.device)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, dev_loader, args.batch_size, args.device)
train_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, train_loader, args.batch_size,
args.device)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, test_loader, args.batch_size,
args.device)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, dev_loader, args.batch_size,
args.device)

trainer_config = {
'optimizer': optimizer,
@@ -147,4 +154,4 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
model.load_state_dict(state_dict)
if dev_loader:
evaluate_dataset('dev', dataset_cls, model, embedding, dev_loader, args.batch_size, args.device)
evaluate_dataset('test', dataset_cls, model, embedding, test_loader, args.batch_size, args.device)
evaluate_dataset('test', dataset_cls, model, embedding, test_loader, args.batch_size, args.device, args.keep_results)
@@ -3,5 +3,5 @@

class TRECQAEvaluatorNCE(QAEvaluator):

def __init__(self, dataset_cls, model, data_loader, batch_size, device):
super(TRECQAEvaluatorNCE, self).__init__(dataset_cls, model, data_loader, batch_size, device)
def __init__(self, dataset_cls, model, data_loader, batch_size, device, keep_results=False):
super(TRECQAEvaluatorNCE, self).__init__(dataset_cls, model, data_loader, batch_size, device, keep_results)
@@ -3,5 +3,5 @@

class WikiQAEvaluatorNCE(QAEvaluator):

def __init__(self, dataset_cls, model, data_loader, batch_size, device):
super(WikiQAEvaluatorNCE, self).__init__(dataset_cls, model, data_loader, batch_size, device)
def __init__(self, dataset_cls, model, data_loader, batch_size, device, keep_results=False):
super(WikiQAEvaluatorNCE, self).__init__(dataset_cls, model, data_loader, batch_size, device, keep_results)
@@ -3,7 +3,7 @@
import time


def get_map_mrr(qids, predictions, labels, device=0):
def get_map_mrr(qids, predictions, labels, device=0, keep_results=False):
"""
Get the map and mrr using the trec_eval utility.
qids, predictions, labels should have the same length.
@@ -30,7 +30,11 @@ def get_map_mrr(qids, predictions, labels, device=0):
mean_average_precision = float(trec_out_lines[0].split('\t')[-1])
mean_reciprocal_rank = float(trec_out_lines[1].split('\t')[-1])

os.remove(qrel_fname)
os.remove(results_fname)
if keep_results:
print("Saving prediction file to {}".format(results_fname))
print("Saving qrel file to {}".format(qrel_fname))
else:
os.remove(results_fname)
os.remove(qrel_fname)

return mean_average_precision, mean_reciprocal_rank

0 comments on commit 921a45e

Please sign in to comment.
You can’t perform that action at this time.