# **Experiments on synthetic data**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lealagonotte/Geometric_data_analysis_project/blob/raph-colab/experiments/main_synthetic.ipynb)  
See on Hugging Face: https://huggingface.co/datasets/xingjiepan/PerturbMulti

## **Colab setup**

In [1]:
!git clone https://github.com/lealagonotte/Geometric_data_analysis_project.git
%cd Geometric_data_analysis_project/
!git clone https://github.com/raphaelrubrice/Perturb-OT.git
!git checkout raph-colab

Cloning into 'Geometric_data_analysis_project'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (90/90), done.[K
remote: Total 116 (delta 55), reused 68 (delta 23), pack-reused 0 (from 0)[K
Receiving objects: 100% (116/116), 4.47 MiB | 10.58 MiB/s, done.
Resolving deltas: 100% (55/55), done.
/content/Geometric_data_analysis_project
Cloning into 'Perturb-OT'...
remote: Enumerating objects: 906, done.[K
remote: Counting objects: 100% (55/55), done.[K
remote: Compressing objects: 100% (44/44), done.[K
remote: Total 906 (delta 8), reused 44 (delta 2), pack-reused 851 (from 2)[K
Receiving objects: 100% (906/906), 38.95 MiB | 14.54 MiB/s, done.
Resolving deltas: 100% (118/118), done.
M	Perturb-OT
Branch 'raph-colab' set up to track remote branch 'raph-colab' from 'origin'.
Switched to a new branch 'raph-colab'


In [2]:
!pip install -r requirements-env.txt

Processing ./Perturb-OT/ott (from -r requirements-env.txt (line 4))
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Processing ./Perturb-OT/perturbot (from -r requirements-env.txt (line 5))
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Processing ./Perturb-OT/scvi-tools (from -r requirements-env.txt (line 6))
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting absl-py==2.3.1 (from -r requirements-env.txt (line 8))
  Downloading absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting anndata==0.10.9 (from -r requirements-env.txt (line 13))
  Downloading anndata-0.10.9-py3-none-any.whl.metadata (6.9 kB)
Collectin

YOU WILL NEED TO RESTART THE SESSION AFTER THE PREVIOUS CELL.
After restarting you should have a working env.

Let's check that we have succesfully recreated the env

In [1]:
import sys

from perturbot.match import (
    get_coupling_cotl,
    get_coupling_cotl_sinkhorn,
    get_coupling_egw_labels_ott,
    get_coupling_egw_all_ott,
    get_coupling_eot_ott,
    get_coupling_leot_ott,
    get_coupling_egw_ott,
    get_coupling_cot,
    get_coupling_cot_sinkhorn,
    get_coupling_gw_labels,
    get_coupling_fot,
)
from perturbot.predict import train_mlp
import ot
print("Imports succeeded")

Imports succeeded


## **Imports**

In [None]:
import torch
import perturbot.match
import perturbot.predict
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import StratifiedKFold, train_test_split
import seaborn as sns
sns.set_theme("paper")
import pandas as pd
from perturbot.eval.prediction import get_evals, get_evals_preds
from perturbot.eval.match import get_FOSCTTM_single, get_FOSCTTM
import ot
import itertools, time, os
from tqdm.auto import tqdm

## **Generate data according to Appendix G from [Ryu et al., 2024](https://arxiv.org/abs/2405.00838)**

In [3]:
## Generate synthetic data as specified in Annex G

def generate_synthetic_data(n_cells_per_pert=50, n_perturbations=10, d_latent=5, p_X=50, p_Y=200):
    """Generate synthetic data as specified in Annex G """
    print(f"Generation of {n_cells_per_pert * n_perturbations} synthetic cells")
    n_total_cells = n_cells_per_pert * n_perturbations

    # Z ~ N(0, 0.1)
    Z_base = np.random.normal(0.0, np.sqrt(0.1), (n_total_cells, d_latent))

    # A_X, A_Y, b_X, b_Y, s_X, s_Y
    A_X = np.random.normal(0.0, 1.0, (d_latent, p_X))
    A_Y = np.random.normal(0.0, 1.0, (d_latent, p_Y))
    b_X = np.random.normal(0.0, 1.0, (p_X,))
    b_Y = np.random.normal(0.0, 1.0, (p_Y,))
    s_X = np.random.gamma(1.0, 1.0, (p_X,))
    s_Y = np.random.gamma(1.0, 1.0, (p_Y,))

    # zeta_X, zeta_Y (Technical noise)
    std_X = np.sqrt(np.exp(np.random.normal(-3, 0)))
    zeta_X = np.random.normal(0, std_X, (n_total_cells, d_latent))
    std_Y = np.sqrt(np.exp(np.random.normal(-3, 0)))
    zeta_Y = np.random.normal(0, std_Y, (n_total_cells, d_latent))

    # Labels and perturbation
    labels = np.repeat(np.arange(n_perturbations), n_cells_per_pert)
    target_dims = np.zeros(n_perturbations, dtype=int)
    target_dims[1:] = (np.arange(n_perturbations - 1) % d_latent)
    effect_sizes = np.zeros(n_perturbations)
    effect_sizes[1:] = np.maximum(3.0, np.random.gamma(1.0, 1.0, n_perturbations - 1))
    penetrance = np.random.beta(1.0, 10.0, n_total_cells)

    Z_perturbed = Z_base.copy()
    for i in range(n_total_cells):
        label_idx = labels[i]
        if label_idx > 0: # No perturbation for label 0
            target_dim = target_dims[label_idx]
            effect = effect_sizes[label_idx]
            Z_perturbed[i, target_dim] += effect * penetrance[i]

    # Final Generation
    Z_noisy_X = Z_perturbed + zeta_X
    Z_noisy_Y = Z_perturbed + zeta_Y
    X = ((Z_noisy_X @ A_X) + b_X) * s_X
    Y = ((Z_noisy_Y @ A_Y) + b_Y) * s_Y

    # Identifier of the cells
    ids = np.arange(n_total_cells).reshape(-1, 1)

    return X, Y, labels, ids

In [4]:
#We generate the data as presented in Annex G
X_full, Y_full, labels_full, ids_full = generate_synthetic_data(
        n_cells_per_pert=50,
        n_perturbations=10,
        d_latent=5,
        p_X=50,
        p_Y=200
    )

Generation of 500 synthetic cells


In [5]:
def format_data_for_coupling(X, Y, labels):
    """
    Transform the data into dictionaries keyed by integer labels.
    """
    X_dict = {}
    Y_dict = {}
    unique_labels = np.unique(labels)

    for l_numpy in unique_labels:
        l_python = int(l_numpy)

        indices = np.where(labels == l_numpy)[0]

        X_dict[l_python] = X[indices]
        Y_dict[l_python] = Y[indices]

    return (X_dict, Y_dict)

In [6]:
#prepare folds for cross-validation

X = np.array(X_full)
Y = np.array(Y_full)
labels = np.array(labels_full)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)

folds = []

for fold_idx, (train_index, test_index) in enumerate(skf.split(X, labels), 1):

    # --- 1) Split brut ---
    X_train_raw, X_test_raw = X[train_index], X[test_index]
    Y_train_raw, Y_test_raw = Y[train_index], Y[test_index]
    lab_train_raw, lab_test_raw = labels[train_index], labels[test_index]

    # Format the data for coupling

    data_train = format_data_for_coupling(X_train_raw, Y_train_raw, lab_train_raw)
    data_test  = format_data_for_coupling(X_test_raw,  Y_test_raw,  lab_test_raw)
    unique_labels_test = sorted(data_test[0].keys())

    #Ground truth test set preparation

    X_test_concat = np.concatenate([data_test[0][l] for l in unique_labels_test], axis=0)
    X_test_t = torch.tensor(X_test_concat, dtype=torch.double)

    Y_test_concat = np.concatenate([data_test[1][l] for l in unique_labels_test], axis=0)

    unique_labels_train = sorted(data_train[0].keys())

    labels_train_concat = np.concatenate([
        np.full(len(data_train[0][l]), l) for l in unique_labels_train
    ])

    # --- 5) Sauvegarde du fold ---
    folds.append({
        "fold": fold_idx,

        "X_train_raw": X_train_raw,
        "X_test_raw":  X_test_raw,
        "Y_train_raw": Y_train_raw,
        "Y_test_raw":  Y_test_raw,
        "lab_train_raw": lab_train_raw,
        "lab_test_raw":  lab_test_raw,

        "data_train": data_train,
        "data_test":  data_test,

        "X_test_torch": X_test_t,
        "Y_test_concat": Y_test_concat,

        "labels_train_concat": labels_train_concat,
    })

    print(f"Fold {fold_idx} created:",
          X_train_raw.shape, X_test_raw.shape)


Fold 1 created: (400, 50) (100, 50)
Fold 2 created: (400, 50) (100, 50)
Fold 3 created: (400, 50) (100, 50)
Fold 4 created: (400, 50) (100, 50)
Fold 5 created: (400, 50) (100, 50)


## **Compute couplings with diverse methods**
(GWOT, COOT, FGWOT, with or without labels, with or without entropic regularization)

In [7]:
#EGWOT labeled
legw = [perturbot.match.get_coupling_egw_labels_ott(
        folds[i]["data_train"],
        eps=0.001
    ) for i in range(len(folds))]

running EGWL with ott
GW called
lse step
updating linearization
Label considered for Sinkhorn run
lse step
updating linearization
Label considered for Sinkhorn run
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.535
Done running LEGWOT with ott
running EGWL with ott
GW called
lse step
updating linearization
Label considered for Sinkhorn run
lse step
updating linearization
Label considered for Sinkhorn run
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.492
Done running LEGWOT with ott
running EGWL with ott
GW called
lse step
updating linearization
Label considered for Sinkhorn run
lse step
updating linearization
Label considered for Sinkhorn run
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop o

In [8]:
#EGWOT without labels
egw = [perturbot.match.get_coupling_egw_all_ott(
        folds[i]["data_train"],
        eps=0.005
    ) for i in range(len(folds))]

running EGWOT with ott
GW called
lse step
updating linearization
lse step
updating linearization
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.014
Done running EGWOT with ott
running EGWOT with ott
GW called
lse step
updating linearization
lse step
updating linearization
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.014
Done running EGWOT with ott
running EGWOT with ott
GW called
lse step
updating linearization
lse step
updating linearization
5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.017
Done running EGWOT with ott
running EGWOT with ott
GW called
lse step
updating linearization
lse step
updating linearizati

In [9]:
#EGWOT per labels
egwper = [perturbot.match.get_coupling_egw_ott(
        folds[i]["data_train"],
        eps=0.050
    ) for i in range(len(folds))]

GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearization
lse step
updating linearization
GW called
lse step
updating linearizat

In [10]:
#COOT labeled
cotl= [perturbot.match.get_coupling_cotl(
        folds[i]["data_train"],
    ) for i in range(len(folds))]

M_0:2.0807757240709446 - 9.979260661462156
M_1:1.9576554364051793 - 10.695716332959297
M_2:1.9554863818289678 - 15.559588021754632
M_3:2.1764447796976265 - 14.311995659199445
M_4:2.136620590661646 - 9.661518660147577
M_5:2.191343394354142 - 8.484049171487667
M_6:2.114891508200784 - 13.851452273958127
M_7:2.1945293772877617 - 11.035159084601723
M_8:2.1087095675881575 - 13.0066007730346
M_9:2.0609096425162714 - 12.281403566229883
M:0.012201868744381105 - 823.2843595750833
It 0 Delta: 0.06999999999999976  Loss: 25.67649195831666
M_0:0.4581908439838984 - 8.110383495445143
M_1:0.5504603209463741 - 8.467301514823104
M_2:0.5131228929233251 - 11.46995449244077
M_3:0.46855901588044757 - 12.274733507813181
M_4:0.48363200042993215 - 9.545162977697828
M_5:0.5243299476177494 - 8.077377835616801
M_6:0.7770172026203339 - 11.938007828074328
M_7:0.5382157399111311 - 10.708639275272043
M_8:0.34614478425679573 - 9.715440865062266
M_9:0.6072757733088914 - 12.458662415708435
M:0.01217299043027878 - 721.970

In [7]:
# Ram will explod

#ECOOT labeled
#ecotl = [perturbot.match.get_coupling_cotl_sinkhorn(
#        folds[i]["data_train"],
#        eps=1e-2
#    ) for i in range(len(folds))]

calculating with eps 0.01
M_0:2.681655854678219 - 10.588568142066041
lse step
M_1:2.6879365420784818 - 11.206690694341459
lse step
M_2:2.551115168056964 - 12.587408323173422
lse step
M_3:3.0134638950953736 - 13.125055993762468
lse step
M_4:3.043866030269688 - 13.903979301580849
lse step
M_5:2.523920643321646 - 11.653478548143733
lse step
M_6:2.913235280682311 - 16.53657575265407
lse step
M_7:2.8232613686742365 - 15.127454092680663
lse step
M_8:2.5162272236145498 - 10.640903265906534
lse step
M_9:2.836039545645148 - 13.457203319610564
lse step
M:8.116556698363331e-05 - 1596.986501753395
lse step
It 0 Delta: 0.013767839217219684  Loss: 42.418918655422
M_0:1.1599790762433138 - 11.134167453337694
lse step
M_1:1.1774822964379092 - 8.648858133099928
lse step
M_2:1.228801757788319 - 12.439884767921182
lse step
M_3:1.379328107177072 - 11.400863323819623
lse step
M_4:1.5504132897141307 - 13.039834572399785
lse step
M_5:1.2646746218477307 - 10.24598963736837
lse step
M_6:1.4503250382529038 - 14.

In [11]:
#COOT without labels
cot=    [perturbot.match.get_coupling_cot(
        folds[i]["data_train"],
    ) for i in range(len(folds))]

Delta: 0.11993746088859492  Loss: 2.49764674096519
Delta: 0.1492831966024853  Loss: 1.6041744980884043
Delta: 0.09426406484516758  Loss: 1.5574699467813164
Delta: 0.08126575420626933  Loss: 1.5405351029550278
Delta: 0.06325021371669094  Loss: 1.5377404485540809
Delta: 0.058722813232689665  Loss: 1.535788727904237
Delta: 0.06306640938985257  Loss: 1.5326769842014576
Delta: 0.050367962909822775  Loss: 1.5318723455075802
Delta: 0.0  Loss: 1.5318723455075802
converged at iter  8
Delta: 0.11993746088859497  Loss: 2.5847412651707047
Delta: 0.15053323475353314  Loss: 1.7517536634045496
Delta: 0.1063680924774783  Loss: 1.6622027771349046
Delta: 0.09443131218761702  Loss: 1.627044565110887
Delta: 0.10116561600250473  Loss: 1.5743897332615346
Delta: 0.09204998997439034  Loss: 1.5508775169748912
Delta: 0.07510705129197362  Loss: 1.5472178073407246
Delta: 0.059850236487159064  Loss: 1.5453849707800489
Delta: 0.04878865319652406  Loss: 1.5444512892344306
Delta: 0.037838821814150075  Loss: 1.5441392

In [12]:
#ECOOT without labels
ecot = [perturbot.match.get_coupling_cot_sinkhorn(
        folds[i]["data_train"],
        eps=0.500
    ) for i in range(len(folds))]

calculating with eps 0.5
lse step
lse step
Delta: 0.00048459246771793274  Loss: 4.698151771105321
lse step
lse step
Delta: 2.0483441858232254e-06  Loss: 4.698133917433545
lse step
lse step
Delta: 1.0717863041520559e-08  Loss: 4.6981340788062065
lse step
lse step
Delta: 2.8053090961321914e-09  Loss: 4.698134074360974
converged at iter  3
calculating with eps 0.5
lse step
lse step
Delta: 0.0004657801196778213  Loss: 4.795991302579692
lse step
lse step
Delta: 2.1942764760751743e-06  Loss: 4.795974513409747
lse step
lse step
Delta: 9.33427113380958e-09  Loss: 4.795974895737295
lse step
lse step
Delta: 8.017729258291695e-10  Loss: 4.795974894832713
converged at iter  3
calculating with eps 0.5
lse step
lse step
Delta: 0.0004918393208294741  Loss: 4.724306080610469
lse step
lse step
Delta: 2.18456852962845e-06  Loss: 4.724287346152258
lse step
lse step
Delta: 9.55880619102345e-09  Loss: 4.724287440646741
converged at iter  2
calculating with eps 0.5
lse step
lse step
Delta: 0.000502281814179

In [25]:
#Fused GW
def get_coupling_fused_gw(X_train, Y_train, labels_train, epsilon=0.001, alpha=0.99, max_iter=500):
    """
    Compute the Fused Gromov-Wasserstein coupling between X_train and Y_train
    using labels_train for the fused cost.
    """
    # -----------------------------
    # 1. Matrices de coût C1 et C2
    # -----------------------------
    # Calcul des matrices de coût quadratiques normalisées
    C1 = ot.dist(X_train, X_train, metric='euclidean')**2
    C2 = ot.dist(Y_train, Y_train, metric='euclidean')**2
    C1=C1/C1.max()
    C2=C2/C2.max()

    # -----------------------------
    # 2. Matrice de coût fused M
    # -----------------------------
    # Ici M[i,j] = 1 si labels différents, 0 sinon
    start = time.time()
    n_train = X_train.shape[0]
    m_train = Y_train.shape[0]
    labels_X = labels_train[:, np.newaxis]  # shape (n_train, 1)
    labels_Y = labels_train[np.newaxis, :]  # shape (1, m_train)

    # M = 1 si labels différents
    M = (labels_X != labels_Y).astype(float)

    print("Calcul du transport plan Fused Gromov-Wasserstein (alpha=1 -> GW uniquement)...")
    T = ot.gromov.entropic_fused_gromov_wasserstein(
        M, C1, C2, alpha=alpha, epsilon=epsilon, max_iter=max_iter, verbose=True
    )
    end= time.time()
    runtime = end - start
    return T, {"time": runtime}
fgw= [get_coupling_fused_gw(folds[i]["X_train_raw"], folds[i]["Y_train_raw"], folds[i]["labels_train_concat"], epsilon=0.001, alpha=0.99, max_iter=500) for i in range(len(folds))]

Calcul du transport plan Fused Gromov-Wasserstein (alpha=1 -> GW uniquement)...




It.  |Err         
-------------------
    0|1.516728e-02|
   10|2.191249e-03|
   20|8.704066e-04|
   30|1.507325e-04|
   40|3.021989e-05|
   50|6.315276e-06|
   60|1.332392e-06|
   70|2.816917e-07|
   80|5.958113e-08|
   90|1.260330e-08|
  100|2.666052e-09|
  110|5.639682e-10|
Calcul du transport plan Fused Gromov-Wasserstein (alpha=1 -> GW uniquement)...
It.  |Err         
-------------------
    0|1.470532e-02|
   10|3.191627e-04|
   20|1.258381e-04|
   30|1.055900e-04|
   40|1.752555e-04|
   50|4.854371e-04|
   60|8.408451e-04|
   70|3.860429e-04|
   80|2.004740e-04|
   90|1.392129e-04|
  100|2.297872e-04|
  110|2.129602e-04|
  120|4.309297e-05|
  130|8.864244e-06|
  140|1.887187e-06|
  150|4.028552e-07|
  160|8.600182e-08|
  170|1.835911e-08|
  180|3.919136e-09|
  190|8.366188e-10|
Calcul du transport plan Fused Gromov-Wasserstein (alpha=1 -> GW uniquement)...
It.  |Err         
-------------------
    0|1.444320e-02|
   10|2.610911e-04|
   20|1.162999e-05|
   30|9.825475e-07|
   

## **Train MLP to evaluate performances on obtained couplings**

In [26]:
#We train the MLP models for each coupling and evaluate their performance on the test set compared with the Ground Truth
folds_results=[]
for k in range(len(folds)):

  couplings = [legw, egw, egwper, cotl, cot, ecot, fgw]
  coupling_names = [
      "EGWOT Labeled (legw)",
      "EGWOT All (egw)",
      "EGWOT Per Label (egwper)",
      "COOT Labeled (cotl)",
      "COOT All (cot)",
      "ECOOT All (ecot)",
      "Fused GW (fgw)"
  ]

  data_train = folds[k]["data_train"]
  Y_test_concatenated = folds[k]["Y_test_concat"]
  X_test_t = folds[k]["X_test_torch"]
  results_pred = []

  print("\n Evaluation of the prediction models")

  for i, coupling in enumerate(couplings):
      coupling_name = coupling_names[i]
      coupling_for_train = coupling[k][0]
      time_for_train = coupling[k][1]
      model, pred_log = perturbot.predict.train_mlp(data_train, coupling_for_train)

      model.eval()
      with torch.no_grad():
          Y_pred = model(X_test_t).cpu().numpy().astype(np.float64)


      Y_test_concatenated = Y_test_concatenated.astype(np.float64)
      metrics_df_pred= get_evals_preds(
          Y_test_concatenated,  # Y_true (Grund Truth)
          [Y_pred],
          pred_labels=[coupling_name],
          full=False

      )

      metrics_dict = metrics_df_pred[coupling_name].to_dict()
      metrics_dict["method"] = coupling_name
      metrics_dict["time"] = time_for_train
      results_pred.append(metrics_dict)


      print(f"[{i+1}/{len(couplings)}] {coupling_name}: "
            f"MSE={metrics_dict.get('MSE', np.nan):.4f}, "
            f"Pearson={metrics_dict.get('Pearson_corr', np.nan):.4f}")


  results_df = pd.DataFrame(results_pred)
  print(f"\nSummary Table of Prediction Metrics for fold{k}")
  print(results_df)
  folds_results.append(results_df)
print(folds_results)


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si


 Evaluation of the prediction models
[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 184/2000:   9%|▉         | 184/2000 [00:08<01:28, 20.56it/s, v_num=70, train_loss_epoch=0.963]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.706. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[1/7] EGWOT Labeled (legw): MSE=1.0271, Pearson=0.8006
[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 122/2000:   6%|▌         | 122/2000 [00:05<01:25, 22.09it/s, v_num=71, train_loss_epoch=1.46]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.018. Signaling Trainer to stop.
[2/7] EGWOT All (egw): MSE=1.5639, Pearson=0.6847


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 123/2000:   6%|▌         | 123/2000 [00:06<01:33, 20.12it/s, v_num=72, train_loss_epoch=1.54]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.883. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[3/7] EGWOT Per Label (egwper): MSE=1.4573, Pearson=0.7109
[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 166/2000:   8%|▊         | 166/2000 [00:07<01:24, 21.63it/s, v_num=73, train_loss_epoch=0.527]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.775. Signaling Trainer to stop.
[4/7] COOT Labeled (cotl): MSE=1.2810, Pearson=0.7534


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 284/2000:  14%|█▍        | 284/2000 [00:13<01:24, 20.29it/s, v_num=74, train_loss_epoch=0.157]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.319. Signaling Trainer to stop.
[5/7] COOT All (cot): MSE=3.3427, Pearson=0.4131


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 93/2000:   5%|▍         | 93/2000 [00:05<01:44, 18.31it/s, v_num=75, train_loss_epoch=1.47]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.013. Signaling Trainer to stop.
[6/7] ECOOT All (ecot): MSE=1.6096, Pearson=0.6722


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 163/2000:   8%|▊         | 163/2000 [00:07<01:30, 20.38it/s, v_num=76, train_loss_epoch=0.726]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.641. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[7/7] Fused GW (fgw): MSE=1.1441, Pearson=0.7748

Summary Table of Prediction Metrics for fold0
   Pearson_corr  Spearman_corr  Pearson_samples  Spearman_samples       MSE  \
0      0.800622       0.776271         0.509766          0.493502  1.027143   
1      0.684694       0.671555         0.079354          0.077061  1.563907   
2      0.710925       0.694591         0.262618          0.252481  1.457340   
3      0.753402       0.729243         0.430479          0.415664  1.281006   
4      0.413148       0.451287        -0.151932         -0.146565  3.342661   
5      0.672182       0.663124        -0.004590         -0.006229  1.609624   
6      0.774815       0.751191         0.446359          0.430379  1.144098   

                     method                                               time  
0      EGWOT Labeled (legw)  {'n_iters_outer': 5, 'converged_inner': True, ...  
1           EGWOT All (egw)  {'n_iters_outer': 5, 'converged_inner': True, ...  
2  EGWOT Per Label (egwper) 

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 147/2000:   7%|▋         | 147/2000 [00:07<01:30, 20.41it/s, v_num=77, train_loss_epoch=0.992]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.959. Signaling Trainer to stop.
[1/7] EGWOT Labeled (legw): MSE=1.0234, Pearson=0.7803


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 156/2000:   8%|▊         | 156/2000 [00:08<01:34, 19.43it/s, v_num=78, train_loss_epoch=1.55]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.149. Signaling Trainer to stop.
[2/7] EGWOT All (egw): MSE=1.4167, Pearson=0.6984


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 118/2000:   6%|▌         | 118/2000 [00:05<01:28, 21.37it/s, v_num=79, train_loss_epoch=1.65]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.096. Signaling Trainer to stop.
[3/7] EGWOT Per Label (egwper): MSE=1.3077, Pearson=0.7263


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 184/2000:   9%|▉         | 184/2000 [00:08<01:28, 20.49it/s, v_num=80, train_loss_epoch=0.479]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.706. Signaling Trainer to stop.
[4/7] COOT Labeled (cotl): MSE=1.2138, Pearson=0.7533


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 229/2000:  11%|█▏        | 229/2000 [00:10<01:24, 20.91it/s, v_num=81, train_loss_epoch=0.193]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.315. Signaling Trainer to stop.
[5/7] COOT All (cot): MSE=2.9421, Pearson=0.4499


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 105/2000:   5%|▌         | 105/2000 [00:04<01:25, 22.22it/s, v_num=82, train_loss_epoch=1.56]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.117. Signaling Trainer to stop.
[6/7] ECOOT All (ecot): MSE=1.4229, Pearson=0.6966


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 222/2000:  11%|█         | 222/2000 [00:11<01:28, 20.10it/s, v_num=83, train_loss_epoch=0.735]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.594. Signaling Trainer to stop.
[7/7] Fused GW (fgw): MSE=1.1773, Pearson=0.7467

Summary Table of Prediction Metrics for fold1
   Pearson_corr  Spearman_corr  Pearson_samples  Spearman_samples       MSE  \
0      0.780305       0.757490         0.454507          0.412807  1.023418   
1      0.698431       0.684333         0.023662          0.025860  1.416657   
2      0.726289       0.705541         0.243902          0.225655  1.307737   
3      0.753270       0.735876         0.440067          0.425613  1.213834   
4      0.449900       0.484402        -0.053044         -0.055300  2.942125   
5      0.696604       0.683544         0.028240          0.026618  1.422915   
6      0.746741       0.728712         0.397436          0.361577  1.177309   

                     method                               

INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 151/2000:   8%|▊         | 151/2000 [00:08<01:41, 18.18it/s, v_num=84, train_loss_epoch=0.994]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.677. Signaling Trainer to stop.
[1/7] EGWOT Labeled (legw): MSE=1.1610, Pearson=0.7738


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 115/2000:   6%|▌         | 115/2000 [00:05<01:31, 20.67it/s, v_num=85, train_loss_epoch=1.57]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.832. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[2/7] EGWOT All (egw): MSE=1.5980, Pearson=0.6742
[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 156/2000:   8%|▊         | 156/2000 [00:08<01:35, 19.26it/s, v_num=86, train_loss_epoch=1.47]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.876. Signaling Trainer to stop.
[3/7] EGWOT Per Label (egwper): MSE=1.3898, Pearson=0.7274


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 238/2000:  12%|█▏        | 238/2000 [00:12<01:29, 19.64it/s, v_num=87, train_loss_epoch=0.401]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.642. Signaling Trainer to stop.
[4/7] COOT Labeled (cotl): MSE=1.6026, Pearson=0.6929


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 221/2000:  11%|█         | 221/2000 [00:11<01:30, 19.63it/s, v_num=88, train_loss_epoch=0.18]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.292. Signaling Trainer to stop.
[5/7] COOT All (cot): MSE=3.3768, Pearson=0.3945


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 104/2000:   5%|▌         | 104/2000 [00:05<01:35, 19.89it/s, v_num=89, train_loss_epoch=1.48]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.153. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[6/7] ECOOT All (ecot): MSE=1.5672, Pearson=0.6811
[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 148/2000:   7%|▋         | 148/2000 [00:07<01:33, 19.81it/s, v_num=90, train_loss_epoch=0.803]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.558. Signaling Trainer to stop.
[7/7] Fused GW (fgw): MSE=1.2714, Pearson=0.7554

Summary Table of Prediction Metrics for fold2
   Pearson_corr  Spearman_corr  Pearson_samples  Spearman_samples       MSE  \
0      0.773815       0.747606         0.434367          0.414779  1.160966   
1      0.674191       0.663938        -0.018014         -0.011572  1.598013   
2      0.727377       0.703116         0.318236          0.309425  1.389794   
3      0.692897       0.675110         0.323306          0.318581  1.602633   
4      0.394530       0.437790        -0.116077         -0.117253  3.376834   
5      0.681096       0.666093        -0.009852         -0.009338  1.567201   
6      0.755375       0.728088         0.395436          0.379626  1.271363   

                     method                               

INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 162/2000:   8%|▊         | 162/2000 [00:08<01:37, 18.85it/s, v_num=91, train_loss_epoch=1.05]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.741. Signaling Trainer to stop.
[1/7] EGWOT Labeled (legw): MSE=1.1471, Pearson=0.7689


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 136/2000:   7%|▋         | 136/2000 [00:06<01:29, 20.73it/s, v_num=92, train_loss_epoch=1.51]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.995. Signaling Trainer to stop.
[2/7] EGWOT All (egw): MSE=1.4975, Pearson=0.6914


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 115/2000:   6%|▌         | 115/2000 [00:06<01:44, 18.05it/s, v_num=93, train_loss_epoch=1.61]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.872. Signaling Trainer to stop.
[3/7] EGWOT Per Label (egwper): MSE=1.3501, Pearson=0.7295


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 173/2000:   9%|▊         | 173/2000 [00:08<01:34, 19.35it/s, v_num=94, train_loss_epoch=0.506]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.816. Signaling Trainer to stop.
[4/7] COOT Labeled (cotl): MSE=1.3419, Pearson=0.7362


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 251/2000:  13%|█▎        | 251/2000 [00:12<01:28, 19.84it/s, v_num=95, train_loss_epoch=0.172]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.283. Signaling Trainer to stop.
[5/7] COOT All (cot): MSE=2.5360, Pearson=0.5439


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 176/2000:   9%|▉         | 176/2000 [00:08<01:28, 20.57it/s, v_num=96, train_loss_epoch=1.61]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.900. Signaling Trainer to stop.
[6/7] ECOOT All (ecot): MSE=1.4878, Pearson=0.6952


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 246/2000:  12%|█▏        | 246/2000 [00:12<01:28, 19.79it/s, v_num=97, train_loss_epoch=0.755]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.604. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[7/7] Fused GW (fgw): MSE=1.2285, Pearson=0.7597

Summary Table of Prediction Metrics for fold3
   Pearson_corr  Spearman_corr  Pearson_samples  Spearman_samples       MSE  \
0      0.768944       0.748035         0.413426          0.384563  1.147146   
1      0.691398       0.679952        -0.003563         -0.008798  1.497462   
2      0.729464       0.711522         0.260704          0.246269  1.350096   
3      0.736159       0.715597         0.382388          0.367790  1.341881   
4      0.543881       0.545694         0.071101          0.062603  2.536006   
5      0.695229       0.687307         0.017760          0.023164  1.487841   
6      0.759720       0.739254         0.406498          0.389093  1.228467   

                     method                                               time  
0      EGWOT Labeled (legw)  {'n_iters_outer': 5, 'converged_inner': True, ...  
1           EGWOT All (egw)  {'n_iters_outer': 5, 'converged_inner': True, ...  
2  EGWOT Per Label (egwper) 

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 136/2000:   7%|▋         | 136/2000 [00:07<01:36, 19.27it/s, v_num=98, train_loss_epoch=0.987]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.630. Signaling Trainer to stop.
[1/7] EGWOT Labeled (legw): MSE=1.0767, Pearson=0.7962


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 132/2000:   7%|▋         | 132/2000 [00:06<01:38, 18.93it/s, v_num=99, train_loss_epoch=1.59]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.013. Signaling Trainer to stop.


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[2/7] EGWOT All (egw): MSE=1.6037, Pearson=0.6817
[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 121/2000:   6%|▌         | 121/2000 [00:05<01:32, 20.21it/s, v_num=100, train_loss_epoch=1.51]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.012. Signaling Trainer to stop.
[3/7] EGWOT Per Label (egwper): MSE=1.4029, Pearson=0.7315


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 189/2000:   9%|▉         | 189/2000 [00:09<01:32, 19.53it/s, v_num=101, train_loss_epoch=0.479]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.728. Signaling Trainer to stop.
[4/7] COOT Labeled (cotl): MSE=1.3765, Pearson=0.7355


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 285/2000:  14%|█▍        | 285/2000 [00:14<01:27, 19.59it/s, v_num=102, train_loss_epoch=0.153]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.301. Signaling Trainer to stop.
[5/7] COOT All (cot): MSE=3.2036, Pearson=0.4307


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 174/2000:   9%|▊         | 174/2000 [00:09<01:35, 19.08it/s, v_num=103, train_loss_epoch=1.52]
Monitored metric val_loss did not improve in the last 45 records. Best score: 1.013. Signaling Trainer to stop.
[6/7] ECOOT All (ecot): MSE=1.5845, Pearson=0.6847


INFO: GPU available: True (cuda), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
INFO: 
  | Name         | Type       | Params
--------------------------------------------
0 | model        | Sequential | 15.5 K
  | other params | n/a        | 200   
--------------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params si

[34mINFO    [0m Running sanity check on val set[33m...[0m                                                                        


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 228/2000:  11%|█▏        | 228/2000 [00:12<01:36, 18.37it/s, v_num=104, train_loss_epoch=0.734]
Monitored metric val_loss did not improve in the last 45 records. Best score: 0.540. Signaling Trainer to stop.
[7/7] Fused GW (fgw): MSE=1.1060, Pearson=0.7885

Summary Table of Prediction Metrics for fold4
   Pearson_corr  Spearman_corr  Pearson_samples  Spearman_samples       MSE  \
0      0.796234       0.774503         0.468296          0.457856  1.076708   
1      0.681674       0.665929        -0.030459         -0.030090  1.603715   
2      0.731520       0.707762         0.305637          0.298811  1.402888   
3      0.735495       0.710584         0.354215          0.349575  1.376545   
4      0.430682       0.456679        -0.084386         -0.080491  3.203643   
5      0.684730       0.673178         0.051322          0.044923  1.584544   
6      0.788456       0.761155         0.466546          0.450340  1.105953   

                     method                              

In [27]:
metrics = ['Pearson_corr', 'Spearman_corr', 'Pearson_samples', 'Spearman_samples', 'MSE']
all_data = pd.concat(folds_results, keys=range(len(folds_results)), names=['fold', 'index'])
mean_df = all_data.groupby('method')[metrics].mean()
std_df = all_data.groupby('method')[metrics].std(ddof=1)
n_folds = len(folds_results)
sem_df = std_df / np.sqrt(n_folds)
summary_df = mean_df.copy()
for col in metrics:
    summary_df[col] = mean_df[col].round(3).astype(str) + ' ± ' + sem_df[col].round(3).astype(str)

summary_df

Unnamed: 0_level_0,Pearson_corr,Spearman_corr,Pearson_samples,Spearman_samples,MSE
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
COOT All (cot),0.446 ± 0.026,0.475 ± 0.019,-0.067 ± 0.038,-0.067 ± 0.036,3.08 ± 0.156
COOT Labeled (cotl),0.734 ± 0.011,0.713 ± 0.011,0.386 ± 0.022,0.375 ± 0.02,1.363 ± 0.066
ECOOT All (ecot),0.686 ± 0.005,0.675 ± 0.005,0.017 ± 0.011,0.016 ± 0.01,1.534 ± 0.035
EGWOT All (egw),0.686 ± 0.004,0.673 ± 0.004,0.01 ± 0.019,0.01 ± 0.019,1.536 ± 0.035
EGWOT Labeled (legw),0.784 ± 0.006,0.761 ± 0.006,0.456 ± 0.016,0.433 ± 0.019,1.087 ± 0.029
EGWOT Per Label (egwper),0.725 ± 0.004,0.705 ± 0.003,0.278 ± 0.014,0.267 ± 0.016,1.382 ± 0.025
Fused GW (fgw),0.765 ± 0.007,0.742 ± 0.006,0.422 ± 0.014,0.402 ± 0.016,1.185 ± 0.029


In [28]:
folds_results_match=[]
for k in range(len(folds)):

    print("Evaluation of coupling metrics on the training data")
    #Metrics evaluation on the training set
    data_train = folds[k]["data_train"]
    Xs_train_dict = data_train[0]
    Xt_train_dict = data_train[1]

    labeled_couplings = [legw[k], cotl[k]]
    labeled_coupling_names = [
        "EGWOT Labeled (legw)",
        "COOT Labeled (cotl)"
    ]
    global_couplings = [egw[k], cot[k], ecot[k], fgw[k]]
    global_coupling_names = [
        "EGWOT Global (egw)",
        "COOT Global (cot)",
        "ECOOT Global (ecot)",
        "Fused GW (fgw)"
    ]

    results_match = []
    #Labeled couplings
    for i, coupling in enumerate(labeled_couplings):
        coupling_name = labeled_coupling_names[i]
        T_dict = coupling[0]

        #Bary FSTTM (Barycentre FOSCTTM: use_barycenter=True)
        foscttm_bary_list, median_foscttm_bary = get_FOSCTTM(
            T_dict=T_dict, Xs_dict=Xs_train_dict, Xt_dict=Xt_train_dict, use_barycenter=True
        )

        all_foscttm_bary = foscttm_bary_list
        mean_foscttm_bary = np.nanmean(all_foscttm_bary)

        results_match.append({
            "method": coupling_name,
            "Bary_FOSCTTM": median_foscttm_bary,
            "Mean_Bary_FOSCTTM": mean_foscttm_bary,
        })

    #Global couplings
    for i, coupling in enumerate(global_couplings):
        coupling_name = global_coupling_names[i]
        T = coupling[0]

        foscttm_bary_list, median_foscttm_bary = get_FOSCTTM_single(
            T=T, Xs_dict=Xs_train_dict, Xt_dict=Xt_train_dict, use_barycenter=True
        )

        all_foscttm_bary = foscttm_bary_list
        mean_foscttm_bary = np.nanmean(all_foscttm_bary)

        results_match.append({
            "method": coupling_name,
            "Bary_FOSCTTM": median_foscttm_bary,

            "Mean_Bary_FOSCTTM": mean_foscttm_bary,

        })

    results_match_df = pd.DataFrame(results_match)
    print("\nSummary Table of Coupling Metrics (Training Set)")
    print(results_match_df)
    folds_results_match.append(results_match_df)
print(folds_results_match)

Evaluation of coupling metrics on the training data

Summary Table of Coupling Metrics (Training Set)
                 method  Bary_FOSCTTM  Mean_Bary_FOSCTTM
0  EGWOT Labeled (legw)      0.244480           0.244480
1   COOT Labeled (cotl)      0.306980           0.306980
2    EGWOT Global (egw)      0.488722           0.489098
3     COOT Global (cot)      0.609023           0.577556
4   ECOOT Global (ecot)      0.500000           0.500069
5        Fused GW (fgw)      0.154135           0.247143
Evaluation of coupling metrics on the training data

Summary Table of Coupling Metrics (Training Set)
                 method  Bary_FOSCTTM  Mean_Bary_FOSCTTM
0  EGWOT Labeled (legw)      0.238108           0.238108
1   COOT Labeled (cotl)      0.280833           0.280833
2    EGWOT Global (egw)      0.496241           0.497632
3     COOT Global (cot)      0.508772           0.519524
4   ECOOT Global (ecot)      0.500000           0.500031
5        Fused GW (fgw)      0.166667           0.24317

In [29]:
metrics = ['Bary_FOSCTTM', 'Mean_Bary_FOSCTTM' ]
all_data = pd.concat(folds_results_match, keys=range(len(folds_results_match)), names=['fold', 'index'])
mean_df = all_data.groupby('method')[metrics].mean()
std_df = all_data.groupby('method')[metrics].std(ddof=1)
n_folds = len(folds_results_match)
sem_df = std_df / np.sqrt(n_folds)
summary_df = mean_df.copy()
for col in metrics:
    summary_df[col] = mean_df[col].round(3).astype(str) + ' ± ' + sem_df[col].round(3).astype(str)

summary_df

Unnamed: 0_level_0,Bary_FOSCTTM,Mean_Bary_FOSCTTM
method,Unnamed: 1_level_1,Unnamed: 2_level_1
COOT Global (cot),0.536 ± 0.039,0.524 ± 0.028
COOT Labeled (cotl),0.294 ± 0.005,0.294 ± 0.005
ECOOT Global (ecot),0.5 ± 0.0,0.5 ± 0.0
EGWOT Global (egw),0.498 ± 0.003,0.498 ± 0.002
EGWOT Labeled (legw),0.248 ± 0.005,0.248 ± 0.005
Fused GW (fgw),0.168 ± 0.005,0.249 ± 0.005


## **Computation times**

In [30]:
# Time information for the couplings
for k in range(len(folds)):
    results_df=folds_results[k]

    for index, log_dict in results_df['time'].items():
        if isinstance(log_dict, dict):
            print(f"Ligne {index} - Clés : {list(log_dict.keys())}")
        else:
            print(f"Ligne {index} - Ce n'est pas un dictionnaire : {type(log_dict)}")

    def extract_total_time(log_dict):
        if 'time' in log_dict:
            return log_dict['time']
        total_time = 0
        found_time = False
        for key, val in log_dict.items():
            if isinstance(val, dict) and 'time' in val:
                total_time += val['time']
                found_time = True

        if found_time:
            return total_time

        return None

    results_df['extracted_time'] = results_df['time'].apply(extract_total_time)

    print(results_df[['method', 'extracted_time']])
    folds_results[k]=results_df

Ligne 0 - Clés : ['n_iters_outer', 'converged_inner', 'converged_outer', 'GW cost', 'time', 'cost_time']
Ligne 1 - Clés : ['n_iters_outer', 'converged_inner', 'converged_outer', 'GW cost', 'time', 'cost_time']
Ligne 2 - Clés : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Ligne 3 - Clés : ['cost', 'time']
Ligne 4 - Clés : ['cost', 'time']
Ligne 5 - Clés : ['cost', 'time']
Ligne 6 - Clés : ['time']
                     method  extracted_time
0      EGWOT Labeled (legw)       11.299049
1           EGWOT All (egw)        2.930561
2  EGWOT Per Label (egwper)       26.871022
3       COOT Labeled (cotl)        0.069252
4            COOT All (cot)        1.505712
5          ECOOT All (ecot)       18.682953
6            Fused GW (fgw)       23.932863
Ligne 0 - Clés : ['n_iters_outer', 'converged_inner', 'converged_outer', 'GW cost', 'time', 'cost_time']
Ligne 1 - Clés : ['n_iters_outer', 'converged_inner', 'converged_outer', 'GW cost', 'time', 'cost_time']
Ligne 2 - Clés : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Lign

In [31]:
metrics = ['Pearson_corr', 'Spearman_corr', 'Pearson_samples', 'Spearman_samples', 'MSE', 'extracted_time']
all_data = pd.concat(folds_results, keys=range(len(folds_results)), names=['fold', 'index'])
mean_df = all_data.groupby('method')[metrics].mean()
std_df = all_data.groupby('method')[metrics].std(ddof=1)
n_folds = len(folds_results)
sem_df = std_df / np.sqrt(n_folds)
summary_df = mean_df.copy()
for col in metrics:
    summary_df[col] = mean_df[col].round(3).astype(str) + ' ± ' + sem_df[col].round(3).astype(str)

summary_df

Unnamed: 0_level_0,Pearson_corr,Spearman_corr,Pearson_samples,Spearman_samples,MSE,extracted_time
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
COOT All (cot),0.446 ± 0.026,0.475 ± 0.019,-0.067 ± 0.038,-0.067 ± 0.036,3.08 ± 0.156,0.957 ± 0.232
COOT Labeled (cotl),0.734 ± 0.011,0.713 ± 0.011,0.386 ± 0.022,0.375 ± 0.02,1.363 ± 0.066,0.087 ± 0.01
ECOOT All (ecot),0.686 ± 0.005,0.675 ± 0.005,0.017 ± 0.011,0.016 ± 0.01,1.534 ± 0.035,14.235 ± 1.376
EGWOT All (egw),0.686 ± 0.004,0.673 ± 0.004,0.01 ± 0.019,0.01 ± 0.019,1.536 ± 0.035,2.454 ± 0.162
EGWOT Labeled (legw),0.784 ± 0.006,0.761 ± 0.006,0.456 ± 0.016,0.433 ± 0.019,1.087 ± 0.029,5.054 ± 1.562
EGWOT Per Label (egwper),0.725 ± 0.004,0.705 ± 0.003,0.278 ± 0.014,0.267 ± 0.016,1.382 ± 0.025,24.008 ± 0.959
Fused GW (fgw),0.765 ± 0.007,0.742 ± 0.006,0.422 ± 0.014,0.402 ± 0.016,1.185 ± 0.029,31.417 ± 9.82


## **Coupling matrices visualization**

In [None]:
def _safe_name(name: str) -> str:
    """Make a filesystem-friendly name from a method string."""
    return (
        name.lower()
            .replace(" ", "_")
            .replace("(", "")
            .replace(")", "")
            .replace("/", "_")
    )

def plot_labeled_coupling_all_folds(coupling_list, method_name, label_to_plot, max_display=500, save_dir="."):
    """
    For a labeled method (coupling_list[fold] = (T_dict, info)):
    create one figure with subplots for all folds for a given label.
    """
    n_folds = len(coupling_list)

    fig, axes = plt.subplots(
        1, n_folds,
        figsize=(4 * n_folds, 4),
        squeeze=False
    )
    axes = axes[0]  # row of axes

    for fold in range(n_folds):
        ax = axes[fold]
        T_dict = coupling_list[fold][0]

        if label_to_plot not in T_dict:
            ax.set_visible(False)
            continue

        P = T_dict[label_to_plot]
        if hasattr(P, "cpu"):
            P = P.cpu().numpy()

        # optionally crop for display if huge
        ds = min(P.shape[0], max_display)
        dt = min(P.shape[1], max_display)
        P_disp = P[:ds, :dt]

        sns.heatmap(
            P_disp,
            cmap="viridis",
            cbar=(fold == n_folds - 1),  # only last subplot has colorbar
            xticklabels=False,
            yticklabels=False,
            square=True,
            ax=ax
        )
        ax.set_title(f"Fold {fold}\n{ds}×{dt}", fontsize=10)

    fig.suptitle(f"{method_name} - Label {label_to_plot}", fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    os.makedirs(save_dir, exist_ok=True)
    fname = f"{_safe_name(method_name)}_label_{label_to_plot}.png"
    path = os.path.join(save_dir, fname)
    plt.savefig(path, dpi=150)
    plt.show()

    print(f"Saved: {path}")

def plot_global_coupling_all_folds(coupling_list, method_name, max_display=500, save_dir="."):
    """
    For a global/unlabeled method (coupling_list[fold] = (T, info)):
    create one figure with subplots for all folds.
    """
    n_folds = len(coupling_list)

    fig, axes = plt.subplots(
        1, n_folds,
        figsize=(4 * n_folds, 4),
        squeeze=False
    )
    axes = axes[0]

    for fold in range(n_folds):
        ax = axes[fold]
        T = coupling_list[fold][0]
        if hasattr(T, "cpu"):
            T = T.cpu().numpy()

        ds = min(T.shape[0], max_display)
        dt = min(T.shape[1], max_display)
        P_disp = T[:ds, :dt]

        sns.heatmap(
            P_disp,
            cmap="viridis",
            cbar=(fold == n_folds - 1),
            xticklabels=False,
            yticklabels=False,
            square=True,
            ax=ax
        )
        ax.set_title(f"Fold {fold}\n{ds}×{dt}", fontsize=10)

    fig.suptitle(f"{method_name} - All Folds", fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    os.makedirs(save_dir, exist_ok=True)
    fname = f"{_safe_name(method_name)}.png"
    path = os.path.join(save_dir, fname)
    plt.savefig(path, dpi=150)
    plt.show()

    print(f"Saved: {path}")


In [None]:
print("\n=== VISUALIZATION OF COUPLING MATRICES - ALL FOLDS IN SUBPLOTS ===\n")

save_dir = "./synthetic_coupling_plots"

# ---- Labeled methods ----
labeled_coupling_lists = [legw, cotl, ecotl]
labeled_coupling_names = [
    "EGWOT Labeled (legw)",
    "COOT Labeled (cotl)",
    "ECOOT Labeled (ecotl)",
]

for c_list, name in zip(labeled_coupling_lists, labeled_coupling_names):
    for label in np.unique(labels):
        print(f"\n{name} — Label {label}: plotting all folds...")
        plot_labeled_coupling_all_folds(
            coupling_list=c_list,
            method_name=name,
            label_to_plot=label,
            max_display=500,
            save_dir=save_dir,
        )

# ---- Global / full-matrix methods ----
global_coupling_lists = [egw, egwper, cot, ecot, fgw]
global_coupling_names = [
    "EGWOT Global (egw)",
    "EGWOT Per Label (egwper)",
    "COOT Global (cot)",
    "ECOOT Global (ecot)",
    "Fused GW (fgw)",
]

for c_list, name in zip(global_coupling_lists, global_coupling_names):
    print(f"\n{name}: plotting all folds...")
    plot_global_coupling_all_folds(
        coupling_list=c_list,
        method_name=name,
        max_display=500,
        save_dir=save_dir,
    )

## **Grid search on regularization coefficient**

In [None]:
# # GRID SEARCH SUR EPSILON POUR LES MÉTHODES ENTROPIC
# epsilon_values = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 0.5, 1.0]

# entropic_methods = [
#     ("EGWOT Labeled (legw)", perturbot.match.get_coupling_egw_labels_ott),
#     ("EGWOT All (egw)",      perturbot.match.get_coupling_egw_all_ott),
#     ("EGWOT Per Label (egwper)", perturbot.match.get_coupling_egw_ott),
#     ("ECOOT Labeled (ecotl)", perturbot.match.get_coupling_cotl_sinkhorn),
#     ("ECOOT All (ecot)",     perturbot.match.get_coupling_cot_sinkhorn),
#     ("Fused GW (fgw)",       get_coupling_fused_gw)
# ]

# results_grid = []


# def extract_time_robust(log_dict):
#     if not isinstance(log_dict, dict): return np.nan
#     if 'time' in log_dict: return log_dict['time']
#     first_key = list(log_dict.keys())[0]
#     if isinstance(log_dict[first_key], dict) and 'time' in log_dict[first_key]:
#         total_time = sum([v['time'] for k, v in log_dict.items() if isinstance(v, dict) and 'time' in v])
#         return total_time
#     return np.nan

# print(f"--- Begining of Grid Search on {len(epsilon_values)} epsilon values ---\n")

# for eps in epsilon_values:
#     print(f"\n>>> Testing Epsilon = {eps}")
    

#     for method_name, method_func in entropic_methods:
        
#         print(f"   Running {method_name}...", end=" ")
        
#         try:
#             coupling_res = method_func(data_train, eps=eps)
            
#             coupling_mat = coupling_res[0]
#             log_dict = coupling_res[1]     
            

#             model, pred_log = perturbot.predict.train_mlp(data_train, coupling_mat)
            
#             model.eval()
#             with torch.no_grad():
#                 Y_pred = model(test_adata_PROT.X).cpu().numpy().astype(np.float64)
            

#             Y_test_concatenated = Y_test_concatenated.astype(np.float64)
#             metrics_df_pred = get_evals_preds(
#                 Y_test_concatenated, 
#                 [Y_pred],              
#                 pred_labels=[method_name], 
#                 full=False
#             )
            
#             metrics_dict = metrics_df_pred[method_name].to_dict()
#             metrics_dict["method"] = method_name
#             metrics_dict["epsilon"] = eps
#             metrics_dict["time"] = extract_time_robust(log_dict)
            
#             results_grid.append(metrics_dict)
            
#             print(f"Done. (MSE={metrics_dict.get('MSE', np.nan):.4f}, Time={metrics_dict.get('time', np.nan):.2f}s)")
            
#         except Exception as e:
#             print(f"FAILED. Error: {e}")

#             results_grid.append({
#                 "method": method_name,
#                 "epsilon": eps,
#                 "MSE": np.nan,
#                 "error": str(e)
#             })


# df_grid_results = pd.DataFrame(results_grid)


# df_grid_results = df_grid_results.sort_values(by=["method", "epsilon"])

# print("\n--- Summary Table of the Grid Search ---")

# cols_to_show = ['method', 'epsilon', 'MSE', 'Pearson_corr', 'time']
# available_cols = [c for c in cols_to_show if c in df_grid_results.columns]
# print(df_grid_results[available_cols])

In [None]:
# # On trie d'abord par 'method' (pour regrouper), puis par 'MSE' croissant (le plus petit en premier)
# df_sorted = df_grid_results.sort_values(by=['method', 'MSE'], ascending=[True, True])

# print("\n--- Résultats triés par Modèle (Meilleur MSE en premier) ---")
# print(df_sorted[available_cols])