In [2]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv
import os
import itertools
import random

os.environ["GEOMSTATS_BACKEND"] = "pytorch"  # noqa: E402
import geomstats.backend as gs


from H2_SurfaceMatch.utils.input_output import plotGeodesic
import H2_SurfaceMatch.utils.input_output as h2_io
import H2_SurfaceMatch.utils.utils
import src.datasets.utils as data_utils
import project_menstrual.default_config as default_config
from src.regression import check_euclidean, training

import src.setcwd

src.setcwd.main()

INFO: Using pytorch backend


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Working directory:  /home/adele/code/my28brains/src
Directory added to path:  /home/adele/code/my28brains
Directory added to path:  /home/adele/code/my28brains/src
Directory added to path:  /home/adele/code/my28brains/H2_SurfaceMatch
Directory added to path:  /home/adele/code/my28brains/src/notebooks
Directory added to path:  /home/adele/code/my28brains/src/notebooks/csv


In [2]:
# multiple linear regression, as done in project_menstrual/main_2_linear_regression

(
    space,
    y,
    all_hormone_levels,
    true_intercept,
    true_coef,
) = data_utils.load_real_data(default_config)


n_train = int(default_config.train_test_split * len(y))

X_indices = np.arange(len(y))
# Shuffle the array to get random values
random.shuffle(X_indices)
train_indices = X_indices[:n_train]
train_indices = np.sort(train_indices)
test_indices = X_indices[n_train:]
test_indices = np.sort(test_indices)

# TODO: instead, save these values in main_2, and then load them here. or, figure out how to predict the mesh using just the intercept and coef learned here, and then load them.

progesterone_levels = gs.array(all_hormone_levels["Prog"].values)
estrogen_levels = gs.array(all_hormone_levels["Estro"].values)
dheas_levels = gs.array(all_hormone_levels["DHEAS"].values)
lh_levels = gs.array(all_hormone_levels["LH"].values)
fsh_levels = gs.array(all_hormone_levels["FSH"].values)
shbg_levels = gs.array(all_hormone_levels["SHBG"].values)

X_multiple = gs.vstack(
    (
        progesterone_levels,
        estrogen_levels,
        dheas_levels,
        lh_levels,
        fsh_levels,
        shbg_levels,
    )
).T  # NOTE: copilot thinks this should be transposed.

(
    multiple_intercept_hat,
    multiple_coef_hat,
    mr,
) = training.fit_linear_regression(y, X_multiple)

mr_score_array = training.compute_R2(y, X_multiple, test_indices, train_indices)

X_multiple_predict = gs.array(X_multiple.reshape(len(X_multiple), -1))
y_pred_for_mr = mr.predict(X_multiple_predict)
y_pred_for_mr = y_pred_for_mr.reshape([len(X_multiple), len(y[0]), 3])

<module 'project_menstrual.default_config' from '/home/adele/code/my28brains/project_menstrual/default_config.py'>
project_dir: /home/adele/code/my28brains/project_menstrual
Using menstrual mesh data (from reparameterized directory)

e. (Sort) Found 29 .plys for (left, -1) in /home/adele/code/my28brains/project_menstrual/results/1_preprocess/d_reparameterized
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
Day 20 has no data. Skipping.
DayID not to use: 20
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape  (1146, 3)
vertices.shape 

RPly: Unable to open file


In [3]:
# Create widgets for the four hormones
progesterone_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='Progesterone')
FSH_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='FSH')
LH_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='LH')
estrogen_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='Estrogen')
DHEAS_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='DHEAS')
SHBG_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='SHBG')

output = widgets.Output()

# project_dir = default_config.project_dir
# print(project_dir)
# save_path = os.path.join("./temp_plys/widget_mesh")
# print(save_path)

def plot_hormone_levels(change):
    with output:
        clear_output(wait=True)
        # Retrieve the current values of the sliders
        progesterone = gs.array(progesterone_slider.value)
        FSH = gs.array(FSH_slider.value)
        LH = gs.array(LH_slider.value)
        estrogen = gs.array(estrogen_slider.value)
        SHBG = gs.array(SHBG_slider.value)
        DHEAS = gs.array(DHEAS_slider.value)

        # Predict Mesh
        X_multiple = gs.vstack(
            (
                progesterone,
                estrogen,
                DHEAS,
                LH,
                FSH,
                SHBG,
            )
        ).T
        
        X_multiple_predict = gs.array(X_multiple.reshape(len(X_multiple), -1))        

        y_pred_for_mr = mr.predict(X_multiple_predict)
        print(y_pred_for_mr.shape)
        y_pred_for_mr = y_pred_for_mr.reshape([len(y[0]), 3])
        


        # Decimate mesh for plotting
        vertices = y_pred_for_mr
        faces = gs.array(space.faces).numpy()
        n_faces_after_decimation = int(faces.shape[0] / 110)
        vertices, faces = H2_SurfaceMatch.utils.utils.decimate_mesh(
            vertices, faces, n_faces_after_decimation
        )

        print(vertices.shape)

        # Plotting
        mesh = pv.PolyData(vertices, faces)

        mesh.plot(show_edges=True, line_width=5)
            
# Attach the plotting function to the sliders
progesterone_slider.observe(plot_hormone_levels, names='value')
FSH_slider.observe(plot_hormone_levels, names='value')
LH_slider.observe(plot_hormone_levels, names='value')
estrogen_slider.observe(plot_hormone_levels, names='value')
DHEAS_slider.observe(plot_hormone_levels, names='value')
SHBG_slider.observe(plot_hormone_levels, names='value')

# Display the widgets
widgets_display = widgets.VBox([progesterone_slider, FSH_slider, LH_slider, estrogen_slider, DHEAS_slider, SHBG_slider, output])
display(widgets_display)

VBox(children=(FloatSlider(value=0.5, description='Progesterone', max=1.0, step=0.01), FloatSlider(value=0.5, …

In [None]:
# mesh points
vertices = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0.5, 0.5, -1]])

# mesh faces
faces = np.hstack(
    [
        [4, 0, 1, 2, 3],  # square
        [3, 0, 1, 4],  # triangle
        [3, 1, 2, 4],  # triangle
    ]
)

surf = pv.PolyData(vertices, faces)

# plot each face with a different color
surf.plot(
    scalars=np.arange(3),
    cpos=[-1, 1, 0.5],
    show_scalar_bar=False,
    show_edges=True,
    line_width=5,
)