In [1]:
import pprint

import numpy as np
import polpo.preprocessing.dict as ppdict
import polpo.preprocessing.pd as ppd
from polpo.model_eval import (
    MultiEvaluator,
    OlsPValues,
    R2Score,
    ReconstructionError,
    ShapeCollector,
    VertexReconstructionError,
    collect_eval_results,
)
from polpo.preprocessing import (
    IndexMap,
    Map,
    NestingSwapper,
    PartiallyInitializedStep,
)
from polpo.preprocessing.load.pregnancy import (
    CacheableMeshLoader,
    NeuroMaternalMultiMeshLoader,
    NeuroMaternalTabularDataLoader,
)
from polpo.preprocessing.mesh.conversion import ToVertices
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.mri import segmtool2encoding
from polpo.preprocessing.np import ConcatenationIndices
from polpo.sklearn.adapter import AdapterFeatureUnion, AdapterPipeline, EvaluatedModel
from polpo.sklearn.preprocessing import ColumnIndexSelector
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split



In [2]:
DEBUG = 0  # >0: no alignment, 20: fewer structs

EDS.

In [3]:
pipe = NeuroMaternalTabularDataLoader(
    keep_mothers=True,
    keep_control=False,
)

tab_data = pipe()

In [4]:
eds_pipe = (
    ppd.DfFilter(lambda df: df["ses"] == 1)
    + ppd.IndexSetter("participant_id", drop=True)
    + ppd.ColumnsSelector("EDS.Total")
    + ppd.SeriesToDict()
)

eds_dict = eds_pipe(tab_data)

Meshes and vector fields.

In [5]:
if DEBUG > 19:
    structs = ["L_Hipp", "R_Hipp"]
else:
    encoding = segmtool2encoding("fsl")
    structs = encoding.structs
    structs.remove("BrStem")


n_structs = len(structs)

as_dict = False


no_cache_pipe = NeuroMaternalMultiMeshLoader(as_mesh=True)

if not DEBUG:
    # because alignment is slow
    per_struct_pipe = PartiallyInitializedStep(
        Step=lambda **kwargs: ppdict.NestedDictMap(
            PvAlign(**kwargs), inner_is_dict=as_dict, depth=1
        ),
        # NB: aligns against first subject, t_0
        _target=lambda meshes: meshes[list(meshes.keys())[0]][0],
        max_iterations=500,
    )

    no_cache_pipe += ppdict.DictMap(per_struct_pipe)

# NB: use reset_cache if changes to the pipeline
pipe = CacheableMeshLoader(
    "cached_nm_meshes",
    no_cache_pipe,
    use_cache=not DEBUG,
    cache=not DEBUG,
)

In [6]:
structs_dict = pipe(structs)

structs

['L_Thal',
 'R_Thal',
 'L_Caud',
 'R_Caud',
 'L_Puta',
 'R_Puta',
 'L_Pall',
 'R_Pall',
 'L_Hipp',
 'R_Hipp',
 'L_Amyg',
 'R_Amyg',
 'L_Accu',
 'R_Accu']

In [7]:
meshes2flat_vfields = ppdict.DictMap(
    step=Map(ToVertices()) + (lambda verts: (verts[1] - verts[0]).flatten())
)
structs2flat_vfields = ppdict.DictMap(meshes2flat_vfields)

structs_flat_vfields_dict = structs2flat_vfields(structs_dict)

len(structs_flat_vfields_dict)

14

In [8]:
dataset_pipe = (
    IndexMap(index=0, step=ppdict.NestedDictSwapper())
    + ppdict.DictMerger()
    # [(x_0, y_0), ..., (x_n, y_n)]
    + NestingSwapper()
    # [(x_0, x_1, ..., x_n), (y_0, y_1, ..., y_n)]
    + IndexMap(index=0, step=ppdict.ListDictSwapper())
    # inverts first swapping
    + IndexMap(index=0, step=ppdict.DictMap(lambda x: np.stack(x)))
    # makes eds a array instead of list
    + IndexMap(index=1, step=lambda x: np.stack(x)[:, None])
)


structs_flat_vfields, eds = dataset_pipe([structs_flat_vfields_dict, eds_dict])

flat_vfields = (ppdict.DictToValuesList() + (lambda x: np.concatenate(x, axis=-1)))(
    structs_flat_vfields
)

len(structs_flat_vfields), flat_vfields.shape, eds.shape

(14, (117, 81360), (117, 1))

In [9]:
pipe = ppdict.ZipWithKeys(ConcatenationIndices(axis=-1))
indices = pipe(structs_flat_vfields)

feature_union = AdapterFeatureUnion(
    [
        (
            name,
            AdapterPipeline(
                [
                    ColumnIndexSelector(start_index, end_index),
                    (
                        "transform",
                        EvaluatedModel(
                            PLSRegression(n_components=1),
                            MultiEvaluator(
                                [
                                    ReconstructionError(),
                                    VertexReconstructionError(prefix="vertex"),
                                    ShapeCollector(),
                                ]
                            ),
                        ),
                    ),
                ]
            ),
        )
        for name, (start_index, end_index) in indices.items()
    ]
)

feature_union

0,1,2
,transformer_list,"[('L_Accu', ...), ('L_Amyg', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,2790

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795eff5c6600>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,2790
,end_index,6894

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795effa8fe60>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,6894
,end_index,14400

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795eff7b7aa0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,14400
,end_index,21906

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795effd3b560>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,21906
,end_index,25668

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795f04d68860>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,25668
,end_index,33174

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b2e70>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,33174
,end_index,40680

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3500>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,40680
,end_index,43470

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b2930>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,43470
,end_index,47574

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3830>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,47574
,end_index,55080

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b39e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,55080
,end_index,62586

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3b90>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,62586
,end_index,66348

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3d40>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,66348
,end_index,73854

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3ef0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,73854
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6040e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True


In [10]:
model = EvaluatedModel(
    AdapterPipeline(
        [
            ("prep", feature_union),
            (
                "regr",
                EvaluatedModel(
                    LinearRegression(),
                    MultiEvaluator([OlsPValues(), R2Score(), ShapeCollector()]),
                ),
            ),
        ]
    ),
    MultiEvaluator([R2Score(), ShapeCollector()]),
)

model

0,1,2
,model,AdapterPipeli...gression()))])
,evaluator,<polpo.model_...x795efd6058b0>

0,1,2
,transformer_list,"[('L_Accu', ...), ('L_Amyg', ...), ...]"
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True

0,1,2
,start_index,0
,end_index,2790

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795eff5c6600>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,2790
,end_index,6894

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795effa8fe60>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,6894
,end_index,14400

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795eff7b7aa0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,14400
,end_index,21906

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795effd3b560>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,21906
,end_index,25668

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795f04d68860>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,25668
,end_index,33174

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b2e70>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,33174
,end_index,40680

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3500>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,40680
,end_index,43470

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b2930>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,43470
,end_index,47574

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3830>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,47574
,end_index,55080

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b39e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,55080
,end_index,62586

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3b90>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,62586
,end_index,66348

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3d40>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,66348
,end_index,73854

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6b3ef0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,start_index,73854
,end_index,81360

0,1,2
,model,PLSRegression(n_components=1)
,evaluator,<polpo.model_...x795efd6040e0>

0,1,2
,n_components,1
,scale,True
,max_iter,500
,tol,1e-06
,copy,True

0,1,2
,model,LinearRegression()
,evaluator,<polpo.model_...x795efd605a90>

0,1,2
,fit_intercept,True
,copy_X,True
,tol,1e-06
,n_jobs,
,positive,False


In [11]:
X, y = flat_vfields, eds

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)


model.fit(X_train, y_train);

In [12]:
model.predict_eval(X_test, y_test);

In [13]:
eval_res_train = collect_eval_results(model, unnest=True, outer_key="full")

print(list(eval_res_train.keys()))

['full', 'full/prep/L_Accu/transform', 'full/prep/L_Amyg/transform', 'full/prep/L_Caud/transform', 'full/prep/L_Hipp/transform', 'full/prep/L_Pall/transform', 'full/prep/L_Puta/transform', 'full/prep/L_Thal/transform', 'full/prep/R_Accu/transform', 'full/prep/R_Amyg/transform', 'full/prep/R_Caud/transform', 'full/prep/R_Hipp/transform', 'full/prep/R_Pall/transform', 'full/prep/R_Puta/transform', 'full/prep/R_Thal/transform', 'full/regr']


In [14]:
eval_res_test = collect_eval_results(model, unnest=True, outer_key="full", train=False)

print(list(eval_res_test.keys()))

['full', 'full/prep/L_Accu/transform', 'full/prep/L_Amyg/transform', 'full/prep/L_Caud/transform', 'full/prep/L_Hipp/transform', 'full/prep/L_Pall/transform', 'full/prep/L_Puta/transform', 'full/prep/L_Thal/transform', 'full/prep/R_Accu/transform', 'full/prep/R_Amyg/transform', 'full/prep/R_Caud/transform', 'full/prep/R_Hipp/transform', 'full/prep/R_Pall/transform', 'full/prep/R_Puta/transform', 'full/prep/R_Thal/transform', 'full/regr']


Eval results of full pipeline.

In [15]:
eval_res_train["full"]

{'r2': array([0.68462256]),
 'X-shape': (93, 81360),
 'y-shape': (93, 1),
 'y_pred-shape': (93, 1)}

In [16]:
eval_res_test["full"]

{'r2': array([-0.62393415]),
 'X-shape': (24, 81360),
 'y-shape': (24, 1),
 'y_pred-shape': (24, 1)}

Equivalently, for the R2:

In [17]:
r2_score(y_train, model.predict(X_train))

0.6846225643173502

In [18]:
r2_score(y_test, model.predict(X_test))

-0.6239341469056814

Eval results of inner regression (PLS component -> EDS).

In [19]:
eval_res_train["full/regr"]

{'mse': array([6.68480562]),
 'res_var': array([7.97034516]),
 'std_err': array([[0.01927579, 0.00733735, 0.00683092, 0.01658303, 0.01847505,
         0.01015303, 0.03169532, 0.03757794, 0.0314626 , 0.00924864,
         0.01345595, 0.03092522, 0.0077391 , 0.01943069]]),
 't': array([[ 2.30125875, -0.16685889,  1.62679709, -1.15437511, -2.33392997,
         -2.83596645, -2.53161826, -2.5841925 , -2.11353078, -1.39043022,
          1.08790983, -1.34108367,  0.2568656 ,  2.50509851]]),
 'pvals': array([[0.02404756, 0.86791289, 0.10781365, 0.25187099, 0.02217522,
         0.00581739, 0.01336401, 0.01162716, 0.03775207, 0.168353  ,
         0.27998512, 0.1837855 , 0.79795887, 0.01432596]]),
 'adj-pvals': array([[0.33666585, 1.        , 1.        , 1.        , 0.31045311,
         0.08144342, 0.18709612, 0.16278028, 0.52852892, 1.        ,
         1.        , 1.        , 1.        , 0.20056349]]),
 'r2': array([0.68462256]),
 'X-shape': (93, 14),
 'y-shape': (93, 1),
 'y_pred-shape': (93, 1

In [20]:
eval_res_test["full/regr"]

{'mse': array([49.38338631]),
 'res_var': array([131.68903017]),
 'std_err': array([[0.21953602, 0.07885852, 0.11581516, 0.31982838, 0.28294   ,
         0.28858938, 0.42307034, 0.44786144, 0.27365608, 0.18693499,
         0.18151239, 0.41155944, 0.0841037 , 0.18689082]]),
 't': array([[ 0.20205607, -0.01552531,  0.09595054, -0.05985408, -0.15239793,
         -0.09977382, -0.18966222, -0.2168274 , -0.24299544, -0.06879176,
          0.08064939, -0.10077114,  0.02363639,  0.26045046]]),
 'pvals': array([[0.84436633, 0.98795182, 0.9256624 , 0.95357985, 0.88223489,
         0.92271101, 0.85378237, 0.83317818, 0.81345655, 0.94665953,
         0.93748574, 0.92194133, 0.98165841, 0.80037736]]),
 'adj-pvals': array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),
 'r2': array([-0.62393415]),
 'X-shape': (24, 14),
 'y-shape': (24, 1),
 'y_pred-shape': (24, 1)}

Eval results of PLS reconstruction.

In [21]:
for struct in structs:
    print(f"{struct}:")

    print("train")
    pprint.pprint(eval_res_train[f"full/prep/{struct}/transform"])

    print("test")
    pprint.pprint(eval_res_test[f"full/prep/{struct}/transform"])

L_Thal:
train
{'X-shape': (93, 7506),
 'featurewise_rec_error': array([14.50835673, 11.70867862,  6.50976769, ..., 13.50285457,
       20.4929383 ,  6.91328294], shape=(7506,)),
 'rec_error_mse': np.float64(0.11530862093852919),
 'rec_error_sum': np.float64(80492.10531510781),
 'vertex-rec_error_mse': np.float64(0.34592586281558757),
 'vertex-rec_error_sum': np.float64(80492.10531510781),
 'vertex-vertexwise_rec_error': array([32.72680303, 33.143953  , 32.12779427, ..., 36.86656543,
       43.73405074, 40.90907581], shape=(2502,)),
 'y-shape': (93, 1),
 'y_pred-shape': (93, 1)}
test
{'X-shape': (24, 7506),
 'featurewise_rec_error': array([4.01293196, 1.75728445, 1.25950337, ..., 2.85353165, 4.41114796,
       1.54124499], shape=(7506,)),
 'rec_error_mse': np.float64(0.10368035639230486),
 'rec_error_sum': np.float64(18677.394121935366),
 'vertex-rec_error_mse': np.float64(0.31104106917691454),
 'vertex-rec_error_sum': np.float64(18677.394121935366),
 'vertex-vertexwise_rec_error': arra