Skip to content

Commit

Permalink
make tests work with current source code state
Browse files Browse the repository at this point in the history
revert this commit when merging into/with python-control#431

(remove statesp_test.py::test_copy_constructor_nodt if not applicable)
  • Loading branch information
bnavigator committed Dec 29, 2020
1 parent 2082796 commit abb9940
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 166 deletions.
75 changes: 22 additions & 53 deletions control/tests/discrete_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import numpy as np
import pytest

from control import (StateSpace, TransferFunction, bode, common_timebase,
evalfr, feedback, forced_response, impulse_response,
isctime, isdtime, rss, sample_system, step_response,
timebase)
from control import StateSpace, TransferFunction, feedback, step_response, \
isdtime, timebase, isctime, sample_system, bode, impulse_response, \
evalfr, timebaseEqual, forced_response, rss


class TestDiscrete:
Expand Down Expand Up @@ -52,21 +51,13 @@ class Tsys:

return T

def testCompatibleTimebases(self, tsys):
"""test that compatible timebases don't throw errors and vice versa"""
common_timebase(tsys.siso_ss1.dt, tsys.siso_tf1.dt)
common_timebase(tsys.siso_ss1.dt, tsys.siso_ss1c.dt)
common_timebase(tsys.siso_ss1d.dt, tsys.siso_ss1.dt)
common_timebase(tsys.siso_ss1.dt, tsys.siso_ss1d.dt)
common_timebase(tsys.siso_ss1.dt, tsys.siso_ss1d.dt)
common_timebase(tsys.siso_ss1d.dt, tsys.siso_ss3d.dt)
common_timebase(tsys.siso_ss3d.dt, tsys.siso_ss1d.dt)
with pytest.raises(ValueError):
# cont + discrete
common_timebase(tsys.siso_ss1d.dt, tsys.siso_ss1c.dt)
with pytest.raises(ValueError):
# incompatible discrete
common_timebase(tsys.siso_ss1d.dt, tsys.siso_ss2d.dt)
def testTimebaseEqual(self, tsys):
"""Test for equal timebases and not so equal ones"""
assert timebaseEqual(tsys.siso_ss1, tsys.siso_tf1)
assert timebaseEqual(tsys.siso_ss1, tsys.siso_ss1c)
assert not timebaseEqual(tsys.siso_ss1d, tsys.siso_ss1c)
assert not timebaseEqual(tsys.siso_ss1d, tsys.siso_ss2d)
assert not timebaseEqual(tsys.siso_ss1d, tsys.siso_ss3d)

def testSystemInitialization(self, tsys):
# Check to make sure systems are discrete time with proper variables
Expand All @@ -84,18 +75,6 @@ def testSystemInitialization(self, tsys):
assert tsys.siso_tf2d.dt == 0.2
assert tsys.siso_tf3d.dt is True

# keyword argument check
# dynamic systems
assert TransferFunction(1, [1, 1], dt=0.1).dt == 0.1
assert TransferFunction(1, [1, 1], 0.1).dt == 0.1
assert StateSpace(1,1,1,1, dt=0.1).dt == 0.1
assert StateSpace(1,1,1,1, 0.1).dt == 0.1
# static gain system, dt argument should still override default dt
assert TransferFunction(1, [1,], dt=0.1).dt == 0.1
assert TransferFunction(1, [1,], 0.1).dt == 0.1
assert StateSpace(0,0,1,1, dt=0.1).dt == 0.1
assert StateSpace(0,0,1,1, 0.1).dt == 0.1

def testCopyConstructor(self, tsys):
for sys in (tsys.siso_ss1, tsys.siso_ss1c, tsys.siso_ss1d):
newsys = StateSpace(sys)
Expand Down Expand Up @@ -135,7 +114,6 @@ def test_timebase_conversions(self, tsys):
assert timebase(tf1*tf2) == timebase(tf2)
assert timebase(tf1*tf3) == timebase(tf3)
assert timebase(tf1*tf4) == timebase(tf4)
assert timebase(tf3*tf4) == timebase(tf4)
assert timebase(tf2*tf1) == timebase(tf2)
assert timebase(tf3*tf1) == timebase(tf3)
assert timebase(tf4*tf1) == timebase(tf4)
Expand All @@ -150,36 +128,33 @@ def test_timebase_conversions(self, tsys):

# Make sure discrete time without sampling is converted correctly
assert timebase(tf3*tf3) == timebase(tf3)
assert timebase(tf3*tf4) == timebase(tf4)
assert timebase(tf3+tf3) == timebase(tf3)
assert timebase(tf3+tf4) == timebase(tf4)
assert timebase(feedback(tf3, tf3)) == timebase(tf3)
assert timebase(feedback(tf3, tf4)) == timebase(tf4)

# Make sure all other combinations are errors
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf2 * tf3
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf3 * tf2
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf2 * tf4
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf4 * tf2
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf2 + tf3
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf3 + tf2
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf2 + tf4
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
tf4 + tf2
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
feedback(tf2, tf3)
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
feedback(tf3, tf2)
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
feedback(tf2, tf4)
with pytest.raises(ValueError, match="incompatible timebases"):
with pytest.raises(ValueError, match="different sampling times"):
feedback(tf4, tf2)

def testisdtime(self, tsys):
Expand Down Expand Up @@ -237,7 +212,6 @@ def testAddition(self, tsys):
sys = tsys.siso_ss1c + tsys.siso_ss1c
sys = tsys.siso_ss1d + tsys.siso_ss1d
sys = tsys.siso_ss3d + tsys.siso_ss3d
sys = tsys.siso_ss1d + tsys.siso_ss3d

with pytest.raises(ValueError):
StateSpace.__add__(tsys.mimo_ss1c, tsys.mimo_ss1d)
Expand All @@ -252,7 +226,6 @@ def testAddition(self, tsys):
sys = tsys.siso_tf1c + tsys.siso_tf1c
sys = tsys.siso_tf1d + tsys.siso_tf1d
sys = tsys.siso_tf2d + tsys.siso_tf2d
sys = tsys.siso_tf1d + tsys.siso_tf3d

with pytest.raises(ValueError):
TransferFunction.__add__(tsys.siso_tf1c, tsys.siso_tf1d)
Expand All @@ -275,7 +248,6 @@ def testMultiplication(self, tsys):
sys = tsys.siso_ss1d * tsys.siso_ss1
sys = tsys.siso_ss1c * tsys.siso_ss1c
sys = tsys.siso_ss1d * tsys.siso_ss1d
sys = tsys.siso_ss1d * tsys.siso_ss3d

with pytest.raises(ValueError):
StateSpace.__mul__(tsys.mimo_ss1c, tsys.mimo_ss1d)
Expand All @@ -289,7 +261,6 @@ def testMultiplication(self, tsys):
sys = tsys.siso_tf1d * tsys.siso_tf1
sys = tsys.siso_tf1c * tsys.siso_tf1c
sys = tsys.siso_tf1d * tsys.siso_tf1d
sys = tsys.siso_tf1d * tsys.siso_tf3d

with pytest.raises(ValueError):
TransferFunction.__mul__(tsys.siso_tf1c, tsys.siso_tf1d)
Expand All @@ -314,7 +285,6 @@ def testFeedback(self, tsys):
sys = feedback(tsys.siso_ss1d, tsys.siso_ss1)
sys = feedback(tsys.siso_ss1c, tsys.siso_ss1c)
sys = feedback(tsys.siso_ss1d, tsys.siso_ss1d)
sys = feedback(tsys.siso_ss1d, tsys.siso_ss3d)

with pytest.raises(ValueError):
feedback(tsys.mimo_ss1c, tsys.mimo_ss1d)
Expand All @@ -328,7 +298,6 @@ def testFeedback(self, tsys):
sys = feedback(tsys.siso_tf1d, tsys.siso_tf1)
sys = feedback(tsys.siso_tf1c, tsys.siso_tf1c)
sys = feedback(tsys.siso_tf1d, tsys.siso_tf1d)
sys = feedback(tsys.siso_tf1d, tsys.siso_tf3d)

with pytest.raises(ValueError):
feedback(tsys.siso_tf1c, tsys.siso_tf1d)
Expand Down
83 changes: 1 addition & 82 deletions control/tests/lti_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from control import c2d, tf, tf2ss, NonlinearIOSystem
from control.lti import (LTI, common_timebase, damp, dcgain, isctime, isdtime,
from control.lti import (LTI, damp, dcgain, isctime, isdtime,
issiso, pole, timebaseEqual, zero)
from control.tests.conftest import slycotonly

Expand Down Expand Up @@ -72,84 +72,3 @@ def test_dcgain(self):
sys = tf(84, [1, 2])
np.testing.assert_equal(sys.dcgain(), 42)
np.testing.assert_equal(dcgain(sys), 42)

@pytest.mark.parametrize("dt1, dt2, expected",
[(None, None, True),
(None, 0, True),
(None, 1, True),
pytest.param(None, True, True,
marks=pytest.mark.xfail(
reason="returns false")),
(0, 0, True),
(0, 1, False),
(0, True, False),
(1, 1, True),
(1, 2, False),
(1, True, False),
(True, True, True)])
def test_timebaseEqual_deprecated(self, dt1, dt2, expected):
"""Test that timbaseEqual throws a warning and returns as documented"""
sys1 = tf([1], [1, 2, 3], dt1)
sys2 = tf([1], [1, 4, 5], dt2)

print(sys1.dt)
print(sys2.dt)

with pytest.deprecated_call():
assert timebaseEqual(sys1, sys2) is expected
# Make sure behaviour is symmetric
with pytest.deprecated_call():
assert timebaseEqual(sys2, sys1) is expected

@pytest.mark.parametrize("dt1, dt2, expected",
[(None, None, None),
(None, 0, 0),
(None, 1, 1),
(None, True, True),
(True, True, True),
(True, 1, 1),
(1, 1, 1),
(0, 0, 0),
])
@pytest.mark.parametrize("sys1", [True, False])
@pytest.mark.parametrize("sys2", [True, False])
def test_common_timebase(self, dt1, dt2, expected, sys1, sys2):
"""Test that common_timbase adheres to :ref:`conventions-ref`"""
i1 = tf([1], [1, 2, 3], dt1) if sys1 else dt1
i2 = tf([1], [1, 4, 5], dt2) if sys2 else dt2
assert common_timebase(i1, i2) == expected
# Make sure behaviour is symmetric
assert common_timebase(i2, i1) == expected

@pytest.mark.parametrize("i1, i2",
[(True, 0),
(0, 1),
(1, 2)])
def test_common_timebase_errors(self, i1, i2):
"""Test that common_timbase throws errors on invalid combinations"""
with pytest.raises(ValueError):
common_timebase(i1, i2)
# Make sure behaviour is symmetric
with pytest.raises(ValueError):
common_timebase(i2, i1)

@pytest.mark.parametrize("dt, ref, strictref",
[(None, True, False),
(0, False, False),
(1, True, True),
(True, True, True)])
@pytest.mark.parametrize("objfun, arg",
[(LTI, ()),
(NonlinearIOSystem, (lambda x: x, ))])
def test_isdtime(self, objfun, arg, dt, ref, strictref):
"""Test isdtime and isctime functions to follow convention"""
obj = objfun(*arg, dt=dt)

assert isdtime(obj) == ref
assert isdtime(obj, strict=True) == strictref

if dt is not None:
ref = not ref
strictref = not strictref
assert isctime(obj) == ref
assert isctime(obj, strict=True) == strictref
34 changes: 3 additions & 31 deletions control/tests/statesp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,13 @@ def sys623(self):

@pytest.mark.parametrize(
"dt",
[(), (None, ), (0, ), (1, ), (0.1, ), (True, )],
[(None, ), (0, ), (1, ), (0.1, ), (True, )],
ids=lambda i: "dt " + ("unspec" if len(i) == 0 else str(i[0])))
@pytest.mark.parametrize(
"argfun",
[pytest.param(
lambda ABCDdt: (ABCDdt, {}),
id="A, B, C, D[, dt]"),
pytest.param(
lambda ABCDdt: (ABCDdt[:4], {'dt': dt_ for dt_ in ABCDdt[4:]}),
id="A, B, C, D[, dt=dt]"),
pytest.param(
lambda ABCDdt: ((StateSpace(*ABCDdt), ), {}),
id="sys")
Expand All @@ -109,7 +106,7 @@ def test_constructor(self, sys322ABCD, dt, argfun):
@pytest.mark.parametrize("args, exc, errmsg",
[((True, ), TypeError,
"(can only take in|sys must be) a StateSpace"),
((1, 2), ValueError, "1, 4, or 5 arguments"),
((1, 2), ValueError, "1 or 4 arguments"),
((np.ones((3, 2)), np.ones((3, 2)),
np.ones((2, 2)), np.ones((2, 2))),
ValueError, "A must be square"),
Expand All @@ -133,16 +130,6 @@ def test_constructor_invalid(self, args, exc, errmsg):
with pytest.raises(exc, match=errmsg):
ss(*args)

def test_constructor_warns(self, sys322ABCD):
"""Test ambiguos input to StateSpace() constructor"""
with pytest.warns(UserWarning, match="received multiple dt"):
sys = StateSpace(*(sys322ABCD + (0.1, )), dt=0.2)
np.testing.assert_almost_equal(sys.A, sys322ABCD[0])
np.testing.assert_almost_equal(sys.B, sys322ABCD[1])
np.testing.assert_almost_equal(sys.C, sys322ABCD[2])
np.testing.assert_almost_equal(sys.D, sys322ABCD[3])
assert sys.dt == 0.1

def test_copy_constructor(self):
"""Test the copy constructor"""
# Create a set of matrices for a simple linear system
Expand All @@ -164,22 +151,6 @@ def test_copy_constructor(self):
linsys.A[0, 0] = -3
np.testing.assert_array_equal(cpysys.A, [[-1]]) # original value

def test_copy_constructor_nodt(self, sys322):
"""Test the copy constructor when an object without dt is passed
FIXME: may be obsolete in case gh-431 is updated
"""
sysin = sample_system(sys322, 1.)
del sysin.dt
sys = StateSpace(sysin)
assert sys.dt == defaults['control.default_dt']

# test for static gain
sysin = StateSpace([], [], [], [[1, 2], [3, 4]], 1.)
del sysin.dt
sys = StateSpace(sysin)
assert sys.dt is None

def test_matlab_style_constructor(self):
"""Use (deprecated) matrix-style construction string"""
with pytest.deprecated_call():
Expand Down Expand Up @@ -382,6 +353,7 @@ def test_freq_resp(self):
np.testing.assert_almost_equal(phase, true_phase)
np.testing.assert_equal(omega, true_omega)

@pytest.mark.skip("is_static_gain is introduced in gh-431")
def test_is_static_gain(self):
A0 = np.zeros((2,2))
A1 = A0.copy()
Expand Down
1 change: 1 addition & 0 deletions control/tests/xferfcn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def test_evalfr_siso(self, dt, omega, resp):
resp,
atol=1e-3)

@pytest.mark.skip("is_static_gain is introduced in gh-431")
def test_is_static_gain(self):
numstatic = 1.1
denstatic = 1.2
Expand Down

0 comments on commit abb9940

Please sign in to comment.