# How to do multiple structures mesh-valued regression?

NB: an alternative way to using a for loop on ["How to do mesh-valued regression?](./mesh_valued_regression.ipynb).

In [1]:
import numpy as np
import pyvista as pv
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

import polpo.preprocessing.pd as ppd
from polpo.models import DictMeshes2Comps, ObjectRegressor
from polpo.preprocessing import (
    IndexMap,
    Map,
    NestingSwapper,
    PartiallyInitializedStep,
)
from polpo.preprocessing.dict import (
    DictMap,
    DictMerger,
    HashWithIncoming,
    ListDictSwapper,
    NestedDictSwapper,
)
from polpo.preprocessing.load.pregnancy import (
    DenseMaternalCsvDataLoader,
    DenseMaternalMeshLoader,
)
from polpo.preprocessing.mesh.io import PvReader
from polpo.preprocessing.mesh.registration import PvAlign
from polpo.preprocessing.mri import segmtool2encoding



In [2]:
STATIC_VIZ = True

if STATIC_VIZ:
    pv.set_jupyter_backend("static")

## Loading meshes 

In [3]:
tool = "fsl"
subject_id = "01"

encoding = segmtool2encoding(tool)

struct_keys = encoding.structs

n_structs = len(struct_keys)

In [4]:
prep_pipe = PartiallyInitializedStep(
    Step=lambda **kwargs: DictMap(PvAlign(**kwargs)),
    _target=lambda meshes: meshes[list(meshes.keys())[0]],
    max_iterations=500,
)

In [5]:
mesh_loader = HashWithIncoming(
    Map(
        PartiallyInitializedStep(
            Step=DenseMaternalMeshLoader,
            pass_data=False,
            subject_id=subject_id,
            _struct=lambda name: name.split("_")[-1],
            _left=lambda name: name.split("_")[0] == "L",
            as_dict=True,
        )
        + DictMap(PvReader())
    )
)

pipe = mesh_loader + DictMap(prep_pipe)

meshes = pipe(struct_keys)

## Loading tabular data

In [6]:
pilot = subject_id == "01"

pipe = DenseMaternalCsvDataLoader(pilot=pilot, subject_id=subject_id)

df = pipe()

INFO: Data has already been downloaded... using cached file ('/home/luisfpereira/.herbrain/data/maternal/28Baby_Hormones.csv').


Here, we filter the tabular data.

In [7]:
session_selector = ppd.DfIsInFilter("stage", ["post"], negate=True)

predictor_selector = (
    session_selector + ppd.ColumnsSelector("gestWeek") + ppd.SeriesToDict()
)

In [8]:
x_dict = predictor_selector(df)

## Merge data

We get the data in the proper format for fitting.

In [9]:
dict_pipe = (
    IndexMap(NestedDictSwapper(), index=1)
    + DictMerger()
    + NestingSwapper()
    + IndexMap(lambda x: np.array(x)[:, None], index=0)
    + IndexMap(ListDictSwapper(), index=1)
)

# meshes_ : dict[list]
X, meshes_ = dict_pipe([x_dict, meshes])

## Create and fix regressor

In [10]:
pca = PCA(n_components=4)

objs2y = DictMeshes2Comps(n_pipes=n_structs, dim_reduction=pca)

In [11]:
model = ObjectRegressor(LinearRegression(fit_intercept=True), objs2y=objs2y)

In [12]:
model.fit(X, meshes_)

## Evaluate fit

`model.predict` outputs meshes, but we know `LinearRegression` sees `PCA` components. We can evaluate [`r2_score`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html) by applying transform.

NB: these are values on the training data.

In [13]:
meshes_pred = model.predict(X)

y_true = objs2y.transform(meshes_)
y_pred = objs2y.transform(meshes_pred)

scores = r2_score(y_true, y_pred, multioutput="raw_values")

dict(zip(struct_keys, scores.reshape(n_structs, -1)))

{'BrStem': array([0.43173305, 0.00093644, 0.1275041 , 0.02389214]),
 'L_Thal': array([0.01415539, 0.53246232, 0.01803249, 0.19446462]),
 'R_Thal': array([0.29409484, 0.19884918, 0.00710387, 0.04048802]),
 'L_Caud': array([0.19577603, 0.43087557, 0.0341979 , 0.01492186]),
 'R_Caud': array([0.28438233, 0.25672117, 0.06444697, 0.03906281]),
 'L_Puta': array([0.66647758, 0.01291946, 0.03550197, 0.02858578]),
 'R_Puta': array([0.31350877, 0.13438541, 0.12970828, 0.1337651 ]),
 'L_Pall': array([0.3591528 , 0.04855132, 0.24886364, 0.057122  ]),
 'R_Pall': array([0.52617966, 0.02088384, 0.06184903, 0.07608355]),
 'L_Hipp': array([0.04356306, 0.0064415 , 0.24583457, 0.00145246]),
 'R_Hipp': array([0.03846712, 0.0608323 , 0.25813046, 0.00213181]),
 'L_Amyg': array([0.0095485 , 0.25933399, 0.39715459, 0.00524612]),
 'R_Amyg': array([0.31313143, 0.10727696, 0.01111527, 0.00072597]),
 'L_Accu': array([3.66642465e-01, 9.99459512e-03, 5.68689499e-02, 1.89262827e-04]),
 'R_Accu': array([0.26589629, 0.