In [1]:
import pprint

import numpy as np
import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.model_eval import (
    MultiEvaluator,
    OlsPValues,
    R2Score,
    ReconstructionError,
    ShapeCollector,
    VertexReconstructionError,
    collect_eval_results,
)
from polpo.preprocessing import (
    IndexMap,
    Map,
    NestingSwapper,
    PartiallyInitializedStep,
)
from polpo.preprocessing.load.pregnancy import (
    NeuroMaternalMeshLoader,
    NeuroMaternalTabularDataLoader,
)
from polpo.preprocessing.mesh.conversion import PvFromData, ToVertices
from polpo.preprocessing.mesh.io import FreeSurferReader
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.mri import segmtool2encoding
from polpo.preprocessing.np import ConcatenationIndices
from polpo.sklearn.adapter import AdapterFeatureUnion, AdapterPipeline, EvaluatedModel
from polpo.sklearn.preprocessing import ColumnIndexSelector
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split



In [2]:
DEBUG = 0  # >0: no alignment, 20: fewer structs

EDS.

In [3]:
pipe = NeuroMaternalTabularDataLoader(
    keep_mothers=True,
    keep_control=False,
)

tab_data = pipe()

In [4]:
eds_pipe = (
    ppd.DfFilter(lambda df: df["ses"] == 1)
    + ppd.IndexSetter("participant_id", drop=True)
    + ppd.ColumnsSelector("EDS.Total")
    + ppd.SeriesToDict()
)

eds_dict = eds_pipe(tab_data)

Meshes and vector fields.

In [5]:
if DEBUG > 19:
    structs = ["L_Hipp", "R_Hipp"]
else:
    encoding = segmtool2encoding("fsl")
    structs = encoding.structs
    structs.remove("BrStem")


n_structs = len(structs)

as_dict = False
file_finder = NeuroMaternalMeshLoader(as_dict=as_dict)

Map_ = ppdict.DictMap if as_dict else Map


mesh_reader = ppdict.DictMap(Map_(FreeSurferReader() + PvFromData()))

prep_pipe = PartiallyInitializedStep(
    Step=lambda **kwargs: ppdict.DictMap(Map_(PvAlign(**kwargs))),
    # NB: aligns against first subject, t_0
    _target=lambda meshes: meshes[list(meshes.keys())[0]][0],
    max_iterations=500,
)


if DEBUG:
    # because alignment is slow
    per_struct_pipe = mesh_reader
else:
    per_struct_pipe = mesh_reader + prep_pipe


pipe = ppdict.HashWithIncoming(
    Map(
        PartiallyInitializedStep(
            Step=NeuroMaternalMeshLoader,
            as_dict=False,
            pass_data=False,
            _struct=lambda name: name.split("_")[-1],
            _left=lambda name: name.split("_")[0] == "L",
            derivative="enigma",
        )
    )
) + ppdict.DictMap(per_struct_pipe)

structs_dict = pipe(structs)

structs

['L_Thal',
 'R_Thal',
 'L_Caud',
 'R_Caud',
 'L_Puta',
 'R_Puta',
 'L_Pall',
 'R_Pall',
 'L_Hipp',
 'R_Hipp',
 'L_Amyg',
 'R_Amyg',
 'L_Accu',
 'R_Accu']

In [6]:
meshes2flat_vfields = ppdict.DictMap(
    step=Map(ToVertices()) + (lambda verts: (verts[1] - verts[0]).flatten())
)
structs2flat_vfields = ppdict.DictMap(meshes2flat_vfields)

structs_flat_vfields_dict = structs2flat_vfields(structs_dict)

len(structs_flat_vfields_dict)

14

In [7]:
dataset_pipe = (
    IndexMap(index=0, step=ppdict.NestedDictSwapper())
    + ppdict.DictMerger()
    + NestingSwapper()
    + IndexMap(index=0, step=ppdict.ListDictSwapper())
    + IndexMap(index=0, step=ppdict.DictMap(lambda x: np.stack(x)))
    + IndexMap(index=1, step=lambda x: np.stack(x)[:, None])
)


structs_flat_vfields, eds = dataset_pipe([structs_flat_vfields_dict, eds_dict])

flat_vfields = (ppdict.DictToValuesList() + (lambda x: np.concatenate(x, axis=-1)))(
    structs_flat_vfields
)

len(structs_flat_vfields), flat_vfields.shape, eds.shape

(14, (117, 81360), (117, 1))

In [8]:
pipe = ppdict.ZipWithKeys(ConcatenationIndices(axis=-1))
indices = pipe(structs_flat_vfields)

feature_union = AdapterFeatureUnion(
    [
        (
            name,
            AdapterPipeline(
                [
                    ColumnIndexSelector(start_index, end_index),
                    (
                        "transform",
                        EvaluatedModel(
                            PLSRegression(n_components=1),
                            MultiEvaluator(
                                [
                                    ReconstructionError(),
                                    VertexReconstructionError(prefix="vertex"),
                                    ShapeCollector(),
                                ]
                            ),
                        ),
                    ),
                ]
            ),
        )
        for name, (start_index, end_index) in indices.items()
    ]
)

feature_union

0,1,2
,transformer_list,"[('L_Thal', ...), ('R_Thal', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,7506

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb776b3fb0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,7506
,end_index,15012

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589f620>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,15012
,end_index,22518

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589e810>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,22518
,end_index,30024

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589e450>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,30024
,end_index,37530

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589df10>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,37530
,end_index,45036

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589ddf0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,45036
,end_index,48798

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589d760>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,48798
,end_index,52560

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589d1f0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,52560
,end_index,60066

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589d550>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,60066
,end_index,67572

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589cf50>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,67572
,end_index,71676

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589cf80>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,71676
,end_index,75780

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589ccb0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,75780
,end_index,78570

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589c680>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,78570
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589c740>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True


In [9]:
model = EvaluatedModel(
    AdapterPipeline(
        [
            ("prep", feature_union),
            (
                "regr",
                EvaluatedModel(
                    LinearRegression(),
                    MultiEvaluator([OlsPValues(), R2Score(), ShapeCollector()]),
                ),
            ),
        ]
    ),
    MultiEvaluator([R2Score(), ShapeCollector()]),
)

model

0,1,2
,model,AdapterPipeli...gression()))])
,evaluator,<polpo.model_...x7cbb75b48170>

0,1,2
,transformer_list,"[('L_Thal', ...), ('R_Thal', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,7506

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb776b3fb0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,7506
,end_index,15012

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589f620>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,15012
,end_index,22518

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589e810>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,22518
,end_index,30024

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589e450>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,30024
,end_index,37530

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589df10>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,37530
,end_index,45036

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589ddf0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,45036
,end_index,48798

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589d760>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,48798
,end_index,52560

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589d1f0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,52560
,end_index,60066

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589d550>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,60066
,end_index,67572

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589cf50>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,67572
,end_index,71676

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589cf80>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,71676
,end_index,75780

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589ccb0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,75780
,end_index,78570

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589c680>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,78570
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x7cbb7589c740>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,model,LinearRegression()
,evaluator,<polpo.model_...x7cbb75b49b50>

0,1,2
,fit_intercept,True
,copy_X,True
,tol,1e-06
,n_jobs,
,positive,False


In [10]:
X, y = flat_vfields, eds

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)


model.fit(X_train, y_train);

In [11]:
model.predict_eval(X_test, y_test);

In [12]:
eval_res_train = collect_eval_results(model, unnest=True, outer_key="full")

print(list(eval_res_train.keys()))

['full', 'full/prep/L_Thal/transform', 'full/prep/R_Thal/transform', 'full/prep/L_Caud/transform', 'full/prep/R_Caud/transform', 'full/prep/L_Puta/transform', 'full/prep/R_Puta/transform', 'full/prep/L_Pall/transform', 'full/prep/R_Pall/transform', 'full/prep/L_Hipp/transform', 'full/prep/R_Hipp/transform', 'full/prep/L_Amyg/transform', 'full/prep/R_Amyg/transform', 'full/prep/L_Accu/transform', 'full/prep/R_Accu/transform', 'full/regr']


In [13]:
eval_res_test = collect_eval_results(model, unnest=True, outer_key="full", train=False)

print(list(eval_res_test.keys()))

['full', 'full/prep/L_Thal/transform', 'full/prep/R_Thal/transform', 'full/prep/L_Caud/transform', 'full/prep/R_Caud/transform', 'full/prep/L_Puta/transform', 'full/prep/R_Puta/transform', 'full/prep/L_Pall/transform', 'full/prep/R_Pall/transform', 'full/prep/L_Hipp/transform', 'full/prep/R_Hipp/transform', 'full/prep/L_Amyg/transform', 'full/prep/R_Amyg/transform', 'full/prep/L_Accu/transform', 'full/prep/R_Accu/transform', 'full/regr']


Eval results of full pipeline.

In [14]:
eval_res_train["full"]

{'r2': array([0.68314331]),
 'X-shape': (93, 81360),
 'y-shape': (93, 1),
 'y_pred-shape': (93, 1)}

In [15]:
eval_res_test["full"]

{'r2': array([-0.5064373]),
 'X-shape': (24, 81360),
 'y-shape': (24, 1),
 'y_pred-shape': (24, 1)}

Equivalently, for the R2:

In [16]:
r2_score(y_train, model.predict(X_train))

0.6831433074088054

In [17]:
r2_score(y_test, model.predict(X_test))

-0.5064373028701794

Eval results of inner regression (PLS component -> EDS).

In [18]:
eval_res_train["full/regr"]

{'mse': array([7.26348815]),
 'res_var': array([8.66031279]),
 'std_err': array([[0.01905456, 0.01787855, 0.01719213, 0.00773973, 0.01596845,
         0.01608923, 0.0188406 , 0.01749317, 0.01347226, 0.00933822,
         0.01972814, 0.01513745, 0.02794816, 0.04333656]]),
 't': array([[-2.98845145, -2.80277649,  2.13555716,  1.4780238 , -1.50452225,
         -0.76852259, -1.44339002,  1.5374797 ,  2.14585654,  1.35634591,
         -0.74215467,  1.35023658,  2.26558754,  3.66469955]]),
 'pvals': array([[3.74812800e-03, 6.38876943e-03, 3.58532618e-02, 1.43428242e-01,
         1.36485627e-01, 4.44497457e-01, 1.52914978e-01, 1.28222804e-01,
         3.49942502e-02, 1.78902328e-01, 4.60222679e-01, 1.80845076e-01,
         2.62501306e-02, 4.49729927e-04]]),
 'adj-pvals': array([[0.05247379, 0.08944277, 0.50194567, 1.        , 1.        ,
         1.        , 1.        , 1.        , 0.4899195 , 1.        ,
         1.        , 1.        , 0.36750183, 0.00629622]]),
 'r2': array([0.68314331]),
 

In [19]:
eval_res_test["full/regr"]

{'mse': array([41.73825159]),
 'res_var': array([111.30200424]),
 'std_err': array([[0.15418333, 0.20786913, 0.32828407, 0.16272751, 0.14304639,
         0.14825106, 0.13795442, 0.12086049, 0.17117646, 0.11914672,
         0.23102542, 0.11965553, 0.35723083, 0.65204415]]),
 't': array([[-0.36932407, -0.2410631 ,  0.11183846,  0.0702985 , -0.16795174,
         -0.0834054 , -0.19712553,  0.22253251,  0.16888735,  0.10630476,
         -0.06337543,  0.17081646,  0.17724953,  0.24356552]]),
 'pvals': array([[0.72042596, 0.81490832, 0.91340609, 0.94549334, 0.87033477,
         0.93535485, 0.8481092 , 0.82886745, 0.86962   , 0.91767235,
         0.9508528 , 0.86814664, 0.86323731, 0.81302839]]),
 'adj-pvals': array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),
 'r2': array([-0.5064373]),
 'X-shape': (24, 14),
 'y-shape': (24, 1),
 'y_pred-shape': (24, 1)}

Eval results of PLS reconstruction.

In [20]:
for struct in structs:
    print(f"{struct}:")

    print("train")
    pprint.pprint(eval_res_train["full/regr"])

    print("test")
    pprint.pprint(eval_res_test["full/regr"])

L_Thal:
train
{'X-shape': (93, 14),
 'adj-pvals': array([[0.05247379, 0.08944277, 0.50194567, 1.        , 1.        ,
        1.        , 1.        , 1.        , 0.4899195 , 1.        ,
        1.        , 1.        , 0.36750183, 0.00629622]]),
 'mse': array([7.26348815]),
 'pvals': array([[3.74812800e-03, 6.38876943e-03, 3.58532618e-02, 1.43428242e-01,
        1.36485627e-01, 4.44497457e-01, 1.52914978e-01, 1.28222804e-01,
        3.49942502e-02, 1.78902328e-01, 4.60222679e-01, 1.80845076e-01,
        2.62501306e-02, 4.49729927e-04]]),
 'r2': array([0.68314331]),
 'res_var': array([8.66031279]),
 'std_err': array([[0.01905456, 0.01787855, 0.01719213, 0.00773973, 0.01596845,
        0.01608923, 0.0188406 , 0.01749317, 0.01347226, 0.00933822,
        0.01972814, 0.01513745, 0.02794816, 0.04333656]]),
 't': array([[-2.98845145, -2.80277649,  2.13555716,  1.4780238 , -1.50452225,
        -0.76852259, -1.44339002,  1.5374797 ,  2.14585654,  1.35634591,
        -0.74215467,  1.35023658,  2.