Skip to content

Commit

Permalink
increased number of samples and add back mock.patch
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus committed Jun 24, 2022
1 parent 69c3eb5 commit eb78129
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions tests/integration_tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# ==============================================================================

import os
from functools import partial
from unittest import mock

import numpy as np
import pandas as pd
import pytest
import torch

from ludwig.api import LudwigModel
from ludwig.constants import NAME, TRAINER
Expand Down Expand Up @@ -47,6 +50,7 @@ def test_binary_predictions(tmpdir, backend, distinct_values):
input_features,
output_features,
os.path.join(tmpdir, "dataset.csv"),
num_examples=100,
)
data_df = pd.read_csv(data_csv_path)

Expand All @@ -55,10 +59,9 @@ def test_binary_predictions(tmpdir, backend, distinct_values):
data_df[feature[NAME]] = data_df[feature[NAME]].map(lambda x: true_value if x else false_value)
data_df.to_csv(data_csv_path, index=False)

config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 2}}

preds_df = predict_with_backend(tmpdir, config, data_csv_path, backend)
config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 1}}

preds_df = predict_with_backend(tmpdir, config, data_csv_path, backend, num_predict_samples=len(data_df))
cols = set(preds_df.columns)
assert f"{feature[NAME]}_predictions" in cols
assert f"{feature[NAME]}_probabilities_{str(false_value)}" in cols
Expand Down Expand Up @@ -96,6 +99,7 @@ def test_binary_predictions_with_number_dtype(tmpdir, backend, distinct_values):
input_features,
output_features,
os.path.join(tmpdir, "dataset.csv"),
num_examples=100,
)
data_df = pd.read_csv(data_csv_path)

Expand All @@ -104,10 +108,9 @@ def test_binary_predictions_with_number_dtype(tmpdir, backend, distinct_values):
data_df[feature[NAME]] = data_df[feature[NAME]].map(lambda x: true_value if x else false_value)
data_df.to_csv(data_csv_path, index=False)

config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 2}}

preds_df = predict_with_backend(tmpdir, config, data_csv_path, backend)
config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 1}}

preds_df = predict_with_backend(tmpdir, config, data_csv_path, backend, num_predict_samples=len(data_df))
cols = set(preds_df.columns)
assert f"{feature[NAME]}_predictions" in cols
assert f"{feature[NAME]}_probabilities_False" in cols
Expand All @@ -128,7 +131,7 @@ def test_binary_predictions_with_number_dtype(tmpdir, backend, distinct_values):
assert np.allclose(prob_0, 1 - prob_1)


def predict_with_backend(tmpdir, config, data_csv_path, backend):
def predict_with_backend(tmpdir, config, data_csv_path, backend, num_predict_samples=None):
with init_backend(backend):
if backend == "ray":
backend = RAY_BACKEND_CONFIG
Expand All @@ -141,6 +144,19 @@ def predict_with_backend(tmpdir, config, data_csv_path, backend):
)
# Check that metadata JSON saves and loads correctly
ludwig_model = LudwigModel.load(os.path.join(output_directory, "model"))
preds_df, _ = ludwig_model.predict(dataset=data_csv_path)

# Produce an even mix of True and False predictions, as the model may be biased towards
# one direction without training
def random_logits(*args, num_predict_samples=None, **kwargs):
return torch.tensor(np.random.uniform(low=-1.0, high=1.0, size=(num_predict_samples,)))

if num_predict_samples is not None:
with mock.patch(
"ludwig.features.binary_feature.BinaryOutputFeature.logits",
partial(random_logits, num_predict_samples=num_predict_samples),
):
preds_df, _ = ludwig_model.predict(dataset=data_csv_path)
else:
preds_df, _ = ludwig_model.predict(dataset=data_csv_path)

return preds_df

0 comments on commit eb78129

Please sign in to comment.