In [1]:
import torch
import pandas as pd
import os
import scanpy as sc

# Mouse brain

In [2]:
# COMMOT
adata=sc.read_h5ad("./COMMOT/mouse.h5ad")

In [3]:
print(adata.obsm['commot-user_database-sum-receiver'])

      r-Vip-Vipr2  r-Prok2-Prokr2  r-Pdgfc-Pdgfra  r-Penk-Oprk1  r-total-total
0        0.000000             0.0        0.000000      0.000000       0.000000
1        0.000000             0.0        0.000000      0.000000       0.000000
2        0.000000             0.0        0.343711      0.525687       0.869399
3        0.000000             0.0        0.000000      0.000000       0.000000
4        0.000000             0.0        1.437130      0.000000       1.437130
...           ...             ...             ...           ...            ...
6132     0.000000             0.0        0.662891      0.664754       1.327645
6133     0.321654             0.0        0.321632      0.321654       0.964940
6134     0.000000             0.0        0.000000      0.000000       0.000000
6135     0.000000             0.0        0.000000      0.000000       0.000000
6136     0.000000             0.0        0.000000      0.000000       0.000000

[6137 rows x 5 columns]


In [4]:
cell_states_df=pd.read_csv("/vast/palmer/scratch/wang_zuoheng/xx244/GITIII_backup/data/Mouse_brain/processed1/mouse1_slice201.csv")
print(cell_states_df.shape)

(6137, 425)


In [5]:
genes=torch.load("./data/mouse/genes.pth")
cell_states=cell_states_df.loc[:,genes].values

In [6]:
print(cell_states.shape)

(6137, 254)


In [7]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

def fit_linear_regression(data_matrix, target_matrix):
    """
    Fit a linear regression model using the data matrix to predict the target matrix.
    Splits data into training and validation sets (4:1 ratio), fits the model, 
    and calculates the mean squared error (MSE) loss.
    
    Parameters:
    data_matrix (numpy.ndarray): Shape (n, c1), input features.
    target_matrix (numpy.ndarray): Shape (n, c2), target values.

    Returns:
    tuple: (validation_mse, total_target_mse)
    """
    # Split data into training (80%) and validation (20%) sets
    X_train, X_val, y_train, y_val = train_test_split(data_matrix, target_matrix, test_size=0.2, random_state=42)

    # Fit linear regression model
    model = LinearRegression()
    model.fit(X_train, y_train)

    # Predict on validation set
    y_pred = model.predict(X_val)

    # Compute MSE loss on validation set
    validation_mse = mean_squared_error(y_val, y_pred)
    print(f"Validation MSE: {validation_mse:.6f}")

    # Compute total variance of target (MSE from mean predictor)
    total_target_mse = mean_squared_error(target_matrix, np.zeros_like(target_matrix))
    print(f"Total Target MSE: {total_target_mse:.6f}")

    return validation_mse, total_target_mse

In [8]:
mse_commot=fit_linear_regression(data_matrix=adata.obsm['commot-user_database-sum-receiver'].values, target_matrix=cell_states)
print("variance_explained:",mse_commot[1]-mse_commot[0])

Validation MSE: 0.172656
Total Target MSE: 0.175501
variance_explained: 0.00284529540641279


In [9]:
# GITIII
results=torch.load("/vast/palmer/scratch/wang_zuoheng/xx244/GITIII_backup/Mouse_brain_evaluate/edges/"+"edges_mouse1_slice201.pth",map_location=torch.device('cpu'), weights_only=False)

In [10]:
all_variance=torch.mean(torch.square(results["y"]))
mse_GITIII=torch.mean(torch.square(results["y"]-results["y_pred"]))
print(all_variance,mse_GITIII,all_variance-mse_GITIII)

tensor(0.1755) tensor(0.1435) tensor(0.0320)


# NSCLC

In [25]:
# COMMOT
adata=sc.read_h5ad("./COMMOT/NSCLC.h5ad")

cell_states_df=pd.read_csv("/vast/palmer/scratch/wang_zuoheng/xx244/GITIII_backup/data/NSCLC/processed1/Lung6.csv")
genes=torch.load("./data/NSCLC/genes.pth")
cell_states=cell_states_df.loc[:,genes].values

mse_commot=fit_linear_regression(data_matrix=adata.obsm['commot-user_database-sum-receiver'].values, target_matrix=cell_states)
print("variance_explained:",mse_commot[1]-mse_commot[0])

Validation MSE: 0.082926
Total Target MSE: 0.092379
variance_explained: 0.009453542227465847


In [26]:
# GITIII
results=torch.load("/vast/palmer/scratch/wang_zuoheng/xx244/GITIII_backup/NSCLC_evaluate/edges/"+"edges_Lung6.pth",map_location=torch.device('cpu'), weights_only=False)

In [27]:
all_variance=torch.mean(torch.square(results["y"]))
mse_GITIII=torch.mean(torch.square(results["y"]-results["y_pred"]))
print(all_variance-mse_GITIII)

tensor(0.0156)


# BC

In [31]:
# COMMOT
adata=sc.read_h5ad("./COMMOT/BC.h5ad")

cell_states_df=pd.read_csv("/vast/palmer/scratch/wang_zuoheng/xx244/GITIII_backup/data/BC/processed1/sample1_rep1.csv")
genes=torch.load("./data/BC/genes.pth")
cell_states=cell_states_df.loc[:,genes].values

mse_commot=fit_linear_regression(data_matrix=adata.obsm['commot-user_database-sum-receiver'].values, target_matrix=cell_states)
print("variance_explained:",mse_commot[1]-mse_commot[0])

Validation MSE: 0.131581
Total Target MSE: 0.133038
variance_explained: 0.0014569300435102184


In [33]:
# GITIII
results=torch.load("/vast/palmer/scratch/wang_zuoheng/xx244/GITIII_backup/BC_evaluate/edges/"+"edges_sample1_rep1.pth",map_location=torch.device('cpu'), weights_only=False)

In [34]:
all_variance=torch.mean(torch.square(results["y"]))
mse_GITIII=torch.mean(torch.square(results["y"]-results["y_pred"]))
print(all_variance-mse_GITIII)

tensor(0.0307)
