In [2]:
from typing import List, NamedTuple, Protocol, Callable
from jaxtyping import Float, Array, Scalar


In [14]:
### --- Parameter data types ---
class ModelParams(object):
    def __init__(self, param_dict: dict):
        self._tree = param_dict

    def by_type(self):
        static = {k: self._tree[k] for k in self._static}
        return (static,)
    
    @classmethod
    def from_types(cls, static,):
        return cls({**static})

class PoseModelParams():
    n_feats: int = property(lambda self: self._tree['n_feats'])

    _static = ['n_feats']

class GMMParams(PoseModelParams):
    n_components: int = property(lambda self: self._tree['n_components'])

    _static = PoseModelParams._static + ['n_components']

class JointModelParams(object):
    def __init__(self, pose, morph):
        self._pose = pose
        self._morph = morph

    def by_type(self):
        pose_dicts = self._pose.by_type()
        morph_dicts = self._morph.by_type()
        return tuple({'pose': pose_dicts[i],
                      'morph': morph_dicts[i]})

    def from_types(model: JointModel, static: dict,):
        return JointModelParams(
            model.pose.ParamClass.from_types(static['pose']),
            model.morph.ParamClass.from_types(static['morph']))

SyntaxError: invalid syntax (18468965.py, line 35)

In [8]:
params = GMMParams({'n_components': 3, 'n_feats': 2})

In [13]:
type(params), params.__class__

(__main__.GMMParams, __main__.GMMParams)

In [None]:
# --- Model types ---
class PoseModel(Protocol):
    ParamClass: type
    load_params: Callable
    pose_logprob: Callable
    aux_distribution: Callable
    construct_params: Callable
    init_hyperparams: Callable
    init_params: Callable
    log_prior: Callable
    reports: Callable
    sample_hyperparams: Callable
    sample_params: Callable
    sample_poses: Callable
    
    

### Scan / model fitting code

In [1]:
setup_scan = None
run_scan = None
update_config = None
kpsn_project = None
setup_model = None

project_dir = None

In [None]:
setup_scan(
    project_dir / "scans" / "scan0",
    hyperparam_values = {
        "morph.variance_prior": [1e-2, 1e-1, 1e0, 1e1, 1e2],
    })
update_config(
    project_dir / "scans" / "scan0" / "base_model_config.yml",
    {"hyperparams.pose.n_components": 3,}
)
run_scan(project_dir / "scans" / "scan0")

In [None]:
paths = kpsn_project(project_dir)

setup_scan(
    paths.scan("scan0"),
    {"hyperparams.morph.variance_prior": [1e-2, 1e-1, 1e0, 1e1, 1e2],})
update_config(
    paths.scan("scan0") / "base_model_config.yml",
    {"hyperparams.pose.n_components": 3,}
)
run_scan(
    paths.scan("scan0"),
    paths.models)

In [None]:
# defaults contain: model.pose: gmm, model.morph: affine
setup_model(
    paths.model("model0"),
    config = {"hyperparams.pose.n_components": 3}
)