Skip to content

Commit

Permalink
Change StateVector and CovarianceMatrix to simpler ndarray wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
sdhiscocks committed Jun 20, 2018
1 parent 90ad6cc commit a477dd7
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions stonesoup/types/state.py
Expand Up @@ -7,38 +7,44 @@
from .base import Type


class StateVector(Type, np.ndarray):
class StateVector(np.ndarray):
"""State vector wrapper for :class:`numpy.ndarray`
This class returns a view to a :class:`numpy.ndarray`, but ensures that
its initialised at a *Nx1* vector.
its initialised at a *Nx1* vector. It's called same as to
:func:`numpy.asarray`.
"""
value = Property(np.ndarray, doc='Array')

def __new__(cls, value, *args, **kwargs):
array = np.array(value)
def __new__(cls, *args, **kwargs):
array = np.asarray(*args, **kwargs)
if not (array.ndim == 2 and array.shape[1] == 1):
raise ValueError(
"state vector shape should be Nx1 dimensions: got {}".format(
array.shape))
return array.view(cls)

def __array_wrap__(self, array):
return np.asarray(array)

class CovarianceMatrix(Type, np.ndarray):

class CovarianceMatrix(np.ndarray):
"""Covariance matrix wrapper for :class:`numpy.ndarray`.
This class returns a view to a :class:`numpy.ndarray`, but ensures that
its initialised at a *NxN* matrix.
its initialised at a *NxN* matrix. It's called similar to
:func:`numpy.asarray`.
"""
value = Property(np.ndarray, doc='Array')

def __new__(cls, value, *args, **kwargs):
array = np.array(value)
def __new__(cls, *args, **kwargs):
array = np.asarray(*args, **kwargs)
if not array.ndim == 2:
raise ValueError("Covariance should have ndim of 2: got {}"
"".format(array.ndim))
return array.view(cls)

def __array_wrap__(self, array):
return np.asarray(array)


class State(Type):
"""State type.
Expand All @@ -49,7 +55,7 @@ class State(Type):
state_vector = Property(StateVector, doc='State vector.')

def __init__(self, state_vector, *args, **kwargs):
state_vector.view(StateVector)
state_vector = state_vector.view(StateVector)
super().__init__(state_vector, *args, **kwargs)

@property
Expand All @@ -68,7 +74,7 @@ class GaussianState(State):
covar = Property(CovarianceMatrix, doc='Covariance matrix of state.')

def __init__(self, state_vector, covar, *args, **kwargs):
covar.view(CovarianceMatrix)
covar = covar.view(CovarianceMatrix)
super().__init__(state_vector, covar, *args, **kwargs)
if self.state_vector.shape[0] != self.covar.shape[0]:
raise ValueError(
Expand Down

0 comments on commit a477dd7

Please sign in to comment.