Skip to content

Commit

Permalink
parametrize tests for proba and score_one
Browse files Browse the repository at this point in the history
  • Loading branch information
Styren committed Mar 6, 2021
1 parent 47d6ac6 commit 75c6525
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions river/compose/test_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest

from river import compose, linear_model, preprocessing
from river import compose, linear_model, preprocessing, anomaly


def test_pipeline_funcs():
Expand Down Expand Up @@ -53,29 +54,31 @@ def b(x):
assert str(pipeline) == "a + b"


def test_no_learn_unsupervised_predict_one():
@pytest.mark.parametrize("func", [compose.Pipeline.predict_one, compose.Pipeline.predict_proba_one])
def test_no_learn_unsupervised_one(func):
pipeline = compose.Pipeline(
("scale", preprocessing.StandardScaler()),
("lin_reg", linear_model.LinearRegression()),
("log_reg", linear_model.LogisticRegression()),
)

dataset = [(dict(a=x, b=x), x) for x in range(100)]

for x, y in dataset:
counts_pre = dict(pipeline.steps["scale"].counts)
pipeline.predict_one(x, learn_unsupervised=True)
func(pipeline, x, learn_unsupervised=True)
counts_post = dict(pipeline.steps["scale"].counts)
pipeline.predict_one(x, learn_unsupervised=False)
func(pipeline, x, learn_unsupervised=False)
counts_no_learn = dict(pipeline.steps["scale"].counts)

assert counts_pre != counts_post
assert counts_post == counts_no_learn


def test_no_learn_unsupervised_predict_many():
@pytest.mark.parametrize("func", [compose.Pipeline.predict_many, compose.Pipeline.predict_proba_many])
def test_no_learn_unsupervised_many(func):
pipeline = compose.Pipeline(
("scale", preprocessing.StandardScaler()),
("lin_reg", linear_model.LinearRegression()),
("log_reg", linear_model.LogisticRegression()),
)

dataset = [(dict(a=x, b=x), x) for x in range(100)]
Expand All @@ -84,9 +87,27 @@ def test_no_learn_unsupervised_predict_many():
X = pd.DataFrame([x for x, y in dataset][i : i + 5])

counts_pre = dict(pipeline.steps["scale"].counts)
pipeline.predict_many(X, learn_unsupervised=True)
func(pipeline, X, learn_unsupervised=True)
counts_post = dict(pipeline.steps["scale"].counts)
pipeline.predict_many(X, learn_unsupervised=False)
func(pipeline, X, learn_unsupervised=False)
counts_no_learn = dict(pipeline.steps["scale"].counts)

assert counts_pre != counts_post
assert counts_post == counts_no_learn

def test_no_learn_unsupervised_score_one():
pipeline = compose.Pipeline(
("scale", preprocessing.StandardScaler()),
("anomaly", anomaly.HalfSpaceTrees()),
)

dataset = [(dict(a=x, b=x), x) for x in range(100)]

for x, y in dataset:
counts_pre = dict(pipeline.steps["scale"].counts)
pipeline.score_one(x, learn_unsupervised=True)
counts_post = dict(pipeline.steps["scale"].counts)
pipeline.score_one(x, learn_unsupervised=False)
counts_no_learn = dict(pipeline.steps["scale"].counts)

assert counts_pre != counts_post
Expand Down

0 comments on commit 75c6525

Please sign in to comment.