In [None]:
"""Utility classes and functions related to gresit.

Copyright (c) 2025 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

In [1]:
import numpy as np
from gresit.group_resit import GroupResit
from gresit.independence_tests import HSIC
from gresit.torch_models import Multioutcome_MLP

In [2]:
rng = np.random.default_rng(42)  # Set seed for reproducibility

X_1 = rng.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=2000)
X_2 = rng.multivariate_normal(mean=np.zeros(3), cov=np.eye(3), size=2000)

X_3 = np.column_stack(
    [X_1[:, 0] * X_2[:, 0] + X_1[:, 1] * X_2[:, 1], X_1[:, 1] * X_2[:, 2]]
) + 0.1 * rng.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=2000)

In [3]:
data_dict = {
    "X_1": X_1,
    "X_2": X_2,
    "X_3": X_3,
}

In [4]:
for key, arr in data_dict.items():
    print(f"| `{key}` | {arr.shape} | {arr.dtype} | `{np.round(arr[:2], decimals=3).tolist()}` |")

| `X_1` | (2000, 2) | float64 | `[[0.305, -1.04], [0.75, 0.941]]` |
| `X_2` | (2000, 3) | float64 | `[[0.253, 0.895, 0.273], [2.239, 1.43, -0.308]]` |
| `X_3` | (2000, 2) | float64 | `[[-0.836, -0.194], [2.878, -0.318]]` |


In [5]:
gresit = GroupResit(regressor=Multioutcome_MLP(), test=HSIC)
learned_graph = gresit.learn_graph(data_dict)

 48%|█████████████████████████████                                | 143/300 [00:04<00:05, 29.67it/s]


Early stopping triggered


100%|█████████████████████████████████████████████████████████████| 300/300 [00:06<00:00, 43.25it/s]
 35%|█████████████████████▎                                       | 105/300 [00:02<00:04, 41.60it/s]


Early stopping triggered


 47%|████████████████████████████▋                                | 141/300 [00:02<00:03, 51.60it/s]


Early stopping triggered


 27%|████████████████▋                                             | 81/300 [00:01<00:04, 51.58it/s]


Early stopping triggered

In [8]:
fig = gresit.show_interactive()
fig

In [14]:
# Output html that you can copy paste
fig.write_html("../docs/imgs/interactive_graph.html", full_html=True, include_plotlyjs="cdn")