Skip to content

Commit

Permalink
slightly better handling of model args
Browse files Browse the repository at this point in the history
  • Loading branch information
Zsailer committed Jun 4, 2018
1 parent 43616a9 commit 1598186
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
10 changes: 9 additions & 1 deletion epistasis/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@
A library of models to decompose high-order epistasis in genotype-phenotype
maps.
"""
# Import linear models
from .linear import (EpistasisLinearRegression,
EpistasisLasso,
EpistasisRidge,
EpistasisElasticNet)

# Import nonlinear models
from .nonlinear import (EpistasisNonlinearRegression,
EpistasisPowerTransform,
EpistasisSpline)

# Import classifiers
from .classifiers import (EpistasisLogisticRegression,
EpistasisGaussianMixture)
EpistasisGaussianMixture,
EpistasisGaussianProcess)

# Import Pipeline object fro stitching models.
from .pipeline import EpistasisPipeline
30 changes: 22 additions & 8 deletions epistasis/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,13 @@ def gpm(self):

def _X(self, data=None, method=None):
"""Handle the X argument in this model."""
# Get object type.
obj = data.__class__

X = data
# If X is None, see if we saved an array.
if X is None:

# Get X from genotypes
X = genotypes_to_X(
self.gpm.wildtype,
Expand All @@ -447,7 +451,8 @@ def _X(self, data=None, method=None):
model_type=self.model_type
)

elif isinstance(X, str) and X in self.gpm.genotypes:
elif obj is str and X in self.gpm.genotypes:

# Get X from genotypes
X = genotypes_to_X(
self.gpm.wildtype,
Expand All @@ -458,15 +463,15 @@ def _X(self, data=None, method=None):
)

# If X is a keyword in Xbuilt, use it.
elif isinstance(X, str) and X in self.Xbuilt:
elif obj is str and X in self.Xbuilt:
X = self.Xbuilt[X]

# If 2-d array, keep as so.
elif isinstance(X, np.ndarray) and X.ndim == 2:
elif obj is np.ndarray and X.ndim == 2:
pass

# If list of genotypes.
elif isinstance(X, list) or isinstance(X, np.ndarray):
elif obj in [list, np.ndarray, pd.DataFrame, pd.Series]:
# Get X from genotypes
X = genotypes_to_X(
self.gpm.wildtype,
Expand All @@ -484,44 +489,53 @@ def _X(self, data=None, method=None):

def _y(self, data=None, method=None):
"""Handle y arguments in this model."""
# Get object type.
obj = data.__class__
y = data

if y is None:
return self.gpm.phenotypes

elif isinstance(y, np.ndarray) or isinstance(y, list):
elif obj in [list, np.ndarray, pd.Series, pd.DataFrame]:
return y

else:
raise Exception("y is invalid.")

def _yerr(self, data=None, method=None):
"""Handle yerr argument in this model."""
# Get object type.
obj = data.__class__
yerr = data
if yerr is None:
return self.gpm.std.upper

elif isinstance(yerr, np.ndarray) or isinstance(yerr, list):
elif obj in [list, np.ndarray, pd.Series, pd.DataFrame]:
return yerr
else:
raise Exception("yerr is invalid.")

def _thetas(self, data=None, method=None):
"""Handle yerr argument in this model."""
# Get object type.
obj = data.__class__
thetas = data
if thetas is None:
return self.thetas

elif isinstance(thetas, np.ndarray) or isinstance(thetas, list):
elif obj in [list, np.ndarray, pd.Series, pd.DataFrame]:
return thetas
else:
raise Exception("thetas is invalid.")

def _lnprior(self, data=None, method=None):
# Get object type.
obj = data.__class__
_lnprior = data
if _lnprior is None:
return np.zeros(self.gpm.n)

elif isinstance(_lnprior, np.ndarray) or isinstance(_lnprior, list):
elif obj in [list, np.ndarray, pd.Series, pd.DataFrame]:
return _lnprior
else:
raise Exceptison("_prior is invalid.")
Expand Down
1 change: 1 addition & 0 deletions epistasis/models/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .logistic import EpistasisLogisticRegression
from .gmm import EpistasisGaussianMixture
from .gaussian_process import EpistasisGaussianProcess

0 comments on commit 1598186

Please sign in to comment.