Skip to content

Commit

Permalink
Restored LOGITS to EXCLUDE_PRED_SET, added another option to return l…
Browse files Browse the repository at this point in the history
…ogits in batch_predict.
  • Loading branch information
dantreiman committed May 16, 2022
1 parent 478d195 commit 279b60f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
23 changes: 19 additions & 4 deletions ludwig/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from tqdm import tqdm

from ludwig.constants import COMBINED, LAST_HIDDEN
from ludwig.constants import COMBINED, LAST_HIDDEN, LOGITS
from ludwig.data.dataset.base import Dataset
from ludwig.data.postprocessing import convert_to_dict
from ludwig.globals import (
Expand All @@ -26,7 +26,7 @@
from ludwig.utils.print_utils import repr_ordered_dict
from ludwig.utils.strings_utils import make_safe_filename

EXCLUDE_PRED_SET = {LAST_HIDDEN}
EXCLUDE_PRED_SET = {LOGITS, LAST_HIDDEN}
SKIP_EVAL_METRICS = {"confusion_matrix", "roc_curve"}

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -155,7 +155,22 @@ def _concat_preds(self, predictions):
# is a tensor that requires grad.
predictions[key] = torch.cat(pred_value_list, dim=0).clone().detach().cpu().numpy()

def batch_evaluation(self, dataset, collect_predictions=False, collect_labels=False, dataset_name=None):
def batch_evaluation(
self, dataset, collect_predictions=False, collect_logits=False, collect_labels=False, dataset_name=None
):
"""Batch evaluate model on dataset.
Params:
dataset (Union[str, dict, pandas.DataFrame]): source containing the entire dataset to be evaluated.
collect_predictions: Return model predictions.
collect_logits: Return model logits and final layer activations.
collect_labels: Return dataset labels in
Returns:
Tuple of dictionaries of (metrics, predictions). The keys of metrics are determined by the metrics in the
model config. The keys of the predictions dictionary depend on which values are requested by the caller:
collect_predictions, collect_logits, collect_labels.
"""
prev_model_training_mode = self.model.training # store previous model training mode
self.model.eval() # set model to eval mode

Expand Down Expand Up @@ -194,7 +209,7 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_labels=Fa
if collect_predictions:
for of_name, of_preds in preds.items():
for pred_name, pred_values in of_preds.items():
if pred_name not in EXCLUDE_PRED_SET:
if collect_logits or pred_name not in EXCLUDE_PRED_SET:
key = f"{of_name}_{pred_name}"
predictions[key].append(pred_values)
# accumulate labels from batch for each output feature
Expand Down
2 changes: 1 addition & 1 deletion ludwig/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def calibration(self, dataset, dataset_name: str, save_path: str):
return
predictor = Predictor(self.model, batch_size=self.eval_batch_size, horovod=self.horovod)
metrics, predictions = predictor.batch_evaluation(
dataset, collect_predictions=True, collect_labels=True, dataset_name=dataset_name
dataset, collect_predictions=True, collect_logits=True, collect_labels=True, dataset_name=dataset_name
)
for output_feature in self.model.output_features.values():
feature_logits_key = "%s_logits" % output_feature.feature_name
Expand Down

0 comments on commit 279b60f

Please sign in to comment.