In [1]:
import pprint

import numpy as np
import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.model_eval import (
    MultiEvaluator,
    OlsPValues,
    R2Score,
    ReconstructionError,
    VertexReconstructionError,
)
from polpo.models import MultiTransform, SupervisedXEmbeddingRegressor
from polpo.preprocessing import (
    IndexMap,
    Map,
    NestingSwapper,
    PartiallyInitializedStep,
)
from polpo.preprocessing.load.pregnancy import (
    NeuroMaternalMeshLoader,
    NeuroMaternalTabularDataLoader,
)
from polpo.preprocessing.mesh.conversion import PvFromData, ToVertices
from polpo.preprocessing.mesh.io import FreeSurferReader
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.mri import segmtool2encoding
from polpo.sklearn.adapter import EvaluatedModel
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression



In [2]:
DEBUG = False

FULL_DEBUG = False and DEBUG

EDS.

In [3]:
pipe = NeuroMaternalTabularDataLoader(
    keep_mothers=True,
    keep_control=False,
)

tab_data = pipe()

In [4]:
eds_pipe = (
    ppd.DfFilter(lambda df: df["ses"] == 1)
    + ppd.IndexSetter("participant_id", drop=True)
    + ppd.ColumnsSelector("EDS.Total")
    + ppd.SeriesToDict()
)

eds_dict = eds_pipe(tab_data)

eds_dict;

Meshes and vector fields.

In [5]:
if FULL_DEBUG:
    structs = ["L_Hipp", "R_Hipp"]
else:
    encoding = segmtool2encoding("fsl")
    structs = encoding.structs
    structs.remove("BrStem")


n_structs = len(structs)

as_dict = False
file_finder = NeuroMaternalMeshLoader(as_dict=as_dict)

Map_ = ppdict.DictMap if as_dict else Map


mesh_reader = ppdict.DictMap(Map_(FreeSurferReader() + PvFromData()))

prep_pipe = PartiallyInitializedStep(
    Step=lambda **kwargs: ppdict.DictMap(Map_(PvAlign(**kwargs))),
    # NB: aligns against first subject, t_0
    _target=lambda meshes: meshes[list(meshes.keys())[0]][0],
    max_iterations=500,
)


if DEBUG:
    # because alignment is slow
    per_struct_pipe = mesh_reader
else:
    per_struct_pipe = mesh_reader + prep_pipe


pipe = ppdict.HashWithIncoming(
    Map(
        PartiallyInitializedStep(
            Step=NeuroMaternalMeshLoader,
            as_dict=False,
            pass_data=False,
            _struct=lambda name: name.split("_")[-1],
            _left=lambda name: name.split("_")[0] == "L",
            derivative="enigma",
        )
    )
) + ppdict.DictMap(per_struct_pipe)

structs_dict = pipe(structs)

structs

['L_Thal',
 'R_Thal',
 'L_Caud',
 'R_Caud',
 'L_Puta',
 'R_Puta',
 'L_Pall',
 'R_Pall',
 'L_Hipp',
 'R_Hipp',
 'L_Amyg',
 'R_Amyg',
 'L_Accu',
 'R_Accu']

In [6]:
meshes2flat_vfields = ppdict.DictMap(
    step=Map(ToVertices()) + (lambda verts: (verts[1] - verts[0]).flatten())
)
structs2flat_vfields = ppdict.DictMap(meshes2flat_vfields)

structs_flat_vfields_dict = structs2flat_vfields(structs_dict)

len(structs_flat_vfields_dict)

14

In [7]:
dataset_pipe = (
    IndexMap(index=0, step=ppdict.NestedDictSwapper())
    + ppdict.DictMerger()
    + NestingSwapper()
    + IndexMap(index=0, step=ppdict.ListDictSwapper())
    + IndexMap(index=0, step=ppdict.DictMap(lambda x: np.stack(x)))
    + IndexMap(index=1, step=lambda x: np.stack(x)[:, None])
)


structs_flat_vfields, eds = dataset_pipe([structs_flat_vfields_dict, eds_dict])

flat_vfields = (ppdict.DictToValuesList() + (lambda x: np.concatenate(x, axis=-1)))(
    structs_flat_vfields
)

len(structs_flat_vfields), flat_vfields.shape, eds.shape

(14, (117, 81360), (117, 1))

In [8]:
model = EvaluatedModel(
    SupervisedXEmbeddingRegressor(
        MultiTransform(
            [
                EvaluatedModel(
                    PLSRegression(n_components=1),
                    MultiEvaluator(
                        [
                            ReconstructionError(),
                            VertexReconstructionError(prefix="vertex"),
                        ]
                    ),
                )
                for _ in range(n_structs)
            ],
            dim=[vertices.shape[-1] for vertices in structs_flat_vfields.values()],
        ),
        EvaluatedModel(
            LinearRegression(),
            MultiEvaluator([OlsPValues(), R2Score()]),
        ),
    ),
    R2Score(),
)

model

In [9]:
X, y = flat_vfields, eds

model.fit(X, y)

Eval results of full pipeline.

In [10]:
model.eval_result_

{'r2': array([0.62420661])}

Eval results of inner regression (PLS component -> EDS).

In [11]:
model.regressor_.eval_result_

{'mse': array([9.03188904]),
 'res_var': array([10.36010802]),
 'std_err': array([[0.19154738, 0.17478454, 0.32074219, 0.40455368, 0.21538441,
         0.31696358, 0.17442715, 0.26581885, 0.17805295, 0.22533223,
         0.5376841 , 0.34254519, 0.20837282, 0.15122466]]),
 't': array([[2.99686012, 2.33440686, 2.0824051 , 2.05395643, 1.9897186 ,
         0.91474212, 3.04647361, 1.40236645, 2.32474417, 2.18004121,
         0.26575589, 0.81488543, 2.45572263, 3.9397964 ]]),
 'pvals': array([[3.42544680e-03, 2.15343319e-02, 3.98075053e-02, 4.25376817e-02,
         4.92988243e-02, 3.62484419e-01, 2.94759675e-03, 1.63842042e-01,
         2.20671305e-02, 3.15542139e-02, 7.90963624e-01, 4.17036487e-01,
         1.57504969e-02, 1.49478976e-04]]),
 'adj-pvals': array([[0.04795626, 0.30148065, 0.55730507, 0.59552754, 0.69018354,
         1.        , 0.04126635, 1.        , 0.30893983, 0.441759  ,
         1.        , 1.        , 0.22050696, 0.00209271]]),
 'r2': array([0.62420661])}

Eval results of PLS reconstruction.

In [12]:
for struct, transform in zip(structs, model.encoder_.transforms_):
    print(f"{struct}:")
    pprint.pprint(transform.eval_result_)

L_Thal:
{'featurewise_rec_error': array([18.31822995, 14.22728718,  8.16419056, ..., 16.63549312,
       26.16067045,  8.41252593], shape=(7506,)),
 'rec_error_mse': np.float64(0.10510561075719327),
 'rec_error_sum': np.float64(92303.95757818864),
 'vertex-rec_error_mse': np.float64(0.31531683227157986),
 'vertex-rec_error_sum': np.float64(92303.95757818867),
 'vertex-vertexwise_rec_error': array([40.70970769, 40.83687244, 40.07225108, ..., 47.35045991,
       54.5758885 , 51.2086895 ], shape=(2502,))}
R_Thal:
{'featurewise_rec_error': array([16.93917036, 20.66090719,  6.78192553, ..., 12.57705357,
       22.38586751,  8.31401847], shape=(7506,)),
 'rec_error_mse': np.float64(0.09595917337096085),
 'rec_error_sum': np.float64(84271.53797272456),
 'vertex-rec_error_mse': np.float64(0.28787752011288253),
 'vertex-rec_error_sum': np.float64(84271.53797272456),
 'vertex-vertexwise_rec_error': array([44.38200308, 48.01309585, 45.81793757, ..., 37.89910737,
       44.95512191, 43.27693955], 