Skip to content

Commit

Permalink
update the type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Mar 27, 2024
1 parent c5efb35 commit 73b4a4b
Show file tree
Hide file tree
Showing 14 changed files with 672 additions and 646 deletions.
74 changes: 34 additions & 40 deletions src/regmod/composite_models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base Model
"""

import logging
from copy import deepcopy
from typing import Dict, List, Optional
Expand All @@ -19,15 +20,9 @@
logger = logging.getLogger(__name__)

link_funs = {
"gaussian": fun_dict[
GaussianModel.default_param_specs["mu"]["inv_link"]
].inv_fun,
"poisson": fun_dict[
PoissonModel.default_param_specs["lam"]["inv_link"]
].inv_fun,
"binomial": fun_dict[
BinomialModel.default_param_specs["p"]["inv_link"]
].inv_fun,
"gaussian": fun_dict[GaussianModel.default_param_specs["mu"]["inv_link"]].inv_fun,
"poisson": fun_dict[PoissonModel.default_param_specs["lam"]["inv_link"]].inv_fun,
"binomial": fun_dict[BinomialModel.default_param_specs["p"]["inv_link"]].inv_fun,
}

model_constructors = {
Expand Down Expand Up @@ -86,19 +81,23 @@ class BaseModel(NodeModel):
Overwrite the append function in NodeModel.
"""

def __init__(self,
name: str,
y: str,
variables: List[Variable],
df: Optional[pd.DataFrame] = None,
weights: str = "weights",
mtype: str = "gaussian",
prior_mask: Optional[Dict] = None,
**param_specs):
def __init__(
self,
name: str,
y: str,
variables: List[Variable],
df: Optional[DataFrame] = None,
weights: str = "weights",
mtype: str = "gaussian",
prior_mask: Optional[Dict] = None,
**param_specs,
):

super().__init__(name)
if any(mtype not in model_config
for model_config in (link_funs, model_constructors)):
if any(
mtype not in model_config
for model_config in (link_funs, model_constructors)
):
raise ValueError(f"Not supported model type {mtype}")
data = deepcopy(data)
variables = list(deepcopy(variables))
Expand All @@ -108,9 +107,7 @@ def __init__(self,
self.df = df
self.weights = weights
self.variables = {v.name: v for v in variables}
self.param_specs = {"variables": variables,
"use_offset": True,
**param_specs}
self.param_specs = {"variables": variables, "use_offset": True, **param_specs}
self.model = None
self.prior_mask = {} if prior_mask is None else prior_mask

Expand Down Expand Up @@ -149,10 +146,7 @@ def fit(self, **fit_options):
if self.model is None:
model_constructor = model_constructors[self.mtype]
self.model = model_constructor(
self.y,
df=self.df,
weights=self.weights,
param_specs=self.param_specs
self.y, df=self.df, weights=self.weights, param_specs=self.param_specs
)
self.model.fit(**fit_options)
message = f"fit_node;finish;{self.level};{self.name};"
Expand All @@ -179,17 +173,19 @@ def get_draws(self, df: DataFrame = None, size: int = 1000) -> DataFrame:
pred_data = self.model.df.copy()
pred_data.attach_df(df)

coefs_draws = np.random.multivariate_normal(self.model.opt_coefs,
self.model.opt_vcov,
size=size)
draws = np.vstack([
self.model.params[0].get_param(coefs_draw, pred_data)
for coefs_draw in coefs_draws
])
df_draws = pd.DataFrame(
coefs_draws = np.random.multivariate_normal(
self.model.opt_coefs, self.model.opt_vcov, size=size
)
draws = np.vstack(
[
self.model.params[0].get_param(coefs_draw, pred_data)
for coefs_draw in coefs_draws
]
)
df_draws = DataFrame(
draws.T,
columns=[f"{self.col_value}_{i}" for i in range(size)],
index=df.index
index=df.index,
)

return pd.concat([df, df_draws], axis=1)
Expand All @@ -216,8 +212,7 @@ def get_posterior(self) -> Dict:
# use minimum standard deviation of the posterior distribution
sd = np.maximum(0.1, np.sqrt(np.diag(self.model.opt_vcov)))
vnames = [v.name for v in self.param_specs["variables"]]
slices = sizes_to_slices([self.variables[name].size
for name in vnames])
slices = sizes_to_slices([self.variables[name].size for name in vnames])
return {
name: GaussianPrior(mean=mean[slices[i]], sd=sd[slices[i]])
for i, name in enumerate(vnames)
Expand All @@ -240,8 +235,7 @@ def append(self, node: NodeModel, rank: int = 0):
primary children.
"""
if rank >= 1:
raise ValueError(f"{type(self).__name__} can only have primary "
"link.")
raise ValueError(f"{type(self).__name__} can only have primary " "link.")
super().append(node, rank=rank)

def __repr__(self) -> str:
Expand Down
82 changes: 41 additions & 41 deletions src/regmod/function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Function module
"""

from dataclasses import dataclass, field
from typing import Callable
from regmod._typing import Callable

import numpy as np

Expand Down Expand Up @@ -60,8 +61,8 @@ def exp_d2fun(x):

def expit_fun(x):
neg_indices = x < 0
z = np.exp(-np.sqrt(x*x))
y = 1/(1 + z)
z = np.exp(-np.sqrt(x * x))
y = 1 / (1 + z)
if np.isscalar(x):
if neg_indices:
y = 1 - y
Expand All @@ -71,15 +72,15 @@ def expit_fun(x):


def expit_dfun(x):
z = np.exp(-np.sqrt(x*x))
y = z/(1 + z)**2
z = np.exp(-np.sqrt(x * x))
y = z / (1 + z) ** 2
return y


def expit_d2fun(x):
neg_indices = x < 0
z = np.exp(-np.sqrt(x*x))
y = z*(z - 1)/(z + 1)**3
z = np.exp(-np.sqrt(x * x))
y = z * (z - 1) / (z + 1) ** 3
if np.isscalar(x):
if neg_indices:
y = -y
Expand All @@ -93,55 +94,54 @@ def log_fun(x):


def log_dfun(x):
return 1/x
return 1 / x


def log_d2fun(x):
return -1/x**2
return -1 / x**2


def logit_fun(x):
return np.log(x/(1 - x))
return np.log(x / (1 - x))


def logit_dfun(x):
return 1/(x*(1 - x))
return 1 / (x * (1 - x))


def logit_d2fun(x):
return (2*x - 1)/(x*(1 - x))**2
return (2 * x - 1) / (x * (1 - x)) ** 2


fun_list = [
SmoothFunction(name="identity",
fun=identity_fun,
inv_fun=identity_fun,
dfun=identity_dfun,
d2fun=identity_d2fun),
SmoothFunction(name="exp",
fun=exp_fun,
inv_fun=log_fun,
dfun=exp_dfun,
d2fun=exp_d2fun),
SmoothFunction(name="expit",
fun=expit_fun,
inv_fun=logit_fun,
dfun=expit_dfun,
d2fun=expit_d2fun),
SmoothFunction(name="log",
fun=log_fun,
inv_fun=exp_fun,
dfun=log_dfun,
d2fun=log_d2fun),
SmoothFunction(name="logit",
fun=logit_fun,
inv_fun=expit_fun,
dfun=logit_dfun,
d2fun=logit_d2fun),
SmoothFunction(
name="identity",
fun=identity_fun,
inv_fun=identity_fun,
dfun=identity_dfun,
d2fun=identity_d2fun,
),
SmoothFunction(
name="exp", fun=exp_fun, inv_fun=log_fun, dfun=exp_dfun, d2fun=exp_d2fun
),
SmoothFunction(
name="expit",
fun=expit_fun,
inv_fun=logit_fun,
dfun=expit_dfun,
d2fun=expit_d2fun,
),
SmoothFunction(
name="log", fun=log_fun, inv_fun=exp_fun, dfun=log_dfun, d2fun=log_d2fun
),
SmoothFunction(
name="logit",
fun=logit_fun,
inv_fun=expit_fun,
dfun=logit_dfun,
d2fun=logit_d2fun,
),
]


fun_dict = {
fun.name: fun
for fun in fun_list
}
fun_dict = {fun.name: fun for fun in fun_list}
Loading

0 comments on commit 73b4a4b

Please sign in to comment.