In [None]:
import numpy as np
import pandas as pd
from scipy.special import gamma, digamma, polygamma
from scipy import stats
from scipy.stats import spearmanr
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KernelDensity, NearestNeighbors
from scipy.signal import find_peaks
import networkx as nx
import os
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

def safe_log(x, epsilon=1e-10):
    return np.log(np.maximum(x, epsilon))

def h(x, mu, sigma, r, b, epsilon=1e-10):
    b_safe = np.maximum(b, epsilon)
    sigma_safe = np.maximum(sigma, epsilon)

    term1 = b_safe / (gamma(1 / b_safe) * sigma_safe * 2 ** (1 + 1 / b_safe))

    sign_term = 1 + r * np.sign(x - mu)
    sign_term = np.maximum(sign_term, epsilon)

    denominator = 2 * (sigma_safe ** b_safe) * sign_term ** b_safe
    abs_diff = np.abs(x - mu) ** b_safe
    term2 = np.exp(- abs_diff / np.maximum(denominator, epsilon))
    return term1 * term2

def weight(x, mu, sigma, r, b):
    epsilon = 1e-8
    max_value = 1e4

    base_abs = np.maximum(np.abs(x - mu), epsilon)
    base_sgn = np.maximum(np.abs(1 + r * np.sign(x - mu)), epsilon)

    base_abs = np.minimum(base_abs, max_value)
    base_sgn = np.minimum(base_sgn, max_value)

    weight_value = (base_abs ** (b - 2) + epsilon) / (base_sgn ** b + epsilon)
    weight_value = np.minimum(weight_value, max_value)
    return weight_value

def update_pi(q):
    return np.mean(q)

def update_mu(data, q, mu, sigma, r, b):
    weights = np.zeros(len(q))
    result = 0.0
    result2 = 0.0
    for i in range(len(q)):
        weights[i] = weight(data[i], mu, sigma, r, b)
        result += q[i] * weights[i]
        result2 += q[i] * weights[i] * data[i]
    return result2 / np.maximum(result, 1e-12)

def update_sigma(data, pi, q, mu, sigma, r, b):
    weights = np.zeros(len(q))
    result = 0.0
    for i in range(len(q)):
        weights[i] = weight(data[i], mu, sigma, r, b)
        result += q[i] * weights[i] * ((data[i] - mu) ** 2)

    denominator = 2 * len(q) * np.maximum(pi, 1e-12)
    return (b * result / np.maximum(denominator, 1e-12)) ** (1 / b)

def update_r(data, q, mu, r, b):
    epsilon = 1e-8
    numerator = 0.0
    denominator = 0.0
    for i in range(len(q)):
        pos_part = (data[i] >= mu).astype(float) * ((np.maximum(data[i] - mu, 0)) ** b)
        neg_part = (data[i] < mu).astype(float) * ((np.maximum(mu - data[i], 0)) ** b)
        numerator += q[i] * pos_part
        denominator += q[i] * neg_part

    if denominator < epsilon:
        return 0.0

    log_numerator = np.log(numerator + epsilon)
    log_denominator = np.log(denominator + epsilon)
    ratio = np.exp((log_numerator - log_denominator) / (b + 1))

    r_new = 1 - 2 / (ratio + 1)
    r_new = np.clip(r_new, -1, 1)
    return r_new

def f_val(x, mu, sigma, r, b, epsilon=1e-8):
    return np.abs(x - mu) / (sigma * (1 + r * np.sign(x - mu)) + epsilon)

def update_b(data, pi, q, mu, sigma, r, b, max_iter=10):
    term2 = 0.0
    term4 = 0.0
    for i in range(len(q)):
        fv = f_val(data[i], mu, sigma, r, b)
        log_fv = safe_log(fv)
        fv_clipped = np.clip(fv, 1e-5, 1e5)
        log_fv_clipped = np.clip(log_fv, -1e5, 1e5)

        term2 += q[i] * (fv_clipped ** b) * log_fv_clipped
        term4 += q[i] * (fv_clipped ** b) * (log_fv_clipped ** 2)

    term1 = (1 / b + np.log(2) / (b ** 2) + digamma(1 / b) / (b ** 2))
    G = len(data) * np.maximum(pi, 1e-12) * term1 - 0.5 * term2

    term3 = (1 / (b ** 2) + 2 * np.log(2) / (b ** 3) + 2 * digamma(1 / b) / (b ** 3) + polygamma(1, 1 / b) / (b ** 4))
    H = -len(data) * np.maximum(pi, 1e-12) * term3 - 0.5 * term4

    if np.abs(H) < 1e-12:
        return b
    b_new = b - G / H
    if not np.isfinite(b_new) or b_new <= 1e-6:
        b_new = max(b, 1e-3)
    return b_new

def log_likelihood(data, q, pi, mu, sigma, r, b, epsilon=1e-10, max_value=1e5, max_exp=500):
    n, K = q.shape
    likelihood = 0.0

    for i in range(n):
        for j in range(K):
            ln_bj = safe_log(b[j], epsilon)
            ln_term = (1 + 1 / b[j]) * np.log(2)
            gamma_term = safe_log(gamma(1 / b[j]), epsilon)
            sigma_term = safe_log(sigma[j], epsilon)
            q_term = safe_log(np.maximum(q[i, j], epsilon), epsilon)

            abs_diff = (np.abs(data[i] - mu[j])) ** b[j]
            abs_diff = np.minimum(abs_diff, max_value)

            sign_term = (1 + r[j] * np.sign(data[i] - mu[j]))
            sign_term = np.maximum(sign_term, epsilon)

            denominator = 2 * (sigma[j] ** b[j]) * (sign_term ** b[j])
            denominator = np.minimum(denominator, max_value)

            term2 = np.exp(-np.minimum(abs_diff / np.maximum(denominator, epsilon), max_exp))

            likelihood += q[i, j] * (
                ln_bj - ln_term - gamma_term - sigma_term - q_term - term2
            )
    return likelihood

def em_algorithm(data, initial_params, max_iter=300, tol=1e-9):
    pi, mu, sigma, r, b = initial_params
    K = len(pi)
    n = len(data)
    q = np.zeros((n, K))
    ll_prev = -np.inf

    for t in range(max_iter):
        for j in range(K):
            q[:, j] = pi[j] * h(data, mu[j], sigma[j], r[j], b[j])
        q_sum = np.maximum(q.sum(axis=1, keepdims=True), 1e-12)
        q = q / q_sum

        pi_new = np.zeros(K)
        mu_new = np.zeros(K)
        sigma_new = np.zeros(K)
        r_new = np.zeros(K)
        b_new = np.zeros(K)

        for j in range(K):
            if K == 1:
                pi_new[j] = 1.0
            else:
                pi_new[j] = update_pi(q[:, j])

            mu_new[j] = update_mu(data, q[:, j], mu[j], sigma[j], r[j], b[j])
            sigma_new[j] = update_sigma(data, pi_new[j], q[:, j], mu_new[j], sigma[j], r[j], b[j])
            r_new[j] = update_r(data, q[:, j], mu_new[j], r[j], b[j])
            b_new[j] = update_b(data, pi_new[j], q[:, j], mu_new[j], sigma_new[j], r_new[j], b[j])

        pi, mu, sigma, r, b = pi_new, mu_new, sigma_new, r_new, b_new
        ll_curr = log_likelihood(data, q, pi, mu, sigma, r, b)

        if t > 0 and np.isfinite(ll_prev) and ll_prev != 0:
            if np.abs(ll_curr / ll_prev - 1) < tol:
                break
        ll_prev = ll_curr

    return pi, mu, sigma, r, b, ll_curr

def generalized_skew_normal_pdf(x, mu, sigma, b, r):
    sigma = np.maximum(sigma, 1e-12)
    normalization = b / (2 ** (1 + 1 / b) * gamma(1 / b) * sigma)
    denom = sigma * (1 + r * np.sign(x - mu))
    denom = np.maximum(denom, 1e-12)
    hh = np.abs(x - mu) / denom
    pdf = normalization * np.exp(-(hh ** b) / 2.0)
    return np.maximum(pdf, 1e-300)

def generalized_skew_normal_rvs(mu, sigma, b, r, size):
    samples = []
    while len(samples) < size:
        x_candidate = np.random.uniform(mu - 20 * sigma, mu + 20 * sigma)
        u = np.random.uniform(0, 1)
        if u < generalized_skew_normal_pdf(x_candidate, mu, sigma, b, r):
            samples.append(x_candidate)
    return np.array(samples)

def ks_test_mixture(data, pi, mu, sigma, b, r, size):
    n_components = len(b)
    samples = []
    for i in range(n_components):
        samples_i = generalized_skew_normal_rvs(mu[i], sigma[i], b[i], r[i], size=size)
        samples.append(samples_i)

    samples_mixed = []
    for i in range(n_components):
        num_samples = int(pi[i] * size)
        samples_mixed.append(samples[i][:num_samples])

    samples_mixed1 = np.concatenate(samples_mixed) if len(samples_mixed) > 0 else np.array([])
    if len(samples_mixed1) == 0:
        return 1.0, 0.0
    ks_statistic, p_value = stats.ks_2samp(data, samples_mixed1)
    return ks_statistic, p_value

def detect_kde_peaks(samples_mixed, height=0.01, grid_n=1000):
    samples_mixed2 = np.array(samples_mixed).reshape(-1, 1)
    n_samples = len(samples_mixed2)
    std_dev = np.std(samples_mixed2)
    bandwidth_kde = 1.06 * std_dev * (n_samples ** (-1 / 5)) if std_dev > 0 else 0.1

    kde = KernelDensity(bandwidth=bandwidth_kde).fit(samples_mixed2)
    x_vals = np.linspace(np.min(samples_mixed2), np.max(samples_mixed2), grid_n)
    log_density = kde.score_samples(x_vals.reshape(-1, 1))
    density = np.exp(log_density)

    peaks, props = find_peaks(density, height=height)
    peak_positions = x_vals[peaks]
    peak_heights = props.get("peak_heights", density[peaks])
    return samples_mixed2, peak_positions, peak_heights

def half_range_mode_auto_h(data):
    data = np.sort(np.asarray(data))
    n = len(data)
    if n == 0:
        return 0.0, 0.0
    m = max(1, n // 2)
    data_range = np.max(data) - np.min(data)
    h = data_range / m if m > 0 else 0.0

    max_count = 0
    mode_estimate = float(np.mean(data))
    for i in range(n):
        j = i
        while j < n and data[j] <= data[i] + h:
            j += 1
        count = j - i
        if count > max_count:
            max_count = count
            mode_estimate = (data[i] + data[j - 1]) / 2
    return float(mode_estimate), float(h)

def sgn_log_likelihood(theta, x):
    mu, sigma, r, b = theta
    x = np.asarray(x)
    n = len(x)
    if sigma <= 0 or b <= 0:
        return -np.inf

    denominator = sigma * (1 + r * np.sign(x - mu))
    if np.any(denominator == 0):
        return -np.inf
    hhh = np.abs(x - mu) / denominator

    try:
        ll = (
            -n * ((1 + 1 / b) * np.log(2) + gamma(1 / b) + np.log(sigma) - np.log(b))
            - np.sum(0.5 * (hhh ** b))
        )
    except Exception:
        return -np.inf
    return ll

def return_initialparams_for_K(samples_mixed, K_target, height=0.01):
    samples_mixed2, peak_positions, peak_heights = detect_kde_peaks(samples_mixed, height=height)

    if len(peak_positions) == 0 or len(peak_positions) < K_target:
        peak_positions = np.array([np.mean(samples_mixed2)])
        K_target = 1

    order = np.argsort(-peak_heights)
    chosen = order[:K_target]
    peak_positions = peak_positions[chosen]

    nn = NearestNeighbors(n_neighbors=1)
    nn.fit(peak_positions.reshape(-1, 1))
    _, labels = nn.kneighbors(samples_mixed2)
    labels = labels.flatten()

    cluster_sizes = np.bincount(labels, minlength=K_target)
    total_points = len(samples_mixed2)

    pi_init = cluster_sizes / np.maximum(total_points, 1)
    pi_init = np.maximum(pi_init, 1e-12)
    pi_init = pi_init / np.sum(pi_init)

    mu_init = np.zeros(K_target, dtype=np.float64)
    sigma_init = np.zeros(K_target, dtype=np.float64)
    r_init = np.zeros(K_target, dtype=np.float64)
    b_init = np.zeros(K_target, dtype=np.float64)

    for k in range(K_target):
        cluster_points = samples_mixed2[labels == k].flatten()
        if len(cluster_points) == 0:
            mu_init[k] = float(np.mean(samples_mixed2))
            sigma_init[k] = float(np.std(samples_mixed2) + 1e-8)
            r_init[k] = 0.0
            b_init[k] = 2.0
            continue

        mode, _ = half_range_mode_auto_h(cluster_points)
        mu_init[k] = mode

        sigma_init[k] = np.sqrt(np.mean((cluster_points - mu_init[k]) ** 2))
        sigma_init[k] = max(float(sigma_init[k]), 0.1)

        I = (cluster_points <= mu_init[k]).astype(int)
        r_init[k] = 1 - 2 * np.mean(I)

        b_values = np.linspace(0.1, 100, 2000)  
        lls = []
        for bb in b_values:
            theta = (mu_init[k], sigma_init[k], r_init[k], bb)
            lls.append(sgn_log_likelihood(theta, cluster_points))
        b_init[k] = float(b_values[int(np.argmax(lls))])

    return (pi_init, mu_init, sigma_init, r_init, b_init)

def fit_best_by_ks(data, samples_mixed, ks_size=2000, height=0.01,
                   em_max_iter=300, em_tol=1e-9):
    _, peak_positions, _ = detect_kde_peaks(samples_mixed, height=height)
    k_peaks_raw = len(peak_positions)
    if k_peaks_raw <= 0:
        k_peaks_raw = 1

    results = []
    best = None

    for K in range(k_peaks_raw, 0, -1):
        initial_params = return_initialparams_for_K(samples_mixed, K_target=K, height=height)

        pi, mu, sigma, r, b, ll = em_algorithm(
            data, initial_params, max_iter=em_max_iter, tol=em_tol
        )
        ks_stat, p_value = ks_test_mixture(data, pi, mu, sigma, b, r, size=ks_size)

        rec = {
            "K": K, "p_value": p_value, "ks_stat": ks_stat,
            "pi": pi, "mu": mu, "sigma": sigma, "r": r, "b": b, "loglik": ll
        }
        results.append(rec)

        if best is None or p_value > best["p_value"]:
            best = rec

    for rec in results:
        print(f"K={rec['K']}, KS={rec['ks_stat']:.6g}, p-value={rec['p_value']:.6g}")
    print(f"\nBest: K={best['K']}, p-value={best['p_value']:.6g}, KS={best['ks_stat']:.6g}")

    return best, results, best["pi"], best["mu"], best["sigma"], best["r"], best["b"]


def calc_spearman_network(X, feature_names, n_edges_keep):
    n = X.shape[1]
    corr_list = []
    for i in range(n):
        for j in range(i + 1, n):
            rho, _ = spearmanr(X[:, i], X[:, j])
            if np.isnan(rho):
                rho = 0
            corr_list.append((feature_names[i], feature_names[j], abs(rho)))
    corr_list_sorted = sorted(corr_list, key=lambda x: -x[2])[:n_edges_keep]
    G = nx.Graph()
    for u, v, w in corr_list_sorted:
        G.add_edge(u, v, weight=w)
    return G

def network_intersection(G1, G2):
    H = nx.Graph()
    edges1 = set(frozenset(e) for e in G1.edges())
    edges2 = set(frozenset(e) for e in G2.edges())
    common_edges = edges1.intersection(edges2)
    for edge in common_edges:
        u, v = tuple(edge)
        w2 = G2[u][v]['weight']
        H.add_edge(u, v, weight=w2)
    for n in G1.nodes():
        H.add_node(n)
    return H

def build_mst(G, feature_names):
    components = list(nx.connected_components(G))
    largest_comp = max(components, key=len)
    G_sub = G.subgraph(largest_comp).copy()
    T = nx.maximum_spanning_tree(G_sub)
    root = list(G_sub.nodes)[0]
    bfs = nx.bfs_tree(T, root)
    idx = {g: i for i, g in enumerate(feature_names)}
    parents = [-1] * len(feature_names)
    for u in bfs.nodes:
        if u == root:
            continue
        try:
            p = next(bfs.predecessors(u))
            parents[idx[u]] = idx[p]
        except StopIteration:
            continue
    for g in feature_names:
        if g not in G_sub.nodes:
            parents[idx[g]] = -1
    return parents

class TAN_GSN_Classifier:
    def __init__(self, height=0.01, em_max_iter=300, em_tol=1e-9):
        self.class_probs = {}
        self.parents = {}
        self.class_params = {}
        self.regressors = {}
        self.height = height
        self.em_max_iter = em_max_iter
        self.em_tol = em_tol

    def fit(self, X, y, feature_names, string_edge_file):
        self.classes = np.unique(y)
        n_samples, n_features = X.shape

        edges = pd.read_csv(string_edge_file)
        string_edges = set(frozenset([r.node1, r.node2]) for r in edges.itertuples())
        n_edges_keep = len(string_edges)

        for c in self.classes:
            self.class_probs[c] = np.mean(y == c)

        self.parents = {}
        for c in self.classes:
            Xc = X[y == c]
            spearman_net = calc_spearman_network(Xc, feature_names, n_edges_keep)

            string_net = nx.Graph()
            for gene in feature_names:
                string_net.add_node(gene)
            for r in edges.itertuples():
                u, v, w = r.node1, r.node2, r.combined_score
                if u in feature_names and v in feature_names:
                    string_net.add_edge(u, v, weight=w)

            intersect_net = network_intersection(spearman_net, string_net)
            parents = build_mst(intersect_net, feature_names)
            self.parents[c] = parents

        self.class_params = {c: [{} for _ in range(n_features)] for c in self.classes}
        self.regressors = {c: [None] * n_features for c in self.classes}

        for c in self.classes:
            Xc = X[y == c]
            parent_list = self.parents[c]

            for i in range(n_features):
                vals = Xc[:, i]
                p = parent_list[i]

                if p == -1:
                    samples_mixed = vals.astype(float)

                    best, all_results, pi, mu, sigma, r, b = fit_best_by_ks(
                        data=samples_mixed,
                        samples_mixed=samples_mixed,
                        ks_size=len(samples_mixed),
                        height=self.height,
                        em_max_iter=self.em_max_iter,
                        em_tol=self.em_tol
                    )

                    self.class_params[c][i][None] = ("mix", pi, mu, sigma, r, b)
                    self.regressors[c][i] = None

                else:
                    parent_vals = Xc[:, p].reshape(-1, 1)
                    lr = LinearRegression().fit(parent_vals, vals)
                    resid = (vals - lr.predict(parent_vals)).astype(float)

                    samples_mixed = resid
                    best, all_results, pi, mu, sigma, r, b = fit_best_by_ks(
                        data=samples_mixed,
                        samples_mixed=samples_mixed,
                        ks_size=len(samples_mixed),
                        height=self.height,
                        em_max_iter=self.em_max_iter,
                        em_tol=self.em_tol
                    )

                    self.class_params[c][i][None] = ("mix", pi, mu, sigma, r, b)
                    self.regressors[c][i] = lr

    def _mixture_pdf(self, x, pi, mu, sigma, r, b):
        # x: scalar
        pdf = 0.0
        K = len(pi)
        for k in range(K):
            pdf += pi[k] * generalized_skew_normal_pdf(x, mu[k], sigma[k], b[k], r[k])
        return np.maximum(pdf, 1e-300)

    def predict(self, X):
        preds = []
        for x in X:
            best_c, best_logp = None, -1e300
            for c in self.classes:
                lp = np.log(np.maximum(self.class_probs[c], 1e-300))

                for i, val in enumerate(x):
                    p = self.parents[c][i]
                    tag, pi, mu, sigma, r, b = self.class_params[c][i][None]

                    if p == -1:
                        pdf = self._mixture_pdf(val, pi, mu, sigma, r, b)
                    else:
                        mu_hat = self.regressors[c][i].predict([[x[p]]])[0]
                        pdf = self._mixture_pdf(val - mu_hat, pi, mu, sigma, r, b)

                    lp += np.log(np.maximum(pdf, 1e-300))

                if lp > best_logp:
                    best_logp, best_c = lp, c

            preds.append(best_c)
        return np.array(preds)


train_cancer = pd.read_csv(
    r"Acancer_expression_batch_corrected.csv",
    index_col=0
).apply(pd.to_numeric)

train_normal = pd.read_csv(
    r"Anormal_expression_batch_corrected.csv",
    index_col=0
).apply(pd.to_numeric)

test_cancer = pd.read_csv(
    r"Bcancer_expression_batch_corrected.csv",
    index_col=0
).apply(pd.to_numeric)

test_normal = pd.read_csv(
    r"Bnormal_expression_batch_corrected.csv",
    index_col=0
).apply(pd.to_numeric)

genes_in_pathway = pd.read_csv(
    r"HSA05226_genes.csv",
    header=None
)[0].tolist()

common_genes = train_cancer.index.intersection(train_normal.index).intersection(
    test_cancer.index).intersection(test_normal.index)

features = [g for g in common_genes if g in genes_in_pathway]

X_train_df = pd.concat([train_normal, train_cancer], axis=1).T
X_train_df = X_train_df[features]
X_train = X_train_df.values
y_train = np.array([0] * train_normal.shape[1] + [1] * train_cancer.shape[1])

X_test_df = pd.concat([test_normal, test_cancer], axis=1).T
X_test_df = X_test_df[features]
X_test = X_test_df.values
y_test = np.array([0] * test_normal.shape[1] + [1] * test_cancer.shape[1])

string_edge_file = r"HSA05226.csv"

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

clf = TAN_GSN_Classifier(height=0.01, em_max_iter=300, em_tol=1e-9)
clf.fit(X_train_scaled, y_train, features, string_edge_file)

N_RUNS = 50
TEST_RATIO = 0.8

out_dir = r"result"
os.makedirs(out_dir, exist_ok=True)

results = []
n_test = X_test_scaled.shape[0]
n_sample = int(n_test * TEST_RATIO)

for run in range(N_RUNS):
    np.random.seed(run)
    idx = np.random.choice(n_test, n_sample, replace=False)

    X_test_sub = X_test_scaled[idx]
    y_test_sub = y_test[idx]

    y_pred = clf.predict(X_test_sub)

    tn, fp, fn, tp = confusion_matrix(y_test_sub, y_pred).ravel()

    sen = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    spe = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    youden = sen + spe - 1
    acc = accuracy_score(y_test_sub, y_pred)
    f1 = f1_score(y_test_sub, y_pred)

    results.append([run + 1, acc, f1, sen, spe, youden])
    print(f"Run {run+1:02d}: ACC={acc:.4f}, F1={f1:.4f}, Sens={sen:.4f}, Spec={spe:.4f}, Youden={youden:.4f}")

df_results = pd.DataFrame(results, columns=["Run", "ACC", "F1", "Sens", "Spec", "Youden"])
out_csv = os.path.join(out_dir, "tan_gsn_test_results_80perc_random50.csv")
df_results.to_csv(out_csv, index=False)

print("=" * 70)
print(f"Test results saved to:\n{out_csv}")
print("=" * 70)
