In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, explained_variance_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import MinMaxScaler
from itertools import product, combinations
from scipy.sparse import coo_matrix
from scipy.stats import spearmanr, pearsonr
import networkx as nx
from utils_plot import *

In [None]:
# 1. Data Normalization
def preprocessing(alpha1, alpha2, N, ampl_threshold=0.2):

    targetnames = np.array(sorted(set(N.index) & set(alpha1.index) & set(alpha2.index)))
    print("Genes in common :", len(targetnames))

    alpha1, alpha2 = alpha1.loc[targetnames].to_numpy(), alpha2.loc[targetnames].to_numpy()
 
    ampl1 = (alpha1.max(axis=1)-alpha1.min(axis=1))/2
    ampl2 = (alpha2.max(axis=1)-alpha2.min(axis=1))/2
    ind = (ampl1 > ampl_threshold) & (ampl2 > ampl_threshold)
    alpha1, alpha2 = alpha1[ind,:], alpha2[ind,:]
    targetnames_filtered = targetnames[ind]

    N = N.loc[targetnames_filtered].to_numpy()
    
    # Identify TFs that are not present in any gene
    inactive_tfs = np.where(N.sum(axis=0) == 0)[0]
    print(f"Number of inactive TFs: {len(inactive_tfs)}")
    N = np.delete(N, inactive_tfs, axis=1)
    tf_names_filtered = np.delete(tf_names, inactive_tfs)
    
    print(f"Kept genes: {N.shape[0]} (ampl > {ampl_threshold})")
    alpha1_norm = alpha1 - np.mean(alpha1, axis=1, keepdims=True) - np.mean(alpha1, axis=0, keepdims=True) + np.mean(alpha1)
    alpha2_norm = alpha2 - np.mean(alpha2, axis=1, keepdims=True) - np.mean(alpha2, axis=0, keepdims=True) + np.mean(alpha2)
    #N_norm = N - np.mean(N, axis=0, keepdims=True) #We will optimize the sparse matrix, so we need to keep the absolute zero values.

    return alpha1_norm, alpha2_norm, N, targetnames_filtered, tf_names_filtered

In [None]:
#Model 1
def svd_regression_with_lambda_CV(alpha, N, lambdas, n_splits=5, seed=42):
    """
    Perform regression using SVD and select best regularization parameter (lambda)
    using k-fold Cross-Validation, following ISMARA approach.
    """
    # Get dimensions
    G, M = N.shape
    C = alpha.shape[1]
    
    # Perform SVD once
    U, s, VT = np.linalg.svd(N, full_matrices=False)
    
    # Initialize k-fold cross-validation
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    
    # Metrics storage
    val_errors = np.zeros((len(lambdas), n_splits))
    val_explained_variances = np.zeros((len(lambdas), n_splits))
    train_explained_variances = np.zeros((len(lambdas), n_splits))
    
    # Cross-validation loop
    for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(G))):
        # Prepare matrices for this fold
        U_train = U[train_idx, :]
        alpha_train = alpha[train_idx, :]
        M_train = U_train.T @ alpha_train
        
        U_val = U[val_idx, :]
        alpha_val = alpha[val_idx, :]
        
        # Test each lambda value
        for i, lambd in enumerate(lambdas):
            # Calculate shrinkage factors
            shrink = s / (s**2 + lambd)
            
            # Get A_star for this lambda using training data
            A_star = VT.T @ (shrink[:, None] * M_train)
            
            # Make predictions
            R_train = N[train_idx, :] @ A_star
            R_val = N[val_idx, :] @ A_star
            
            # Calculate metrics
            val_errors[i, fold] = mean_squared_error(alpha_val.T, R_val.T)
            val_explained_variances[i, fold] = explained_variance_score(alpha_val, R_val)
            train_explained_variances[i, fold] = explained_variance_score(alpha_train, R_train)
    
    # Average metrics across folds
    mean_val_errors = np.mean(val_errors, axis=1)
    mean_val_explained_variances = np.mean(val_explained_variances, axis=1)
    mean_train_explained_variances = np.mean(train_explained_variances, axis=1)
    
    # Find optimal lambda
    best_lambda_idx = np.argmin(mean_val_errors)
    lambda_opt = lambdas[best_lambda_idx]
    
    # Train final model on all data using optimal lambda
    shrink_opt = s / (s**2 + lambda_opt)
    M_full = U.T @ alpha
    A_star = VT.T @ (shrink_opt[:, None] * M_full)

    # Plot
    fig, ax1 = plt.subplots(figsize=(4,3))

    color = 'tab:blue'
    ax1.set_xlabel('Lambda')
    ax1.set_ylabel('Validation MSE', color=color)
    ax1.plot(lambdas, mean_val_errors, color=color, label='Validation MSE')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_xscale('log')
    ax1.grid(True)

    ax2 = ax1.twinx()
    ax2.set_ylabel('Explained Variance (%)')

    color = 'tab:green'
    ax2.plot(lambdas, mean_val_explained_variances * 100, color='tab:green', linestyle='--', label='Validation EV')
    ax2.plot(lambdas, mean_train_explained_variances * 100, color='tab:orange', linestyle='--', label='Training EV')
    ax2.tick_params(axis='y')

    # Legends
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax2.legend(lines_1 + lines_2, labels_1 + labels_2, loc='best')

    fig.suptitle('Cross-Validation: MSE and EV vs Lambda')
    fig.tight_layout()
    plt.show()

    return A_star, lambda_opt

In [None]:
# Model 2: OscilloTF
# 2 Define Ridge Regression Model with Trainable Sparse W
class TrainableModel(nn.Module):
    def __init__(self, N, alpha, A_split, num_tfs, num_thetas, lambdaW=0.01, lambdaA=0.01):
        super(TrainableModel, self).__init__()
        
        self.lambdaW = lambdaW  # L1 regularization for W
        self.lambdaA = lambdaA  # L2 regularization for A

        # Convert N to COO format
        sparse_matrix = coo_matrix(N)

        # Get the nonzero indices and values
        self.i = torch.tensor(sparse_matrix.row, dtype=torch.long)
        self.j = torch.tensor(sparse_matrix.col, dtype=torch.long)
        values = sparse_matrix.data

        # Create W as a trainable vector for the non-zero elements of N
        self.A = nn.Parameter(torch.tensor(A_split, dtype=torch.float32))
        self.W = nn.Parameter(torch.tensor(np.log(values), dtype=torch.float32)) #Transform for exponential optimisation N = exp(W) <=> W = log(N)
        
        self.num_genes, self.num_tfs = N.shape

    def forward(self):
        # Create a sparse tensor for W
        N_sparse = torch.sparse_coo_tensor(
            indices=torch.stack([self.i, self.j]), 
            values=torch.exp(self.W),
            size=(self.num_genes, self.num_tfs)
        )
        N_dense_tensor = N_sparse.to_dense()

        # Compute the reconstructed alpha matrix.
        return torch.matmul(N_dense_tensor, self.A)

    def loss(self, alpha_true):
        alpha_pred = self.forward()
        main_loss = torch.sum((alpha_true - alpha_pred) ** 2)
        l1_loss = torch.sum(torch.abs(self.W))  # L1 on W
        l2_loss = torch.sum(self.A ** 2)         # L2 on A
        
        #smoothness_loss = 10*torch.sum((self.x[:, 1:] - self.x[:, :-1])**2) #Avoid spikes in A
        #cyclic_loss = torch.sum((self.x[:, 0] - self.x[:, -1])**2) #Make activities more cyclic
        #slope_loss = torch.sum((self.x[:, 1] - self.x[:, 0] - (self.x[:, -1] - self.x[:, -2]))**2) #Make slopes matc

        total_loss = main_loss + self.lambdaW * l1_loss + self.lambdaA * l2_loss
        return total_loss

In [None]:
def train_model(N, alpha, alpha_test, lambdaW, lambdaA, A_split, tol=(0.04, 150), patience=20, num_epochs=400, lr=0.01):
    # require 4% improvement every 20 epochs
    num_genes, num_tfs = N.shape
    num_thetas = alpha.shape[1]
    
    model = TrainableModel(N, alpha, A_split, num_tfs, num_thetas, lambdaW, lambdaA)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    best_EV = -float("inf")
    best_test_loss = float("inf")
    patience_counter = 0
    
    train_loss_history = []
    test_loss_history = []
    test_EV_history = []

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        loss = model.loss(alpha)
        loss.backward()
        optimizer.step()

       # Validation Step (keep computations in PyTorch)
        W_sparse_vector = model.W.detach()
        W_dense = torch.sparse_coo_tensor(
            indices=torch.stack([model.i, model.j]),
            values=torch.exp(W_sparse_vector),
            size=(num_genes, num_tfs)
        ).to_dense()
        A = model.A.detach()

        # Calculate R_test as a torch tensor
        R_test = torch.matmul(W_dense, A)

        # Calculate EV_test (convert tensors to NumPy for explained_variance_score)
        EV_test = explained_variance_score(alpha_test.numpy(), R_test.detach().numpy())

        # Calculate Test Loss in PyTorch
        main_loss = torch.sum((alpha_test - R_test) ** 2)
        l1_loss = torch.sum(torch.abs(W_sparse_vector))  # L1 on W
        l2_loss = torch.sum(A ** 2)                      # L2 on the unconstrained A
        total_test_loss = main_loss + lambdaW * l1_loss + lambdaA * l2_loss
        
        train_loss_history.append(loss.item())
        test_loss_history.append(total_test_loss)
        test_EV_history.append(EV_test)

        # Check for early stopping
        is_relative_loss_better = total_test_loss < best_test_loss * (1 - tol[0])
        is_absolute_loss_better = total_test_loss < best_test_loss - tol[1]
        if is_relative_loss_better and is_absolute_loss_better:
            best_EV = EV_test
            best_test_loss = total_test_loss
            patience_counter = 0  # Reset patience
        else:
            patience_counter += 1
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.0f}, Loss test: {total_test_loss:.0f}, EV_test: {EV_test*100:.2f}%")

        if patience_counter >= patience:
            if not is_relative_loss_better:
                print(f"Early stopping (relative={tol[0]}) at epoch {epoch+1}. Best Loss test: {best_test_loss:.0f}. Best EV_test: {best_EV*100:.2f}%")
            if not is_absolute_loss_better:
                print(f"Early stopping (absolute={tol[1]}) at epoch {epoch+1}. Best Loss test: {best_test_loss:.0f}. Best EV_test: {best_EV*100:.2f}%")
            break
    
    fig, ax1 = plt.subplots(figsize=(4, 3))
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss value')
    ax1.plot(range(epoch+1), train_loss_history, color='tab:red', label="Train Loss")
    ax1.plot(range(epoch+1), test_loss_history, color='tab:green', label="Test Loss")
    ax1.tick_params(axis='y')
    ax1.grid(True)
    ax1.legend(loc='center right')

    ax2 = ax1.twinx()
    ax2.set_ylabel('EV Test', color='tab:blue')
    ax2.plot(range(epoch+1), test_EV_history, color='tab:blue', linestyle='--')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    fig.suptitle('Loss and EV Over Epochs')
    fig.tight_layout()
    plt.show()

    return W_dense.numpy(), A.numpy(), loss.item()#, total_test_loss

In [None]:
# 4. Cross-Validation for Lambda Optimization
def cross_val_lambda(N, alpha1, alpha2, lambdaW_values, lambdaA_values, A_split, A_split_2, tol=(0.04, 150), patience=20):
    best_lambdaW, best_lambdaA, best_EV = None, None, -np.inf
    losses1 = []
    losses2 = []
    EVs_avg = []

    for lambdaW, lambdaA in product(lambdaW_values, lambdaA_values):
        print(f"Testing lambdaW = {lambdaW:.2f}, lambdaA = {lambdaA:.2f}")

        # Train on alpha1, test on alpha2
        W1, A1, loss1 = train_model(N, alpha1, alpha2, lambdaW, lambdaA, A_split, tol, patience)
        losses1.append(loss1)
        R_test1 = W1 @ A1
        EV1 = explained_variance_score(alpha2, R_test1)

        # Train on alpha2, test on alpha1
        W2, A2, loss2 = train_model(N, alpha2, alpha1, lambdaW, lambdaA, A_split_2, tol, patience)
        losses2.append(loss2)
        R_test2 = W2 @ A2
        EV2 = explained_variance_score(alpha1, R_test2)

        avg_EV = (EV1 + EV2) / 2
        EVs_avg.append(avg_EV)
        print(f"lambdaW={lambdaW:.2f}, lambdaA={lambdaA:.2f}, EV={avg_EV*100:.2f}%\n")

        if avg_EV > best_EV:
            best_lambdaW, best_lambdaA, best_EV = lambdaW, lambdaA, avg_EV
            
    EV_surface = np.array(EVs_avg).reshape(len(lambdaW_values), len(lambdaA_values))
    LambdaW, LambdaA = np.meshgrid(lambdaW_values, lambdaA_values, indexing='ij')

    # Find optimal point
    best_idx = np.unravel_index(np.argmax(EV_surface), EV_surface.shape)
    opt_lambdaW = lambdaW_values[best_idx[0]]
    opt_lambdaA = lambdaA_values[best_idx[1]]

    # Plot
    plt.figure(figsize=(7, 5))
    cp = plt.contourf(LambdaW, LambdaA, EV_surface, levels=30, cmap='viridis')
    plt.colorbar(cp, label='Average Explained Variance')
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('λ_W (W regularization)')
    plt.ylabel('λ_A (A regularization)')
    plt.title('EV Surface over λ_W, λ_A')
    plt.scatter([opt_lambdaW], [opt_lambdaA], color='red', label='Optimum')
    plt.legend()
    plt.grid(True, which='both', ls='--', lw=0.3)
    plt.tight_layout()
    plt.show()

    print(f"Best λ_W = {best_lambdaW:.2f}, Best λ_A = {best_lambdaA:.2f}, Best EV = {best_EV*100:.2f}%\n")
    return best_lambdaW, best_lambdaA

In [None]:
# 5. Cross train for best model
def cross_train(N, alpha1, alpha2, best_lambdaW, best_lambdaA, A_split, A_split_2, tol=(0.04, 150), patience=20):
    print("Training on α1, testing on α2...")
    W1, A1, loss1 = train_model(N, alpha1, alpha2, best_lambdaW, best_lambdaA, A_split, tol, patience)
    R_test1 = W1 @ A1
    EV1_train = explained_variance_score(alpha1.numpy(), R_test1)
    EV1_test = explained_variance_score(alpha2.numpy(), R_test1)

    print("Training on α2, testing on α1...")
    W2, A2, loss2 = train_model(N, alpha2, alpha1, best_lambdaW, best_lambdaA, A_split_2, tol, patience)
    R_test2 = W2 @ A2
    EV2_train = explained_variance_score(alpha2.numpy(), R_test2)
    EV2_test = explained_variance_score(alpha1.numpy(), R_test2)

    avg_EV_train = (EV1_train + EV2_train) / 2
    avg_EV_test = (EV1_test + EV2_test) / 2
    print(f"Average EV_train: {avg_EV_train*100:.2f}%")
    print(f"Average EV_test: {avg_EV_test*100:.2f}%")
    
    return W1, A1, W2, A2

In [None]:
def standardize_amplitudes(matrices_from, matrices_to = False, target_amp=0.2):
    """
    Rescale the matrices to a single target amplitude.
    """
    standardized_matrices_from = []
    standardized_matrices_to = []
    for i in range(len(matrices_from)):
        amp = (np.max(matrices_from[i], axis=1) - np.min(matrices_from[i], axis=1)) / 2
        scale = target_amp / amp
        standardized_matrix_from = matrices_from[i] * scale[:, np.newaxis] #TF x theta
        standardized_matrices_from.append(standardized_matrix_from)
        if (matrices_to != False):
            standardized_matrix_to = matrices_to[i] / scale[np.newaxis, :] #genes x TF
            standardized_matrices_to.append(standardized_matrix_to)
    
    return standardized_matrices_from, standardized_matrices_to

In [None]:
# Load Data & Run
fileAlpha1 = "alpha_snrna_rep1_5000_1_2p75.csv"
fileAlpha2 = "alpha_snrna_rep2_5000_1_2p75.csv"
fileBSM = 'data_binding_site_matrix.txt'
process = ["transcription", "\u03B1"]
theta_smooth = np.round(np.linspace(0.01, 1.00, 100), 2)  # 100 bins from 0.01 to 1.00
ampl_threshold=0.1

# Define Lambda Values
lambdaW_values = np.logspace(-3, -1, 10)
lambdaA_values = np.logspace(1, 2, 10)
#best_lambdaW, best_lambdaA = 0, 0
best_lambdaW, best_lambdaA = 0.0215, 77.43

N = pd.read_csv(fileBSM, sep="\t",index_col=0)
tf_names = N.columns
alpha1 = pd.read_csv(fileAlpha1, sep=",",index_col=0)
alpha2 = pd.read_csv(fileAlpha2, sep=",",index_col=0)
#A_split = np.load("activities_export/ampl_"+str(ampl_threshold)+"/A_star_split.npy")

#Select common genes and normalize
print(alpha1.shape, alpha2.shape, N.shape)
alpha1_norm, alpha2_norm, N_norm, targetnames, tf_names = preprocessing(alpha1, alpha2, N, ampl_threshold=ampl_threshold)
print(alpha1_norm.shape, alpha2_norm.shape, N_norm.shape, "\n")

n_runs = 5
A1_list = []
W1_list = []
A2_list = []
W2_list = []
A1_split_list = []
A2_split_list = []

for seed in np.random.randint(1,100, n_runs):
    torch.manual_seed(seed) #useless now that init is not random anymore
    np.random.seed(seed)
    lambdas = np.logspace(0, 6, 40)
    A_split, lambda_opt1 = svd_regression_with_lambda_CV(alpha1_norm, N_norm, lambdas, seed=seed)
    A_split_2, lambda_opt2 = svd_regression_with_lambda_CV(alpha2_norm, N_norm, lambdas, seed=seed)
    A1_split_list.append(A_split)
    A2_split_list.append(A_split_2)
    #print("Best lambda:", lambda_opt1)
    #print("Best lambda:", lambda_opt2)

    N_tensor = torch.tensor(N_norm, dtype=torch.float32)  # (genes, TFs)
    alpha1_tensor = torch.tensor(alpha1_norm, dtype=torch.float32)  # (genes, thetas)
    alpha2_tensor = torch.tensor(alpha2_norm, dtype=torch.float32)  # (genes, thetas)

    tol = (0.03, 200) #(0.07, 800) for ampl 0.2  #np.linspace(0.01, 0.07, 7)
    patience = 20 #np.linspace(5, 30, 7)
    # Optimize Lambda
    if (best_lambdaW == 0):
        best_lambdaW, best_lambdaA = cross_val_lambda(N_tensor, alpha1_tensor, alpha2_tensor, lambdaW_values, lambdaA_values, A_split, A_split_2, tol=(0.04, 150), patience=20)
    # Train and Cross-Test
    print("Seed :", seed, "Tolerance :", tol, "Patience :", patience)
    W1, A1, W2, A2 = cross_train(N_tensor, alpha1_tensor, alpha2_tensor, best_lambdaW, best_lambdaA, A_split, A_split_2, tol, patience)
    A1_list.append(A1)
    W1_list.append(W1)
    A2_list.append(A2)
    W2_list.append(W2)

In [None]:
n_tfs = A1_list[0].shape[0]
n_genes = W1_list[0].shape[0]
A_list = A1_list
W_list = W1_list
A_split_list = A1_split_list

heatmap_vals = np.zeros((4, n_runs+1, n_runs+1)) #A_split  A1_1  A1_2  A1_3  A1_4  A1_5
heatmap_labels = ['Model 1']
for l in range(n_runs):
    heatmap_labels.append("Rep "+str(l+1))

for k in range(n_runs+1): #Diagonal to 1
    heatmap_vals[0][k, k] = 1.0
    heatmap_vals[1][k, k] = 1.0
    heatmap_vals[2][k, k] = 1.0
    heatmap_vals[3][k, k] = 1.0

for i, j in combinations(range(n_runs), 2): #Corr of A1 between runs
    A1_i, A1_j = A_list[i], A_list[j]
    W1_i, W1_j = W_list[i], W_list[j]

    # Per TF correlation of activities
    A_corrs = np.zeros(n_tfs)
    for m in range(n_tfs):
        A_corrs[m], _ = pearsonr(A1_i[m, :], A1_j[m, :])

    # Per gene correlation of weights
    W_corrs = np.zeros(n_genes)
    for g in range(n_genes):
        W_corrs[g], _ = pearsonr(W1_i[g, :], W1_j[g, :])
        
    heatmap_vals[0][i+1, j+1] = heatmap_vals[0][j+1, i+1] = round(np.mean(A_corrs), 3)
    heatmap_vals[1][i+1, j+1] = heatmap_vals[1][j+1, i+1] = round(np.median(A_corrs), 3)
    heatmap_vals[2][i+1, j+1] = heatmap_vals[2][j+1, i+1] = round(np.mean(W_corrs), 3)
    heatmap_vals[3][i+1, j+1] = heatmap_vals[3][j+1, i+1] = round(np.median(W_corrs), 3)
    
for l in range(n_runs): #A_split corr with corresponding A after run
    A_split, A1 = A_split_list[l], A_list[l]
    W1 = W_list[l]

    # Per TF correlation of activities
    A_corrs = np.zeros(n_tfs)
    for m in range(n_tfs):
        A_corrs[m], _ = pearsonr(A_split[m, :], A1[m, :])

    # Per gene correlation of weights
    W_corrs = np.zeros(n_genes)
    for g in range(n_genes):
        W_corrs[g], _ = pearsonr(N_norm[g, :], W1[g, :])
        
    heatmap_vals[0][l+1, 0] = heatmap_vals[0][0, l+1] = round(np.mean(A_corrs), 3)
    heatmap_vals[1][l+1, 0] = heatmap_vals[1][0, l+1] = round(np.median(A_corrs), 3)
    heatmap_vals[2][l+1, 0] = heatmap_vals[2][0, l+1] = round(np.mean(W_corrs), 3)
    heatmap_vals[3][l+1, 0] = heatmap_vals[3][0, l+1] = round(np.median(W_corrs), 3)

for i, title in enumerate(["Mean A Corr", "Median A Corr", "Mean W Corr", "Median W Corr"]):
    plt.figure(figsize=(8, 6))
    plt.title(title)
    ax = sns.heatmap(heatmap_vals[i], cbar=True, annot=True, fmt=".2f", vmin=0, vmax=1, yticklabels=heatmap_labels,xticklabels=heatmap_labels)
    plt.show()

In [None]:
BP_nb = np.where(tf_names == 'E2f1')[0][0]
plot_binding_protein_activity(tf_names, A_split, process, theta_smooth, BP_nb=BP_nb)

In [None]:
R_split = N_norm @ A_split
explained_variance_score(alpha1_norm, R_split)

In [None]:
plt.rcParams.update({
    'font.size': 13,          # base font size
    'axes.labelsize': 13,     # axis label font size
    'axes.titlesize': 11,     
    'xtick.labelsize': 9,
    'ytick.labelsize': 11,
})
plt.figure(figsize=(7,5))
# Compute TF importance: sum of absolute weights per TF
A1_list_standardized, W1_list_standardized = standardize_amplitudes(A1_list, W1_list, target_amp=0.2)
print("tol=", tol, "ampl=", ampl_threshold, "Patience=", patience)
W = W1_list_standardized[0]
TFs_sumW_df = W_key_TF(W, tf_names, top_k=20)

In [None]:
# Compute gene importance: sum of weights per gene
Genes_sumW_df = W_key_gene(W, targetnames, top_k=20)
#Genes_sumW_df.to_csv("mESC_genes_sumW.csv")

In [None]:
def get_tf_targets(W, tf_names, gene_names, tf_query, top_k=20):
    # Find TF index
    tf_idx = np.where(tf_names == tf_query)[0][0]
    W_tf = W[:, tf_idx]

    # Filter out zero weights
    nonzero_idx = np.where(W_tf != 0)[0]
    W_tf_nonzero = W_tf[nonzero_idx]
    gene_names_nonzero = [gene_names[i] for i in nonzero_idx]

    # Sort by absolute weight (descending)
    sorted_idx = np.argsort(W_tf_nonzero)[::-1]
    top_gene_names = [gene_names_nonzero[i] for i in sorted_idx]
    weights = [W_tf_nonzero[i] for i in sorted_idx]
    print(str(len(weights))+" targets for "+tf_query)
    
    df = pd.DataFrame({'sum_W': weights}, index=top_gene_names)
    #df = df.sort_values('sum_W', ascending=False)
    
    # --- Step 3: Plot top_k TFs ---
    top_df = df.head(top_k)
    
    plt.figure(figsize=(7,5))
    plt.bar(np.arange(top_k), top_df['sum_W'], color='red')
    plt.xticks(np.arange(top_k), top_df.index, rotation=45, ha='right')
    plt.ylabel("Binding site counts (W)")
    #plt.title(f"Top {top_k} TFs by ∑W (After amplitude standardization)")
    plt.axhline(0, color='black', linewidth=0.8)
    plt.tight_layout()
    plt.show()

    # Build dictionary
    targets_dic = dict(zip(top_gene_names, weights))
    return targets_dic

targets_dic = get_tf_targets(W, tf_names, targetnames, tf_query="E2f1")
#targets_dic

In [None]:
np.sort(W1_list_standardized[0][W1_list_standardized[0] > 0].flatten())

In [None]:
A1, A2 = A1_list[0], A1_list[1]
residuals = A1 - A2
std_residuals = residuals.std(axis=1)

# Compute per-TF Pearson correlations
n_tfs = A1.shape[0]
pearson_rs = np.array([pearsonr(A1[i, :], A2[i, :])[0] for i in range(n_tfs)])

# Plot histogram of std residuals => check for biases
plt.figure(figsize=(6, 4))
plt.hist(std_residuals, bins=30, color='steelblue', alpha=0.8)
plt.xlabel("Std of Residuals (A1 - A2)")
plt.ylabel("Number of TFs")
plt.title("Distribution of Std Residuals per TF")
plt.grid(True)
plt.tight_layout()
plt.show()

# Scatter plot: Pearson r vs Residual std => identify shape and/or amplitude disagreements
plt.figure(figsize=(6, 5))
plt.scatter(pearson_rs, std_residuals, alpha=0.8, edgecolor='k')
plt.xlabel("Pearson Correlation (A1 vs A2)")
plt.ylabel("Std of Residuals (A1 - A2)")
plt.title("Per-TF Reproducibility: Shape vs Magnitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
A1, A2 = A1_list[0], A1_list[1]
W1, W2 = W1_list[0], W1_list[1]
print(np.sort(W1.flatten()))
print(np.sort(A1.flatten()))

In [None]:
#We smooth activities
#A1 = fourier_fit(A1, theta_smooth)
#A2 = fourier_fit(A2, theta_smooth)
R1 = W1 @ A1
R2 = W2 @ A2

In [None]:
plt.rcParams.update({
    'font.size': 22,          # base font size
    'axes.labelsize': 20,     # axis label font size
    'axes.titlesize': 20,     
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
})

In [None]:
expl = []
for k in range(alpha1_norm.shape[0]):
    expl.append(explained_variance_score(alpha2_norm[k,:], R1[k,:]))
expl = np.array(expl)
expl_sorted = np.sort(expl)[:]
plt.hist(expl_sorted, bins=150)
plt.xlabel("Explained variance")
plt.xlim(-2,1)
plt.ylabel("Number of genes")
plt.grid(True)
plt.title("Distribution of gene-wise EVs for transcription")
plt.show()
print(explained_variance_score(alpha2_norm, R1))
np.mean(expl)

In [None]:
BP_nb = np.where(tf_names == 'E2f4')[0][0]
#BP_nb = 27
plot_binding_protein_activity(tf_names, A1, process, theta_smooth, BP_nb=BP_nb)
print(f"Positive W1 among target genes of {tf_names[BP_nb]} : {np.sum(W1[:, BP_nb] > 0)}/{np.sum(W1[:, BP_nb] != 0)} ({np.sum(W1[:, BP_nb] > 0)/np.sum(W1[:, BP_nb] != 0)*100:.2f}%)")

In [None]:
#MODEL 1
n = np.where(targetnames == 'Arfgef3')[0][0]
#n = 138
print("Train")
plot_rate_comparison(targetnames, alpha1_norm, R_split, process, theta_smooth, target_nb=n)
print("Test")
plot_rate_comparison(targetnames, alpha2_norm, R_split, process, theta_smooth, target_nb=n)

In [None]:
#MODEL 2
n = np.where(targetnames == 'Ankrd10')[0][0]
#n = 4972
print("Train")
plot_rate_comparison(targetnames, alpha1_norm, R1, process, theta_smooth, target_nb=n)
print("Test")
plot_rate_comparison(targetnames, alpha2_norm, R1, process, theta_smooth, target_nb=n)

In [None]:
compute_reproducibility(A1, A2, alpha1_norm, alpha2_norm, metric="TF activities")
compute_reproducibility(W1, W2, alpha1_norm, alpha2_norm, metric="W site counts")
#compute_reproducibility(R1, R2, alpha1_norm, alpha2_norm, metric="Reconstruction")

In [None]:
#plot_heatmap(A_split, ylabels=tf_names, display_limit=25, cmap='RdBu_r', title=" TF activities across the cell cycle")
#plot_heatmap(alpha1_norm, cmap='RdBu_r', title=" transcription rates across the cell cycle")
plt.rcParams.update({
    'font.size': 20,          # base font size
    'axes.labelsize': 18,     # axis label font size
    'axes.titlesize': 18,     
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
})
#plot_heatmap(alpha1_norm, cmap='RdBu_r', title="Gene transcription (scRNA-seq) across the cell cycle")
plot_heatmap(A_split, cmap='RdBu_r', title="TFs activities on transcription across the cell cycle")

In [None]:
#Heatmap of BPs activity along cell cycle (Export)
key_tfs = ["Smad3", "Hbp1", "E2f1", "E2f2_E2f5", "E2f3", "E2f4", "E2f6", "E2f7", "E2f8", "Sp1", "Hes1", "Elf3", "Tfap4"]
tf_displayed = plot_heatmap_list(A1, tf_names, key_tfs, clip=True)
print(tf_displayed)

In [None]:
expected_activity = {
    "Smad3": {"ranges": [(0.15, 0.35)], "inhibitory": False}, #True?
    "Hbp1": {"ranges": [(0.15, 0.35)], "inhibitory": True},
    "E2f1": {"ranges": [(0.15, 0.35)], "inhibitory": False},
    "E2f2_E2f5": {"ranges": [(0.15, 0.35)], "inhibitory": False},
    "E2f3": {"ranges": [(0.15, 0.35)], "inhibitory": False},
    "E2f4": {"ranges": [(0.05, 0.25)], "inhibitory": True}, #Not 100% confident
    "E2f6": {"ranges": [(0.3, 1)], "inhibitory": True}, #Not 100% confident
    "E2f7": {"ranges": [(0.3, 1)], "inhibitory": True}, #Not 100% confident
    "E2f8": {"ranges": [(0.3, 1)], "inhibitory": True}, #Not 100% confident
    "Sp1": {"ranges": [(0.15, 0.35)], "inhibitory": False}, #Can be both?
    "Hes1": {"ranges": [(0.15, 0.35)], "inhibitory": True}
}

In [None]:
#### TF EXPRESSION AND BIOLOGICAL MEANING ####

In [None]:
key_tfs = ["Smad3", "Hbp1", "E2f1", "E2f2", "E2f3", "E2f4", "E2f5", "E2f6", "E2f7", "E2f8", "Sp1", "Hes1"]

tf_names_filtered = np.array([tf for tf in key_tfs if tf in alpha1.index and tf in alpha2.index and tf in tf_names])
print("TFs in common :", str(len(tf_names_filtered))+"/"+str(len(key_tfs)))

alpha1_f, alpha2_f = alpha1.loc[tf_names_filtered], alpha2.loc[tf_names_filtered]
alpha1_n, alpha2_n = alpha1_f.to_numpy(), alpha2_f.to_numpy()

#Standardize amplitudes
#A_standard = A_standard - np.mean(A_standard, axis=1, keepdims=True)
alpha_sn_n = (alpha1_n + alpha2_n) / 2
alpha_sn_norm = alpha_sn_n - np.mean(alpha_sn_n, axis=1, keepdims=True) - np.mean(alpha_sn_n, axis=0, keepdims=True) + np.mean(alpha_sn_n)
matrices_from, _ = standardize_amplitudes([alpha_sn_norm, A1])
alpha_sn_norm, A_standard = matrices_from

In [None]:
plt.rcParams.update({
    'font.size': 16,          # base font size
    'axes.labelsize': 15,     # axis label font size
    'axes.titlesize': 15,     
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'axes.spines.right': True
})

In [None]:
corrs = []
z_vals = []
for tf in range(len(tf_names_filtered)):
    plot_TF_exp_activity(theta_smooth, alpha_sn_norm, A_standard, tf_names, tf_names_filtered, tf)
    corr = spearmanr(alpha_sn_norm[tf], A_standard[list(tf_names).index(tf_names_filtered[tf])])[0]
    action = "activator" if not expected_activity[tf_names_filtered[tf]]["inhibitory"] else "inhibitor"
    if (action == "inhibitor"):
        corr = -corr
    print(f"scRNA & A correlation : {corr:.3f} ({ action })\n")
    z_val = compute_tf_activity_difference(A_standard[list(tf_names).index(tf_names_filtered[tf]), :], theta_smooth, expected_activity[tf_names_filtered[tf]]["ranges"], expected_activity[tf_names_filtered[tf]]["inhibitory"])
    corrs.append(corr)
    z_vals.append(z_val)
    print(f"Expected activity range : {expected_activity[tf_names_filtered[tf]]["ranges"]}")
    print(f"TF activity biological z-score : {z_val:.2f} ({ action })")
print(f"\nGlobal correlation :{np.mean(corrs):.3f}")
print(f"Global z-score :{np.mean(z_vals):.2f}")