In [1]:
import pprint

import numpy as np
import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.model_eval import (
    MultiEvaluator,
    OlsPValues,
    R2Score,
    ReconstructionError,
    VertexReconstructionError,
    collect_eval_results,
)
from polpo.preprocessing import (
    IndexMap,
    Map,
    NestingSwapper,
    PartiallyInitializedStep,
)
from polpo.preprocessing.load.pregnancy import (
    NeuroMaternalMeshLoader,
    NeuroMaternalTabularDataLoader,
)
from polpo.preprocessing.mesh.conversion import PvFromData, ToVertices
from polpo.preprocessing.mesh.io import FreeSurferReader
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.mri import segmtool2encoding
from polpo.preprocessing.np import ConcatenationIndices
from polpo.sklearn.adapter import AdapterFeatureUnion, AdapterPipeline, EvaluatedModel
from polpo.sklearn.preprocessing import ColumnIndexSelector
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression



In [2]:
DEBUG = 0  # >0: no alignment, 20: fewer structs

EDS.

In [3]:
pipe = NeuroMaternalTabularDataLoader(
    keep_mothers=True,
    keep_control=False,
)

tab_data = pipe()

In [4]:
eds_pipe = (
    ppd.DfFilter(lambda df: df["ses"] == 1)
    + ppd.IndexSetter("participant_id", drop=True)
    + ppd.ColumnsSelector("EDS.Total")
    + ppd.SeriesToDict()
)

eds_dict = eds_pipe(tab_data)

eds_dict;

Meshes and vector fields.

In [5]:
if DEBUG > 19:
    structs = ["L_Hipp", "R_Hipp"]
else:
    encoding = segmtool2encoding("fsl")
    structs = encoding.structs
    structs.remove("BrStem")


n_structs = len(structs)

as_dict = False
file_finder = NeuroMaternalMeshLoader(as_dict=as_dict)

Map_ = ppdict.DictMap if as_dict else Map


mesh_reader = ppdict.DictMap(Map_(FreeSurferReader() + PvFromData()))

prep_pipe = PartiallyInitializedStep(
    Step=lambda **kwargs: ppdict.DictMap(Map_(PvAlign(**kwargs))),
    # NB: aligns against first subject, t_0
    _target=lambda meshes: meshes[list(meshes.keys())[0]][0],
    max_iterations=500,
)


if DEBUG:
    # because alignment is slow
    per_struct_pipe = mesh_reader
else:
    per_struct_pipe = mesh_reader + prep_pipe


pipe = ppdict.HashWithIncoming(
    Map(
        PartiallyInitializedStep(
            Step=NeuroMaternalMeshLoader,
            as_dict=False,
            pass_data=False,
            _struct=lambda name: name.split("_")[-1],
            _left=lambda name: name.split("_")[0] == "L",
            derivative="enigma",
        )
    )
) + ppdict.DictMap(per_struct_pipe)

structs_dict = pipe(structs)

structs

['L_Thal',
 'R_Thal',
 'L_Caud',
 'R_Caud',
 'L_Puta',
 'R_Puta',
 'L_Pall',
 'R_Pall',
 'L_Hipp',
 'R_Hipp',
 'L_Amyg',
 'R_Amyg',
 'L_Accu',
 'R_Accu']

In [6]:
meshes2flat_vfields = ppdict.DictMap(
    step=Map(ToVertices()) + (lambda verts: (verts[1] - verts[0]).flatten())
)
structs2flat_vfields = ppdict.DictMap(meshes2flat_vfields)

structs_flat_vfields_dict = structs2flat_vfields(structs_dict)

len(structs_flat_vfields_dict)

14

In [7]:
dataset_pipe = (
    IndexMap(index=0, step=ppdict.NestedDictSwapper())
    + ppdict.DictMerger()
    + NestingSwapper()
    + IndexMap(index=0, step=ppdict.ListDictSwapper())
    + IndexMap(index=0, step=ppdict.DictMap(lambda x: np.stack(x)))
    + IndexMap(index=1, step=lambda x: np.stack(x)[:, None])
)


structs_flat_vfields, eds = dataset_pipe([structs_flat_vfields_dict, eds_dict])

flat_vfields = (ppdict.DictToValuesList() + (lambda x: np.concatenate(x, axis=-1)))(
    structs_flat_vfields
)

len(structs_flat_vfields), flat_vfields.shape, eds.shape

(14, (117, 81360), (117, 1))

In [8]:
pipe = ppdict.ZipWithKeys(ConcatenationIndices(axis=-1))
indices = pipe(structs_flat_vfields)

feature_union = AdapterFeatureUnion(
    [
        (
            name,
            AdapterPipeline(
                [
                    ColumnIndexSelector(start_index, end_index),
                    (
                        "transform",
                        EvaluatedModel(
                            PLSRegression(n_components=1),
                            MultiEvaluator(
                                [
                                    ReconstructionError(),
                                    VertexReconstructionError(prefix="vertex"),
                                ]
                            ),
                        ),
                    ),
                ]
            ),
        )
        for name, (start_index, end_index) in indices.items()
    ]
)

feature_union

0,1,2
,transformer_list,"[('L_Thal', ...), ('R_Thal', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,7506

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e730823b7d0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,7506
,end_index,15012

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308093530>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,15012
,end_index,22518

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73080927e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,22518
,end_index,30024

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091be0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,30024
,end_index,37530

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091820>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,37530
,end_index,45036

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091370>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,45036
,end_index,48798

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090e60>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,48798
,end_index,52560

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090d10>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,52560
,end_index,60066

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73080909e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,60066
,end_index,67572

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090560>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,67572
,end_index,71676

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090740>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,71676
,end_index,75780

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090230>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,75780
,end_index,78570

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090200>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,78570
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73083acbf0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True


In [9]:
model = EvaluatedModel(
    AdapterPipeline(
        [
            ("prep", feature_union),
            (
                "regr",
                EvaluatedModel(
                    LinearRegression(),
                    MultiEvaluator([OlsPValues(), R2Score()]),
                ),
            ),
        ]
    ),
    R2Score(),
)

model

0,1,2
,model,AdapterPipeli...gression()))])
,evaluator,<polpo.model_...x7e73083756a0>

0,1,2
,transformer_list,"[('L_Thal', ...), ('R_Thal', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,7506

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e730823b7d0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,7506
,end_index,15012

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308093530>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,15012
,end_index,22518

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73080927e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,22518
,end_index,30024

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091be0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,30024
,end_index,37530

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091820>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,37530
,end_index,45036

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091370>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,45036
,end_index,48798

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090e60>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,48798
,end_index,52560

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090d10>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,52560
,end_index,60066

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73080909e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,60066
,end_index,67572

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090560>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,67572
,end_index,71676

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090740>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,71676
,end_index,75780

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090230>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,75780
,end_index,78570

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090200>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,78570
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73083acbf0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,model,LinearRegression()
,evaluator,<polpo.model_...x7e73083770b0>

0,1,2
,fit_intercept,True
,copy_X,True
,tol,1e-06
,n_jobs,
,positive,False


In [10]:
X, y = flat_vfields, eds

model.fit(X, y)

0,1,2
,model,AdapterPipeli...gression()))])
,evaluator,<polpo.model_...x7e73083756a0>

0,1,2
,transformer_list,"[('L_Thal', ...), ('R_Thal', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,7506

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e730823b7d0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,7506
,end_index,15012

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308093530>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,15012
,end_index,22518

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73080927e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,22518
,end_index,30024

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091be0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,30024
,end_index,37530

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091820>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,37530
,end_index,45036

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308091370>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,45036
,end_index,48798

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090e60>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,48798
,end_index,52560

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090d10>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,52560
,end_index,60066

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73080909e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,60066
,end_index,67572

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090560>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,67572
,end_index,71676

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090740>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,71676
,end_index,75780

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090230>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,75780
,end_index,78570

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e7308090200>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,78570
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7e73083acbf0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,model,LinearRegression()
,evaluator,<polpo.model_...x7e73083770b0>

0,1,2
,fit_intercept,True
,copy_X,True
,tol,1e-06
,n_jobs,
,positive,False


In [11]:
eval_res = collect_eval_results(model, unnest=True, outer_key="full")

print(list(eval_res.keys()))

['full', 'full/prep/L_Thal/transform', 'full/prep/R_Thal/transform', 'full/prep/L_Caud/transform', 'full/prep/R_Caud/transform', 'full/prep/L_Puta/transform', 'full/prep/R_Puta/transform', 'full/prep/L_Pall/transform', 'full/prep/R_Pall/transform', 'full/prep/L_Hipp/transform', 'full/prep/R_Hipp/transform', 'full/prep/L_Amyg/transform', 'full/prep/R_Amyg/transform', 'full/prep/L_Accu/transform', 'full/prep/R_Accu/transform', 'full/regr']


Eval results of full pipeline.

In [12]:
eval_res["full"]

{'r2': array([0.62420661])}

Eval results of inner regression (PLS component -> EDS).

In [13]:
eval_res["full/regr"]

{'mse': array([9.03188904]),
 'res_var': array([10.36010802]),
 'std_err': array([[0.0184778 , 0.01634002, 0.00865656, 0.00755614, 0.01273044,
         0.00816827, 0.03565424, 0.01711681, 0.02052865, 0.00853369,
         0.01111987, 0.01452142, 0.01794446, 0.04185686]]),
 't': array([[-2.99686012, -2.33440686,  2.0824051 ,  2.05395643,  1.9897186 ,
         -0.91474212, -3.04647361, -1.40236645,  2.32474417,  2.18004121,
         -0.26575589,  0.81488543,  2.45572263,  3.9397964 ]]),
 'pvals': array([[3.42544680e-03, 2.15343319e-02, 3.98075053e-02, 4.25376817e-02,
         4.92988243e-02, 3.62484419e-01, 2.94759675e-03, 1.63842042e-01,
         2.20671305e-02, 3.15542139e-02, 7.90963624e-01, 4.17036487e-01,
         1.57504969e-02, 1.49478976e-04]]),
 'adj-pvals': array([[0.04795626, 0.30148065, 0.55730507, 0.59552754, 0.69018354,
         1.        , 0.04126635, 1.        , 0.30893983, 0.441759  ,
         1.        , 1.        , 0.22050696, 0.00209271]]),
 'r2': array([0.62420661])}

Eval results of PLS reconstruction.

In [14]:
for struct in structs:
    print(f"{struct}:")
    pprint.pprint(eval_res["full/regr"])

L_Thal:
{'featurewise_rec_error': array([18.31822995, 14.22728718,  8.16419056, ..., 16.63549312,
       26.16067045,  8.41252593], shape=(7506,)),
 'rec_error_mse': np.float64(0.10510561075719328),
 'rec_error_sum': np.float64(92303.95757818865),
 'vertex-rec_error_mse': np.float64(0.3153168322715798),
 'vertex-rec_error_sum': np.float64(92303.95757818864),
 'vertex-vertexwise_rec_error': array([40.70970769, 40.83687244, 40.07225108, ..., 47.35045991,
       54.5758885 , 51.2086895 ], shape=(2502,))}
R_Thal:
{'featurewise_rec_error': array([16.93917036, 20.66090719,  6.78192553, ..., 12.57705357,
       22.38586751,  8.31401847], shape=(7506,)),
 'rec_error_mse': np.float64(0.09595917337096087),
 'rec_error_sum': np.float64(84271.53797272457),
 'vertex-rec_error_mse': np.float64(0.2878775201128826),
 'vertex-rec_error_sum': np.float64(84271.53797272457),
 'vertex-vertexwise_rec_error': array([44.38200308, 48.01309585, 45.81793757, ..., 37.89910737,
       44.95512191, 43.27693955], sh