diff --git a/pytext/task/tasks.py b/pytext/task/tasks.py index 307be3249..f86f5da5a 100644 --- a/pytext/task/tasks.py +++ b/pytext/task/tasks.py @@ -199,6 +199,15 @@ class Config(NewTask.Config): # for multi-label classification task, # choose MultiLabelClassificationMetricReporter + @classmethod + def format_prediction(cls, predictions, scores, context, target_names): + for prediction, score in zip(predictions, scores): + score_with_name = {n: s for n, s in zip(target_names, score.tolist())} + yield { + "prediction": target_names[prediction.data], + "score": score_with_name, + } + class DocumentRegressionTask(NewTask): class Config(NewTask.Config):