In [1]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy.sparse import load_npz, identity, issparse
from numpy.linalg import pinv, matrix_rank, svd
from scipy.linalg import solve_discrete_are, eigvals, norm
from scipy.sparse.linalg import eigs as sparse_eigs

import networkx as nx
from tqdm.auto import tqdm
import time
import argparse

A_path = "o2/A_final_stable.npz"      # repo file
genes_path = "o2/genes_final.csv"     # expects column 'gene' or single column
expr_path = "data/expr_common_full.csv"  # expression table: rows = samples, cols = genes (or transpose)

T_gram = 40               # finite horizon for Gramian approximations (trace)
m_cont = 30               # depth for controllability rank checks
max_greedy_size = 200     # safety cap on greedy selected set size
objective = "rank"  # options: 'gramian_trace', 'min_eig', 'rank'
seed_with_structural = True
verbose = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if not os.path.exists(A_path):
    raise FileNotFoundError(f"{A_path} not found. Run from repo root or change path.")

A = load_npz(A_path).astype(float)
n = A.shape[0]
print("Loaded A:", A.shape, "sparse:", issparse(A))

GENES = None
if os.path.exists(genes_path):
    gdf = pd.read_csv(genes_path, header=0)
    # try to infer gene column
    if "gene" in gdf.columns:
        GENES = gdf["gene"].astype(str).tolist()
    else:
        # if single column
        GENES = gdf.iloc[:,0].astype(str).tolist()
    assert len(GENES) == n, f"Gene list length {len(GENES)} != A.shape[0] {n}"
    print("Loaded genes list, n =", len(GENES))
else:
    # fallback: generate numeric gene names
    GENES = [f"g{i}" for i in range(n)]
    print("genes file not found; using synthetic gene names.")

Loaded A: (8378, 8378) sparse: True
Loaded genes list, n = 8378


In [3]:
if os.path.exists(expr_path):
    expr_df = pd.read_csv(expr_path, index_col=0)
    print("Expression data", expr_df)
    print("Loaded expression:", expr_df.shape)
    # Ensure columns correspond to gene list (attempt to align)
    # If expr has same columns as genes, reorder; else assume same order
    if set(expr_df.columns) >= set(GENES):
        expr_df = expr_df[GENES]   # reorder to genes list
    else:
        # if mismatch, try transpose if rows are genes
        if set(expr_df.index) >= set(GENES):
            expr_df = expr_df.T
            if set(expr_df.columns) >= set(GENES):
                expr_df = expr_df[GENES]
        else:
            print("Warning: gene names in expression file do not match gene list. Assuming same ordering.")
            # then we assume expr_df.columns order == genes
    # Build x*: median of 'normal' samples if it can find them, else global median
    # Look for typical normal identifiers in sample names
    sample_names = expr_df.index.astype(str).tolist()
    normal_mask = [("normal" in s.lower() or "adj" in s.lower() or "tumor-adj" in s.lower()) for s in sample_names]
    if any(normal_mask):
        x_star = expr_df.loc[np.array(normal_mask)].median(axis=0).values
        print("Using median of detected normal samples as x*")
    else:
        x_star = expr_df.median(axis=0).values
        print("Using global median across samples as x*")
else:
    # fallback: zero vector
    x_star = np.zeros(n)
    print("Expression file not found; using x* = zero vector (NOT recommended).")

assert x_star.shape[0] == n, "x* length mismatch"
print("x* : ", x_star)

Expression data              MB-0006  MB-0028  MB-0035  MB-0046  MB-0050  MB-0053  MB-0054  \
Hugo_Symbol                                                                  
A2M           2.4671  -0.8292  -0.3122  -0.1450  -0.7651   1.1408   2.1077   
A4GALT       -0.4412   0.4318  -1.9995   0.5552   0.1737  -1.5991  -0.8954   
AAAS          0.7683  -0.7058   0.0129  -0.6430  -1.2141  -0.4286  -0.1641   
AACS         -1.3245   0.2130  -0.0626   0.6073  -0.8968  -1.5375  -1.0142   
AADACL2       0.6082  -0.5639  -0.2895   1.1172  -0.3718   1.0195   0.3245   
...              ...      ...      ...      ...      ...      ...      ...   
ZWINT        -0.6574   0.3098   0.0264   1.4851  -1.6942   1.0965   1.0998   
ZXDC         -1.5690  -0.2822  -0.5964  -0.2320   0.4545  -1.9882  -2.7958   
ZYG11B        0.0940  -0.1729   0.4592  -1.0761   0.3331   0.3317   0.9267   
ZYX          -0.7766  -2.5108   0.5377  -0.2600   0.1423  -3.6021  -2.0316   
ZZEF1        -0.9416  -1.4202  -1.0823   0.2592 

In [4]:
if issparse(A):
    absA = abs(A)
    degrees = np.array(absA.sum(axis=0)).ravel() + np.array(absA.sum(axis=1)).ravel() # in+out
else:
    absA = np.abs(A)
    degrees = absA.sum(axis=0) + absA.sum(axis=1)

# rank genes by degree
rank_idx = np.argsort(-degrees)
top_k = 20
top_genes = [(GENES[i], float(degrees[i])) for i in rank_idx[:top_k]]
top_df = pd.DataFrame(top_genes, columns=["gene","degree"])
print(top_df.head(10))

        gene     degree
0     IGFBP7  66.265798
1     TUBA1C  54.582640
2     UBE2G2  52.866151
3      ITGB1  48.068044
4      SPCS1  44.949030
5     RPS27A  44.542268
6     HNRNPK  40.958177
7      GSG1L  39.184623
8  GABARAPL2  38.430002
9       MXD3  37.708785


In [5]:
def select_ctrl(top_k: int = 5, genes_list: list = None) -> list:
    if genes_list is None:
        genes_list = []
        k_control = top_k
        selected_control_idx = rank_idx[:k_control].tolist()
        control_gene_names = [GENES[i] for i in selected_control_idx]
        print("Selected control genes (k={}):".format(k_control), control_gene_names)
    else:
        selected_control_idx = []
        for g in genes_list:
            if g in GENES:
                idx = GENES.index(g)
                selected_control_idx.append(idx)
            else:
                print(f"Warning: gene {g} not found in gene list; skipping.")
        k_control = len(selected_control_idx)
        control_gene_names = [GENES[i] for i in selected_control_idx]
        print("Selected control genes (k={}):".format(k_control), control_gene_names)

    B = np.zeros((n, k_control))
    for j, idx in enumerate(selected_control_idx):
        B[idx, j] = 1.0   # direct actuation
        
    return B, selected_control_idx

In [6]:
select_ctrl(genes_list=["TP53", "BRCA1", "EGFR"])
select_ctrl()  # default top 

Selected control genes (k=3): ['TP53', 'BRCA1', 'EGFR']
Selected control genes (k=5): ['IGFBP7', 'TUBA1C', 'UBE2G2', 'ITGB1', 'SPCS1']


(array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], shape=(8378, 5)),
 [3238, 7535, 7603, 3372, 6662])

In [7]:
def structural_driver_nodes(A_sparse, genes):
    """
    Construct a bipartite graph (left copy L, right copy R).
    For every directed edge u->v in A, add (u_L, v_R).
    Compute maximum matching; unmatched right nodes are driver nodes (Liu et al).
    """
    if not issparse(A_sparse):
        A_sparse = A_sparse
    n = A_sparse.shape[0]
    # Build directed edges from sparse matrix
    Acoo = A_sparse.tocoo()
    G = nx.DiGraph()
    # bipartite graph as networkx Graph:
    B = nx.DiGraph()
    # Add nodes
    L_nodes = [f"l_{i}" for i in range(n)]
    R_nodes = [f"r_{i}" for i in range(n)]
    BG = nx.Graph()
    BG.add_nodes_from(L_nodes, bipartite=0)
    BG.add_nodes_from(R_nodes, bipartite=1)
    # Add edges (u_L, v_R) for each nonzero A[u,v]
    for u, v in zip(Acoo.row, Acoo.col):
        BG.add_edge(f"l_{u}", f"r_{v}")
    # maximum bipartite matching (networkx returns dict)
    matching = nx.algorithms.bipartite.matching.hopcroft_karp_matching(BG, top_nodes=L_nodes)
    # matching maps both ways; extract matched right nodes
    matched_right = set()
    for k,v in matching.items():
        if k.startswith("l_"):
            matched_right.add(v)
    # unmatched right nodes are driver nodes (their r_i absent from matched_right)
    unmatched_r = [i for i in range(n) if f"r_{i}" not in matched_right]
    driver_genes = [GENES[i] for i in unmatched_r]
    return unmatched_r, driver_genes, matching


def A_pwr_precomp(A_sparse, T):
    """Return a list of functions that compute A^t @ v efficiently.
       For moderate n we produce dense powers. For large n keep sparse logic: compute iteratively.
    """
    n = A_sparse.shape[0]
    if issparse(A_sparse) and n > 2000:
        # for large sparse, we'll just return None and compute power via iterative multiplication when needed
        return None
    else:
        A_dense = A_sparse.toarray() if issparse(A_sparse) else A_sparse
        powers = [np.eye(n)]
        cur = np.eye(n)
        for _ in range(1, T):
            cur = A_dense.dot(cur)
            powers.append(cur.copy())
        return powers

def B_gramian_Tr(A_sparse, B_mat, T, A_powers=None):
    # W_T = sum_{i=0}^{T-1} A^i B B^T (A^i)^T
    # trace(W_T) = sum_{i=0}^{T-1} trace(B^T (A^i)^T A^i B) = sum ||A^i B||_F^2
    n, k = B_mat.shape
    total = 0.0
    if A_powers is not None:
        # use powers to compute A^i B quickly: A_powers[i] @ B
        for i in range(T):
            AiB = A_powers[i].dot(B_mat)
            total += np.sum(AiB * AiB)   # Frobenius squared
    else:
        # iterative multiplication
        cur = B_mat.copy()
        for i in range(T):
            total += np.sum(cur * cur)
            cur = A_sparse.dot(cur)
    return float(total)

def B_min_eig(A_sparse, B_mat, T):
    # Build W explicitly (only for smaller n)
    n, k = B_mat.shape
    if n > 1200:
        return np.nan
    if issparse(A_sparse):
        A = A_sparse.toarray()
    else:
        A = A_sparse
    W = np.zeros((n,n))
    cur = B_mat.copy()
    for i in range(T):
        W += cur.dot(cur.T)
        cur = A.dot(cur)
    eigs = np.linalg.eigvalsh(W)
    return float(np.min(eigs))

def B_ctrlability_rank(A_sparse, B_mat, m):
    # build [B, A B, A^2 B, ...] up to m terms and compute rank
    n, k = B_mat.shape
    if k == 0:
        return 0  # no actuators => rank zero (avoid calling matrix_rank on empty arrays)
    cols = [B_mat]
    cur = B_mat.copy()
    for i in range(1, min(m, n)):
        cur = A_sparse.dot(cur)
        cols.append(cur.copy())
    M = np.hstack(cols)
    r = matrix_rank(M)
    return int(r)





In [8]:
from math import isfinite

def compute_Z_krylov(A_sparse, idx, T=T_gram, tol=1e-8, max_rank=None):
    """
    Compute low-rank factor Z (n x r) approximating finite-horizon discrete Gramian
      W_T = sum_{t=0}^{T-1} (A^t b) (A^t b)^T
    for b = e_idx, using a Krylov basis approach.
    Returns Z numpy array shape (n, r) where r <= T.
    Remarks:
      - Works for sparse or dense A (uses A.dot(v)).
      - Builds orthonormal basis V = [v1..vr] and small S = sum_{i} (V^T w_i)(V^T w_i)^T.
      - Final W_approx = V S V^T, and we return Z = V @ sqrtm(S) (via eig-decomp of S).
    """
    n = A_sparse.shape[0]
    # basis vector b
    b = np.zeros(n)
    b[idx] = 1.0
    # helper to multiply A * vec
    Aop = A_sparse.dot if issparse(A_sparse) else (lambda x: A_sparse.dot(x))

    V_cols = []            # list of numpy arrays (orthonormal basis vectors)
    S = None               # small Gramian in V coordinates (r x r), updated incrementally
    # current vector w = A^i b (starting with i=0: w = b)
    w = b.copy()
    # iterate for T steps
    for t in range(T):
        # project w onto current V to get coords c = V^T w
        if len(V_cols) == 0:
            # first vector, create new basis vector from w
            wnorm = np.linalg.norm(w)
            if wnorm < tol:
                # zero vector (rare) -> return empty Z or small placeholder
                return np.zeros((n, 0))
            v_new = w / wnorm
            V_cols.append(v_new)
            # S becomes [[ (v_new^T w)^2 ]] = [[wnorm^2]]
            S = np.array([[wnorm * wnorm]], dtype=float)
        else:
            # form c = V^T w
            Vmat = np.column_stack(V_cols)  # n x r
            c = Vmat.T.dot(w)               # r-vector
            # compute residual r = w - V c
            residual = w - Vmat.dot(c)
            res_norm = np.linalg.norm(residual)
            if res_norm > tol:
                # append new orthonormal basis vector
                v_new = residual / res_norm
                # update S to incorporate new coordinate dimension
                r_old = S.shape[0]
                # expand S to (r_old+1 x r_old+1)
                S_new = np.zeros((r_old + 1, r_old + 1), dtype=float)
                S_new[:r_old, :r_old] = S
                # compute new c_full = [c; res_norm] is the coordinates of w in new basis
                c_full = np.concatenate([c, np.array([res_norm])])
                S_new += np.outer(c_full, c_full)
                S = S_new
                V_cols.append(v_new)
            else:
                # residual nearly zero: w lies in span(V). S += c c^T
                S += np.outer(c, c)
        # advance w <- A * w for next power
        w = Aop(w)
        # optional early stop if rank exceeds max_rank
        if (max_rank is not None) and (S.shape[0] >= max_rank):
            # still accumulate this step's contribution if needed: project current w and add then break
            # We already handled current w above; break now to limit rank growth
            break

    # At end, build Z = V * sqrtm(S)
    if S is None or len(V_cols) == 0:
        return np.zeros((n, 0))
    Vmat = np.column_stack(V_cols)   # n x r
    # small eigen-decomposition of S (r x r)
    evals, evecs = np.linalg.eigh(S)
    # clip tiny negative numerical noise
    evals[evals < 0] = 0.0
    # discard tiny eigenvalues
    keep = evals > (tol * np.max(evals) if np.max(evals) > 0 else tol)
    if np.sum(keep) == 0:
        # all tiny -> return empty
        return np.zeros((n, 0))
    evals_kept = evals[keep]
    evecs_kept = evecs[:, keep]
    # sqrt(S) = evecs_kept * diag(sqrt(evals_kept)) * evecs_kept^T
    # but Z = V * (evecs_kept * diag(sqrt(evals_kept))) -> compute small R = evecs_kept * sqrt(evals_kept)
    R_small = evecs_kept * np.sqrt(evals_kept)[np.newaxis, :]
    Z = Vmat.dot(R_small)   # n x r_eff
    return Z

# Precompute Zs for candidate pool (Krylov-based)
def precompute_basis_Zs_krylov(A_sparse, candidate_indices, T=T_gram, tol=1e-8, max_rank=None, parallel=False):
    """
    Build Zdict: index -> Z (n x r_i) using compute_Z_krylov.
    parallel: not implemented in this snippet (could use multiprocessing).
    """
    Zdict = {}
    t0 = time.time()
    for idx in tqdm(candidate_indices, desc="Precomputing Z (Krylov)"):
        try:
            Z = compute_Z_krylov(A_sparse, idx, T=T, tol=tol, max_rank=max_rank)
            Zdict[idx] = Z
        except Exception as e:
            print(f"[Z compute error] idx {idx}: {e}")
            Zdict[idx] = np.zeros((A_sparse.shape[0], 0))
    dt = time.time() - t0
    total_rank = sum(Z.shape[1] for Z in Zdict.values())
    print(f"Precomputed Z for {len(candidate_indices)} candidates in {dt:.1f}s; total rank sum {total_rank}")
    return Zdict

# Fast objective computations using Zdict (unchanged from earlier)
def gramian_trace_from_Zlist(Zlist):
    if len(Zlist) == 0:
        return 0.0
    s = 0.0
    for Z in Zlist:
        s += float(np.sum(Z * Z))
    return s

def gramian_min_eig_from_Zlist(Zlist, tol=1e-12):
    if len(Zlist) == 0:
        return 0.0
    Z_concat = np.hstack([Z for Z in Zlist if Z.shape[1] > 0])
    if Z_concat.size == 0:
        return 0.0
    S = Z_concat.T.dot(Z_concat)   # small matrix r_total x r_total
    # eigenvalues of W are eigenvalues of S
    evals = np.linalg.eigvalsh(S)
    # numerical floor
    evals[evals < tol] = 0.0
    return float(np.min(evals)) if evals.size > 0 else 0.0

# Modified greedy selection that uses the Krylov Zdict (same interface as previous)
def greedy_selection_with_Zs_krylov(A_sparse, genes, Zdict, candidate_pool, seed_indices=None,
                                    objective="gramian_trace", T=T_gram, m_cont=30, max_selected=200, verbose=True):
    n = A_sparse.shape[0]
    selected = []
    if seed_indices:
        selected = list(seed_indices)
    available = set(candidate_pool) - set(selected)
    Z_selected = [Zdict[i] for i in selected] if len(selected) > 0 else []
    # initial objective
    if objective == "gramian_trace":
        cur_val = gramian_trace_from_Zlist(Z_selected)
    elif objective == "min_eig":
        cur_val = gramian_min_eig_from_Zlist(Z_selected)
    elif objective == "rank":
        B_current = np.zeros((n, len(selected)))
        for j, idx in enumerate(selected):
            B_current[idx, j] = 1.0
        cur_val = B_ctrlability_rank(A_sparse, B_current, m_cont)
    else:
        cur_val = gramian_trace_from_Zlist(Z_selected)
    if verbose:
        print("Initial objective ({}) = {:.6g} (selected {} genes)".format(objective, cur_val, len(selected)))
    iter_count = 0
    while iter_count < max_selected and len(selected) < n and len(available) > 0:
        iter_count += 1
        best_gain = -np.inf
        best_gene = None
        best_new_val = None
        # iterate candidate pool only
        for g in available:
            Zcand_list = Z_selected + [Zdict[g]]
            if objective == "gramian_trace":
                val = gramian_trace_from_Zlist(Zcand_list)
                gain = val - cur_val
                if gain > best_gain:
                    best_gain = gain; best_new_val = val; best_gene = g
            elif objective == "min_eig":
                val = gramian_min_eig_from_Zlist(Zcand_list)
                gain = val - cur_val
                if gain > best_gain:
                    best_gain = gain; best_new_val = val; best_gene = g
            elif objective == "rank":
                # compute rank increase cheaply for one added actuator
                Bcand = np.zeros((n, len(selected) + 1))
                for j, idx in enumerate(selected):
                    Bcand[idx, j] = 1.0
                Bcand[g, -1] = 1.0
                val = B_ctrlability_rank(A_sparse, Bcand, m_cont)
                if val > cur_val and (best_gene is None or val > best_new_val):
                    best_gain = val - cur_val; best_new_val = val; best_gene = g
            else:
                val = gramian_trace_from_Zlist(Zcand_list)
                gain = val - cur_val
                if gain > best_gain:
                    best_gain = gain; best_new_val = val; best_gene = g
        if best_gene is None:
            if verbose:
                print("No candidate in pool improved objective. Stopping.")
            break
        selected.append(best_gene)
        available.remove(best_gene)
        Z_selected.append(Zdict[best_gene])
        cur_val = best_new_val
        if verbose:
            print(f"Iter {iter_count}: added {genes[best_gene]} (idx {best_gene}), new obj={cur_val:.6g}, selected={len(selected)}")
        # quick stop if rank full
        if objective != "rank":
            B_curr = np.zeros((n, len(selected)))
            for j, idx in enumerate(selected):
                B_curr[idx, j] = 1.0
            r = B_ctrlability_rank(A_sparse, B_curr, m_cont)
            if r == n:
                if verbose:
                    print("Controllability rank reached n. Stopping greedy.")
                break
        else:
            if cur_val == n:
                if verbose:
                    print("Rank objective reached n. Stopping greedy.")
                break
    return selected

In [None]:
# structural unmatched indices already computed earlier
unmatched_indices, driver_genes, matching = structural_driver_nodes(A, GENES)
print("structural lower bound:", len(unmatched_indices))

# build candidate pool: union of structural unmatched nodes + top degree genes
top_k_for_pool = 500   # adjust based on resources (e.g., 200 - 2000)
top_candidates = list(rank_idx[:top_k_for_pool])
candidate_pool = sorted(set(unmatched_indices).union(set(top_candidates)))
print("Candidate pool size:", len(candidate_pool))

# precompute Z for candidate pool (this is the expensive-but-one-time step)
Zdict = precompute_basis_Zs_krylov(A, candidate_pool, T=T_gram, tol=1e-8, max_rank=None)


# run greedy selection using Zdict
seed = unmatched_indices if seed_with_structural else None
selected = greedy_selection_with_Zs_krylov(A, GENES, Zdict, candidate_pool, seed_indices=seed,
                                           objective=objective, T=T_gram, m_cont=m_cont,
                                           max_selected=max_greedy_size, verbose=True)

# save results
df_sel = pd.DataFrame({"index": selected, "gene": [GENES[i] for i in selected]})
df_sel.to_csv("greedy_selected_drivers_Zbased.csv", index=False)
print("Saved greedy_selected_drivers_Zbased.csv")

# final rank check using B_final built from selected set
B_final = np.zeros((n, len(selected)))
for j, idx in enumerate(selected):
    B_final[idx, j] = 1.0
final_rank = B_ctrlability_rank(A, B_final, m_cont)
print("Final controllability rank with selected set:", final_rank, "out of", n)


In [10]:
selected_genes = [    
    "ZNF512",
    "ZNF525",
    "ZNF205",
    "SH3GL1",
    "ZNF550",
    "SLC5A3",
    "ZNF787",
    "PIK3R4",
    "POLN",
    "EPG5",
    "WDR53",
    "ALDH16A1",
    "NRBP1",
    "ARFRP1",
    "EEF1G",
    "PTTG1IP",
    "LTB4R",
    "QTRT1",
    "VPS18",
    "PCSK4",
    "TPM1",
    "APP",
    "NXNL1",
    "PLEKHN1",
    "ZNF587",
    "NEURL4",
    "SRPK3"
]


In [11]:
B, selected = select_ctrl(genes_list=selected_genes)
# B, selected = select_ctrl(1000)

Selected control genes (k=27): ['ZNF512', 'ZNF525', 'ZNF205', 'SH3GL1', 'ZNF550', 'SLC5A3', 'ZNF787', 'PIK3R4', 'POLN', 'EPG5', 'WDR53', 'ALDH16A1', 'NRBP1', 'ARFRP1', 'EEF1G', 'PTTG1IP', 'LTB4R', 'QTRT1', 'VPS18', 'PCSK4', 'TPM1', 'APP', 'NXNL1', 'PLEKHN1', 'ZNF587', 'NEURL4', 'SRPK3']
