In [18]:
#!/usr/bin/env python3
# === DTPSP-style reference time-point selection — single-cell-friendly
# Version with PE_S removed completely.
# Encapsulated Version with Normalization Flag and Fix for AttributeError

from __future__ import annotations
from pathlib import Path
import math, random
from typing import List, Tuple, Optional, Dict, Any
from time import perf_counter

import numpy as np
import anndata as ad
import scanpy as sc

import torch
import torch.nn as nn
import torch.optim as optim

# ------------------ Helpers ------------------
def _fmt_secs(sec):
    if sec < 1e-3: return f"{sec*1e6:.1f}µs"
    if sec < 1.0: return f"{sec*1e3:.1f}ms"
    return f"{sec:.3f}s"

def set_seed(s, device_str="cpu"):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    if "cuda" in device_str:
        torch.cuda.manual_seed_all(s)

def sinusoidal_pe(t_idx: np.ndarray, d_model: int):
    pe = np.zeros((len(t_idx), d_model), dtype=np.float32)
    pos = t_idx[:, None].astype(np.float32)
    div = np.exp(np.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = np.sin(pos * div)
    pe[:, 1::2] = np.cos(pos * div)
    return pe

def _cosine_knn_graph(X_GT, k, device):
    # X_GT is (G, T) or similar features per gene
    Xn = X_GT / (np.linalg.norm(X_GT, axis=1, keepdims=True) + 1e-12)
    S_mat = Xn @ Xn.T
    np.fill_diagonal(S_mat, 0)
    rows, cols, vals = [], [], []
    G_ = X_GT.shape[0]
    for g in range(G_):
        idx = np.argpartition(-S_mat[g], k-1)[:k]
        for j in idx:
            rows.append(g); cols.append(j); vals.append(S_mat[g, j])
    rows.extend(range(G_))
    cols.extend(range(G_))
    vals.extend([1.0]*G_)
    edge_i = torch.tensor([rows, cols], dtype=torch.long, device=device)
    edge_v = torch.tensor(vals, dtype=torch.float32, device=device)
    deg = torch.zeros(G_, device=device).scatter_add_(0, edge_i[0], edge_v)
    dinv = deg.pow(-0.5); dinv[deg==0] = 0
    edge_v = dinv[edge_i[0]] * edge_v * dinv[edge_i[1]]
    return torch.sparse_coo_tensor(edge_i, edge_v, (G_, G_)).coalesce()


# ------------------ Model Classes ------------------
class TinyAE(nn.Module):
    def __init__(self, d_in: int, hidden_dim: int, d_latent: int = 64):
        super().__init__()
        self.d_latent = d_latent  # <--- FIXED: Added missing attribute
        self.enc = nn.Sequential(
            nn.Linear(d_in, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, d_latent), nn.ReLU()
        )
        self.dec = nn.Sequential(
            nn.Linear(d_latent, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, d_in)
        )
    def forward(self, x):
        z = self.enc(x)
        rec = self.dec(z)
        return rec, z

class Regressor(nn.Module):
    def __init__(self, d_in, hidden_dim, dropout):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(d_in, hidden_dim), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        )
        self.head = nn.Linear(hidden_dim, 1)
    def forward(self, x):
        return self.head(self.backbone(x))

class GraphConv(nn.Module):
    def __init__(self, din, dout, adj_matrix):
        super().__init__()
        self.lin = nn.Linear(din, dout, bias=False)
        self.adj = adj_matrix 
    def forward(self, x):
        return torch.relu(torch.sparse.mm(self.adj, self.lin(x)))

class GeneGCN(nn.Module):
    def __init__(self, t_in, gcn_dim1, gcn_dim2, adj_matrix):
        super().__init__()
        self.gc1 = GraphConv(t_in, gcn_dim1, adj_matrix)
        self.gc2 = GraphConv(gcn_dim1, gcn_dim2, adj_matrix)
    def forward(self, X_mask):
        return self.gc2(self.gc1(X_mask))


# ------------------ Data Utilities ------------------
class ReconDatasetMaskedFullT(torch.utils.data.Dataset):
    def __init__(self, X_log, S, gene_idx, T_total):
        super().__init__()
        self.S = np.array(S)
        self.gene_idx = np.array(gene_idx)
        X_full = X_log.T.astype(np.float32)
        mask = np.zeros(T_total, dtype=np.float32); mask[self.S] = 1
        self.X_masked = X_full * mask[None,:]
        self.X_target = X_full
    def __len__(self):
        return len(self.gene_idx)
    def __getitem__(self, idx):
        g = self.gene_idx[idx]
        return (torch.from_numpy(self.X_masked[g]), torch.from_numpy(self.X_target[g]))

class PairSampler(torch.utils.data.IterableDataset):
    def __init__(self, genes_all, target_times_all, S, X_log, ae, gcn, z_gcn_full, batch_g, require_grad, device, pe_dim, T_total):
        super().__init__()
        self.genes_all = genes_all
        self.target_times_all = target_times_all
        self.S = list(S)
        self.X_log = X_log
        self.ae = ae
        self.gcn = gcn
        self.z_gcn_full = z_gcn_full
        self.batch_g = batch_g
        self.require_grad = require_grad
        self.N = len(genes_all)
        self.device = device
        self.pe_dim = pe_dim
        self.T_total = T_total

    def __iter__(self):
        order = np.arange(self.N); np.random.shuffle(order)
        for i in range(0, self.N, self.batch_g):
            idx = order[i:i+self.batch_g]
            g_batch = self.genes_all[idx]
            t_batch = self.target_times_all[idx]
            
            S_arr = np.array(self.S, dtype=np.int64)
            X_full = self.X_log.T.astype(np.float32)
            mask = np.zeros(self.T_total, dtype=np.float32); mask[S_arr] = 1
            X_mask = X_full * mask[None, :]

            gv = X_mask[g_batch].astype(np.float32)
            gv_t = torch.from_numpy(gv).to(self.device)

            if self.require_grad:
                z_ae = self.ae.enc(gv_t)
            else:
                with torch.no_grad():
                    z_ae = self.ae.enc(gv_t)

            if self.z_gcn_full is not None:
                z_gcn = self.z_gcn_full[torch.from_numpy(g_batch).to(self.device)]
            else:
                X_G_T_masked = torch.from_numpy(X_mask).to(self.device)
                with torch.set_grad_enabled(self.require_grad):
                    z_all = self.gcn(X_G_T_masked)
                z_gcn = z_all[torch.from_numpy(g_batch).to(self.device)]

            pe_t_np = sinusoidal_pe(t_batch, d_model=self.pe_dim)
            pe_t = torch.from_numpy(pe_t_np).to(self.device)

            Xb = torch.cat([z_ae, z_gcn, pe_t], dim=1)
            yb = torch.from_numpy(
                self.X_log[t_batch, g_batch].astype(np.float32)[:,None]
            ).to(self.device)
            yield Xb, yb


# ------------------ Main Class ------------------
class DTPSP_Selector:
    def __init__(self, 
                 seed=1234, 
                 device=None,
                 max_ref=15,
                 train_frac=0.8,
                 shuffle_split=True,
                 ae_epochs=10,
                 reg_epochs=10,
                 ft_epochs=10,
                 batch_g=128,
                 lr=1e-3,
                 ft_lr_factor=0.2,
                 lambda_r=1.0,
                 hidden=128,
                 dropout=0.0,
                 pe_dim=16,
                 k_neighb=30,
                 gcn_dim1=128,
                 gcn_dim2=64,
                 target_lib=1e6,
                 beam_width=32,
                 precompute_gcn=False):
        
        self.SEED = seed
        self.DEVICE = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        
        self.MAX_REF = max_ref
        self.TRAIN_FRAC = train_frac
        self.SHUFFLE_SPLIT = shuffle_split
        
        self.AE_EPOCHS = ae_epochs
        self.REG_EPOCHS = reg_epochs
        self.FT_EPOCHS = ft_epochs
        self.BATCH_G = batch_g
        self.LR = lr
        self.FT_LR_FACTOR = ft_lr_factor
        self.LAMBDA_R = lambda_r
        
        self.HIDDEN = hidden
        self.DROPOUT = dropout
        self.PE_DIM = pe_dim
        
        self.K_NEIGHB = k_neighb
        self.GCN_DIM1 = gcn_dim1
        self.GCN_DIM2 = gcn_dim2
        
        self.TARGET_LIB = target_lib
        self.BEAM_WIDTH = beam_width
        self.PRECOMPUTE_GCN = precompute_gcn
        
        set_seed(self.SEED, self.DEVICE)

    def _normalize(self, adata):
        if self.TARGET_LIB is not None:
            sc.pp.normalize_total(adata, target_sum=self.TARGET_LIB)
        sc.pp.log1p(adata)
        return adata

    def build_gene_masked_fullT_tensor(self, X_log, S):
        T_total, G_total = X_log.shape
        X_full = X_log.T.astype(np.float32)
        mask = np.zeros(T_total, dtype=np.float32)
        mask[np.array(S, dtype=np.int64)] = 1.0
        X_mask = X_full * mask[None, :]
        return torch.from_numpy(X_mask).to(self.DEVICE)

    def fit(self, adata_input: ad.AnnData, normalize_data: bool = False, verbose: bool = True):
        set_seed(self.SEED, self.DEVICE)
        
        if normalize_data:
            adata = self._normalize(adata_input.copy())
        else:
            adata = adata_input

        X0 = adata.X.A if hasattr(adata.X, "A") else np.array(adata.X, dtype=np.float32)
        
        T_obs, G_var = adata.n_obs, adata.n_vars
        if X0.shape == (T_obs, G_var):
            X_raw = X0.astype(np.float32)
        elif X0.shape == (G_var, T_obs):
            X_raw = X0.T.astype(np.float32)
        else:
            X_raw = (X0 if X0.shape[0] <= X0.shape[1] else X0.T).astype(np.float32)

        self.X_log = X_raw
        self.T, self.G = self.X_log.shape
        self.times = np.arange(self.T, dtype=np.int64)

        perm = np.random.permutation(self.G) if self.SHUFFLE_SPLIT else np.arange(self.G)
        n_train = max(1, int(round(self.TRAIN_FRAC * self.G)))
        if n_train >= self.G: n_train = self.G - 1
        genes_train = perm[:n_train]
        genes_val = perm[n_train:]
        if len(genes_val) == 0:
            genes_val = perm[-1:]
            genes_train = perm[:-1]
            
        self.genes_train = genes_train
        self.genes_val = genes_val

        self.A_NORM = _cosine_knn_graph(self.X_log.T, self.K_NEIGHB, self.DEVICE)

        S_sel, pack_sel, hist = self.beam_search_select(
            beam_width=self.BEAM_WIDTH,
            max_ref=self.MAX_REF,
            genes_train_in=self.genes_train,
            genes_val_in=self.genes_val,
            verbose=verbose
        )
        return S_sel, pack_sel, hist

    def make_pairs(self, S, genes_all, use_all_targets=True, n_targets_per_gene=None):
        remaining = np.setdiff1d(self.times, np.array(S))
        if len(remaining) == 0:
            remaining = self.times.copy()
        if use_all_targets:
            t_list = np.repeat(remaining, len(genes_all))
            g_list = np.tile(genes_all, len(remaining))
        else:
            n = int(n_targets_per_gene or 1)
            choices = np.random.choice(remaining, size=(len(genes_all), n), replace=True)
            g_list = np.repeat(genes_all, n)
            t_list = choices.ravel()
        return g_list.astype(np.int64), t_list.astype(np.int64)

    def pretrain_ae_stage1(self, S, gene_idx=None, d_latent=64):
        assert len(S) >= 1
        if gene_idx is None: gene_idx = self.genes_train
        ds = ReconDatasetMaskedFullT(self.X_log, S, gene_idx, self.T)
        dl = torch.utils.data.DataLoader(ds, batch_size=2048, shuffle=True)
        
        ae = TinyAE(self.T, self.HIDDEN, d_latent).to(self.DEVICE)
        opt = optim.Adam(ae.parameters(), lr=self.LR)
        loss_fn = nn.MSELoss()
        S_idx = torch.tensor(S, dtype=torch.long, device=self.DEVICE)

        ae.train()
        for _ in range(self.AE_EPOCHS):
            for xb, yb in dl:
                xb = xb.to(self.DEVICE)
                yb = yb.to(self.DEVICE)
                rec, _ = ae(xb)
                loss = loss_fn(
                    rec.index_select(1, S_idx),
                    yb.index_select(1, S_idx)
                )
                opt.zero_grad()
                loss.backward()
                opt.step()
        return ae

    def stage2_train_regressor(self, S, train_pairs, val_pairs, ae):
        g_tr, t_tr = train_pairs
        g_va, t_va = val_pairs

        gcn = GeneGCN(self.T, self.GCN_DIM1, self.GCN_DIM2, self.A_NORM).to(self.DEVICE)

        if self.PRECOMPUTE_GCN:
            for p in gcn.parameters(): p.requires_grad = False
            with torch.no_grad():
                X_mask = self.build_gene_masked_fullT_tensor(self.X_log, S)
                z_gcn_full = gcn(X_mask)
        else:
            z_gcn_full = None

        d_in = ae.d_latent + self.GCN_DIM2 + self.PE_DIM
        reg = Regressor(d_in, self.HIDDEN, self.DROPOUT).to(self.DEVICE)

        for p in ae.parameters(): p.requires_grad = False

        params = list(reg.parameters())
        if not self.PRECOMPUTE_GCN:
            params += list(gcn.parameters())
        opt = optim.Adam(params, lr=self.LR)
        loss_fn = nn.L1Loss()

        tr_ds = PairSampler(g_tr, t_tr, S, self.X_log, ae, gcn, z_gcn_full, self.BATCH_G, True, self.DEVICE, self.PE_DIM, self.T)
        va_ds = PairSampler(g_va, t_va, S, self.X_log, ae, gcn, z_gcn_full, self.BATCH_G, False, self.DEVICE, self.PE_DIM, self.T)
        tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=None)
        va_dl = torch.utils.data.DataLoader(va_ds, batch_size=None)

        reg.train()
        gcn.train() if (not self.PRECOMPUTE_GCN) else gcn.eval()
        for _ in range(self.REG_EPOCHS):
            for xb, yb in tr_dl:
                pred = reg(xb)
                loss = loss_fn(pred, yb)
                opt.zero_grad()
                loss.backward()
                opt.step()

        reg.eval(); gcn.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for xb, yb in va_dl:
                yp = reg(xb)
                y_true.append(yb.cpu().numpy().ravel())
                y_pred.append(yp.cpu().numpy().ravel())
        y_true = np.hstack(y_true)
        y_pred = np.hstack(y_pred)
        mae = float(np.mean(np.abs(y_pred - y_true)))
        y_bar = float(np.mean(y_true))
        sst = float(np.sum((y_true - y_bar)**2))
        R2 = float("nan") if sst <= 0 else 1 - (np.sum((y_pred - y_true)**2) / sst)
        return reg, gcn, z_gcn_full, mae, R2

    def stage3_joint_finetune(self, S, train_pairs, val_pairs, ae, reg, gcn, z_gcn_full):
        lr_ft = self.LR * self.FT_LR_FACTOR
        for p in ae.parameters(): p.requires_grad = True

        if self.PRECOMPUTE_GCN:
            for p in gcn.parameters(): p.requires_grad = False
            params = list(ae.parameters()) + list(reg.parameters())
        else:
            params = list(ae.parameters()) + list(reg.parameters()) + list(gcn.parameters())
        opt = optim.Adam(params, lr=lr_ft)
        loss_rec = nn.MSELoss()
        loss_reg = nn.L1Loss()

        recon_ds = ReconDatasetMaskedFullT(self.X_log, S, self.genes_train, self.T)
        recon_dl = torch.utils.data.DataLoader(recon_ds, batch_size=2048, shuffle=True)
        S_idx = torch.tensor(S, dtype=torch.long, device=self.DEVICE)

        g_tr, t_tr = train_pairs
        g_va, t_va = val_pairs
        tr_ds = PairSampler(g_tr, t_tr, S, self.X_log, ae, gcn, z_gcn_full, self.BATCH_G, True, self.DEVICE, self.PE_DIM, self.T)
        va_ds = PairSampler(g_va, t_va, S, self.X_log, ae, gcn, z_gcn_full, self.BATCH_G, False, self.DEVICE, self.PE_DIM, self.T)
        tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=None)
        va_dl = torch.utils.data.DataLoader(va_ds, batch_size=None)

        for _ in range(self.FT_EPOCHS):
            ae.train(); reg.train()
            gcn.eval() if self.PRECOMPUTE_GCN else gcn.train()
            recon_iter = iter(recon_dl)

            for xb, yb in tr_dl:
                pred = reg(xb)
                L_R = loss_reg(pred, yb)

                try:
                    xr, yr = next(recon_iter)
                except StopIteration:
                    recon_iter = iter(recon_dl)
                    xr, yr = next(recon_iter)
                xr = xr.to(self.DEVICE); yr = yr.to(self.DEVICE)
                rec, _ = ae(xr)
                L_AE = loss_rec(
                    rec.index_select(1, S_idx),
                    yr.index_select(1, S_idx)
                )

                L = L_AE + self.LAMBDA_R * L_R
                opt.zero_grad()
                L.backward()
                opt.step()

        # ---- metrics ----
        reg.eval(); gcn.eval()
        
        tr_ds2 = PairSampler(g_tr, t_tr, S, self.X_log, ae, gcn, z_gcn_full, self.BATCH_G, False, self.DEVICE, self.PE_DIM, self.T)
        tr_dl2 = torch.utils.data.DataLoader(tr_ds2, batch_size=None)
        with torch.no_grad():
            y_true_tr, y_pred_tr = [], []
            for xb, yb in tr_dl2:
                yp = reg(xb)
                y_true_tr.append(yb.cpu().numpy().ravel())
                y_pred_tr.append(yp.cpu().numpy().ravel())
        y_true_tr = np.hstack(y_true_tr)
        y_pred_tr = np.hstack(y_pred_tr)
        mae_tr = float(np.mean(np.abs(y_pred_tr - y_true_tr)))
        mse_tr = float(np.mean((y_pred_tr - y_true_tr)**2))
        y_bar = float(np.mean(y_true_tr))
        sst = np.sum((y_true_tr - y_bar)**2)
        R2_tr = float("nan") if sst <= 0 else 1 - (np.sum((y_pred_tr - y_true_tr)**2) / sst)

        va_ds2 = PairSampler(g_va, t_va, S, self.X_log, ae, gcn, z_gcn_full, self.BATCH_G, False, self.DEVICE, self.PE_DIM, self.T)
        va_dl2 = torch.utils.data.DataLoader(va_ds2, batch_size=None)
        with torch.no_grad():
            y_true_va, y_pred_va = [], []
            for xb, yb in va_dl2:
                yp = reg(xb)
                y_true_va.append(yb.cpu().numpy().ravel())
                y_pred_va.append(yp.cpu().numpy().ravel())
        y_true_va = np.hstack(y_true_va)
        y_pred_va = np.hstack(y_pred_va)
        mae_va = float(np.mean(np.abs(y_true_va - y_pred_va)))
        mse_va = float(np.mean((y_true_va - y_pred_va)**2))
        y_bar = float(np.mean(y_true_va))
        sst = np.sum((y_true_va - y_bar)**2)
        R2_va = float("nan") if sst <= 0 else 1 - (np.sum((y_pred_va - y_true_va)**2) / sst)

        return ae, reg, gcn, z_gcn_full, mae_tr, R2_tr, mse_tr, mae_va, R2_va, mse_va

    def _score_and_train_for_S(self, S_try, genes_train_in, genes_val_in):
        ae = self.pretrain_ae_stage1(S_try, gene_idx=genes_train_in)
        
        def _targets_per_gene_for_len(s_len):
            return 1 if s_len <= 1 else 2
            
        ntpg = _targets_per_gene_for_len(len(S_try))
        g_tr, t_tr = self.make_pairs(S_try, genes_train_in, use_all_targets=False, n_targets_per_gene=ntpg)
        g_va, t_va = self.make_pairs(S_try, genes_val_in, use_all_targets=True)

        reg, gcn, z_gcn_full, mae2, R22 = self.stage2_train_regressor(
            S_try, (g_tr, t_tr), (g_va, t_va), ae
        )

        ae, reg, gcn, z_gcn_full, mae_tr, R2_tr, mse_tr, mae_va, R2_va, mse_va = self.stage3_joint_finetune(
            S_try, (g_tr, t_tr), (g_va, t_va), ae, reg, gcn, z_gcn_full
        )

        return (ae, reg, gcn, z_gcn_full), mae_tr, R2_tr, mse_tr, mae_va, R2_va, mse_va

    def beam_search_select(self, beam_width, max_ref, genes_train_in, genes_val_in, verbose=False):
        t0 = perf_counter()
        init_entries = []
        for t0_idx in range(self.T):
            S0 = [t0_idx]
            pack0, mae_tr0, R2_tr0, mse_tr0, mae_va0, R2_va0, mse_va0 = self._score_and_train_for_S(
                S0, genes_train_in, genes_val_in
            )
            init_entries.append({
                "S": S0, "pack": pack0,
                "mae_tr": mae_tr0, "R2_tr": R2_tr0, "mse_tr": mse_tr0,
                "mae_va": mae_va0, "R2_va": R2_va0, "mse_va": mse_va0,
                "added_t": t0_idx
            })
        init_entries.sort(
            key=lambda e: (e["mae_tr"], -np.nan_to_num(e["R2_tr"], nan=-1e9))
        )
        beam = init_entries[:beam_width]

        step_time = perf_counter() - t0
        history = [{
            "step": 1, "S": beam[0]["S"].copy(),
            "added_t": beam[0]["added_t"],
            "TRAIN_MAE": round(beam[0]["mae_tr"], 6),
            "TRAIN_MSE": round(beam[0]["mse_tr"], 6),
            "TRAIN_R2": round(beam[0]["R2_tr"], 6),
            "VAL_MAE": round(beam[0]["mae_va"], 6),
            "VAL_MSE": round(beam[0]["mse_va"], 6),
            "VAL_R2": round(beam[0]["R2_va"], 6),
            "time_sec": step_time
        }]

        if verbose:
            print(
                f"[Step 1] S={history[-1]['S']} added_t={history[-1]['added_t']} "
                f"MAE_train={history[-1]['TRAIN_MAE']:.6f} MSE_train={history[-1]['TRAIN_MSE']:.6f} "
                f"R2_train={history[-1]['TRAIN_R2']:.6f} | "
                f"MAE_val={history[-1]['VAL_MAE']:.6f} MSE_val={history[-1]['VAL_MSE']:.6f} "
                f"R2_val={history[-1]['VAL_R2']:.6f} time={_fmt_secs(history[-1]['time_sec'])}"
            )

        depth = 1
        L_target = min(max_ref, self.T)
        while depth < L_target:
            depth += 1
            d_start = perf_counter()
            candidates = []
            for entry in beam:
                S_curr = entry["S"]
                remaining = [t for t in range(self.T) if t not in S_curr]
                for t_star in remaining:
                    S_try = S_curr + [t_star]
                    pack, mae_tr, R2_tr, mse_tr, mae_va, R2_va, mse_va = \
                        self._score_and_train_for_S(S_try, genes_train_in, genes_val_in)
                    candidates.append({
                        "S": S_try, "pack": pack,
                        "mae_tr": mae_tr, "R2_tr": R2_tr, "mse_tr": mse_tr,
                        "mae_va": mae_va, "R2_va": R2_va, "mse_va": mse_va,
                        "added_t": t_star
                    })
            candidates.sort(
                key=lambda e: (e["mae_tr"], -np.nan_to_num(e["R2_tr"], nan=-1e9))
            )
            beam = candidates[:beam_width]
            d_time = perf_counter() - d_start
            top = beam[0]
            history.append({
                "step": depth,
                "S": top["S"].copy(),
                "added_t": top["added_t"],
                "TRAIN_MAE": round(top["mae_tr"], 6),
                "TRAIN_MSE": round(top["mse_tr"], 6),
                "TRAIN_R2": round(top["R2_tr"], 6),
                "VAL_MAE": round(top["mae_va"], 6),
                "VAL_MSE": round(top["mse_va"], 6),
                "VAL_R2": round(top["R2_va"], 6),
                "time_sec": d_time
            })
            if verbose:
                print(
                    f"[Step {depth}] S={history[-1]['S']} added_t={history[-1]['added_t']} "
                    f"MAE_train={history[-1]['TRAIN_MAE']:.6f} MSE_train={history[-1]['TRAIN_MSE']:.6f} "
                    f"R2_train={history[-1]['TRAIN_R2']:.6f} | "
                    f"MAE_val={history[-1]['VAL_MAE']:.6f} MSE_val={history[-1]['VAL_MSE']:.6f} "
                    f"R2_val={history[-1]['VAL_R2']:.6f} time={_fmt_secs(history[-1]['time_sec'])}"
                )

        best = min(
            beam,
            key=lambda e: (e["mae_tr"], -np.nan_to_num(e["R2_tr"], nan=-1e9))
        )
        return best["S"], best["pack"], history
    
    def predict_full_from_pack(self, S, pack):
        ae, reg, gcn, z_gcn_full = pack
        reg.eval(); ae.eval(); gcn.eval()
        
        T_, G_ = self.T, self.G
        allg = np.arange(G_)
        P = np.zeros((G_, T_), dtype=np.float32)
        
        with torch.no_grad():
            for t in range(T_):
                for i in range(0, G_, self.BATCH_G):
                    g_slice = allg[i:i+self.BATCH_G]
                    t_slice = np.full_like(g_slice, t)
                    
                    S_arr = np.array(S, dtype=np.int64)
                    X_full = self.X_log.T.astype(np.float32)
                    mask = np.zeros(self.T, dtype=np.float32); mask[S_arr] = 1
                    X_mask = X_full * mask[None, :]

                    gv = X_mask[g_slice].astype(np.float32)
                    gv_t = torch.from_numpy(gv).to(self.DEVICE)
                    
                    z_ae = ae.enc(gv_t)
                    
                    if z_gcn_full is not None:
                        z_gcn = z_gcn_full[torch.from_numpy(g_slice).to(self.DEVICE)]
                    else:
                        X_G_T_masked = torch.from_numpy(X_mask).to(self.DEVICE)
                        z_all = gcn(X_G_T_masked)
                        z_gcn = z_all[torch.from_numpy(g_slice).to(self.DEVICE)]

                    pe_t_np = sinusoidal_pe(t_slice, d_model=self.PE_DIM)
                    pe_t = torch.from_numpy(pe_t_np).to(self.DEVICE)

                    xb = torch.cat([z_ae, z_gcn, pe_t], dim=1)
                    yp = reg(xb).cpu().numpy().ravel()
                    P[i:i+len(g_slice), t] = yp
        return P


if __name__ == "__main__":
    DATA_PATH = Path("bulk_data_lungalveoli_TPS.h5ad")
    try:
        adata = ad.read_h5ad(DATA_PATH)
    except Exception as e:
        print(f"Could not load {DATA_PATH}: {e}")
        print("Creating dummy data for demo...")
        T, G = 10, 500
        X_dummy = np.random.randn(T, G).astype(np.float32)
        adata = ad.AnnData(X=X_dummy)

    dtpsp = DTPSP_Selector(
        seed=1234,
        max_ref=15,
        beam_width=32,
        precompute_gcn=False
    )

    print("Starting Beam Search...")
    S_sel, pack_sel, hist = dtpsp.fit(adata, normalize_data=False)
    
    print("\nSelected S:", S_sel)
    
    pred_GT = dtpsp.predict_full_from_pack(S_sel, pack_sel)
    print("Prediction shape:", pred_GT.shape)

Starting Beam Search...
[Step 1] S=[6] added_t=6 MAE_train=0.074042 MSE_train=0.083836 R2_train=0.754327 | MAE_val=0.070176 MSE_val=0.082016 R2_val=0.705493 time=41.706s
[Step 2] S=[6, 0] added_t=0 MAE_train=0.046605 MSE_train=0.033404 R2_train=0.902101 | MAE_val=0.046652 MSE_val=0.034845 R2_val=0.870695 time=499.206s
[Step 3] S=[0, 3, 6] added_t=6 MAE_train=0.036371 MSE_train=0.019087 R2_train=0.943127 | MAE_val=0.039161 MSE_val=0.024394 R2_val=0.910364 time=1892.235s
[Step 4] S=[4, 0, 6, 5] added_t=5 MAE_train=0.031413 MSE_train=0.017745 R2_train=0.946225 | MAE_val=0.033447 MSE_val=0.020655 R2_val=0.922659 time=1473.084s
[Step 5] S=[5, 6, 0, 2, 3] added_t=3 MAE_train=0.024092 MSE_train=0.011499 R2_train=0.964533 | MAE_val=0.027391 MSE_val=0.015175 R2_val=0.947022 time=1087.636s
[Step 6] S=[5, 0, 6, 2, 3, 4] added_t=4 MAE_train=0.016372 MSE_train=0.008374 R2_train=0.972686 | MAE_val=0.018590 MSE_val=0.012695 R2_val=0.956075 time=725.940s
[Step 7] S=[6, 4, 0, 2, 5, 3, 1] added_t=1 MAE_

In [None]:
import pdb,sys,os
import warnings
warnings.filterwarnings('ignore')
import anndata
import scanpy as sc
sc.settings.verbosity = 0
import argparse
import copy
import numpy as np
import scipy
import timeit

from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
from typing import Tuple
import scSemiProfiler as semi
from scSemiProfiler.utils import *
name = 'single_cell_inference_project_lung_Alveolus_high'
bulk = 'bulk_data_lung_Alveolus.h5ad'
logged = False
normed = False
geneselection = False
batch = 3

t0 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t0.h5ad")
t1 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t1.h5ad")
t2 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t2.h5ad")
t3 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t3.h5ad")
t4 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t4.h5ad")
t5 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t5.h5ad")
t6 = anndata.read_h5ad("single_cell_inference_project_lung_Alveolus_high/sample_sc/t6.h5ad")

In [None]:
import pdb, sys, os
import anndata
import scanpy as sc
import argparse
import copy
import numpy as np
from sklearn.metrics import pairwise_distances
from typing import Union
import matplotlib.pyplot as plt

def initsetup(name: str, bulk: str, logged: bool = False, normed: bool = True,
              geneselection: Union[bool, int] = True, representatives: list = None) -> None:
    """
    Initial setup of the semi-profiling pipeline, processing the bulk data,
    and assigning each sample to the nearest fixed representative.

    Parameters
    ----------
    name : str
        Project name.
    bulk : str
        Path to bulk data as an h5ad file.
    logged : bool
        Whether the data has been logged or not.
    normed : bool
        Whether the library size has been normalized or not.
    geneselection : bool or int
        Perform gene selection (boolean) or specify number of highly variable genes.
    representatives : list
        Indices of fixed representative samples.

    Returns
    -------
    None

    Example
    -------
    >>> name = 'runexample'
    >>> bulk = 'example_data/bulkdata.h5ad'
    >>> logged = False
    >>> normed = True
    >>> geneselection = False
    >>> representatives = [0, 2, 5]  # Fixed representative indices
    >>> initsetup(name, bulk, logged, normed, geneselection, representatives)
    """

    print('Start initial setup')

    if not os.path.isdir(name):
        os.system('mkdir ' + name)
    else:
        print(name + ' exists. Please choose another name.')
        return

    if not os.path.isdir(name + '/figures'):
        os.system('mkdir ' + name + '/figures')

    bulkdata = anndata.read_h5ad(bulk)

    if not normed:
        if logged:
            print('Bad data preprocessing. Normalize library size before log-transformation.')
            return
        sc.pp.normalize_total(bulkdata, target_sum=1e4)

    if not logged:
        sc.pp.log1p(bulkdata)

    sids = list(bulkdata.obs['sample_ids'])
    with open(name + '/sids.txt', 'w') as f:
        for sid in sids:
            f.write(sid + '\n')

    if geneselection is False:
        hvgenes = np.array(bulkdata.var.index)
    elif geneselection is True:
        sc.pp.highly_variable_genes(bulkdata, n_top_genes=6000)
        bulkdata = bulkdata[:, bulkdata.var.highly_variable]
        hvgenes = np.array(bulkdata.var.index)[bulkdata.var.highly_variable]
    else:
        sc.pp.highly_variable_genes(bulkdata, n_top_genes=int(geneselection))
        bulkdata = bulkdata[:, bulkdata.var.highly_variable]
        hvgenes = np.array(bulkdata.var.index)[bulkdata.var.highly_variable]
    np.save(name + '/hvgenes.npy', hvgenes)

    n_comps = min(100, bulkdata.X.shape[0] - 1)
    sc.tl.pca(bulkdata, n_comps=n_comps)

    bulkdata.write(name + '/processed_bulkdata.h5ad')

    if representatives is None or len(representatives) == 0:
        print("Please provide fixed representative indices.")
        return

    representatives_pca = bulkdata.obsm['X_pca'][representatives]
    distances = pairwise_distances(bulkdata.obsm['X_pca'], representatives_pca)
    cluster_labels = np.argmin(distances, axis=1)

    # Store the cluster labels
    if not os.path.isdir(name + '/status'):
        os.system('mkdir ' + name + '/status')

    with open(name + '/status/init_cluster_labels.txt', 'w') as f:
        for label in cluster_labels:
            f.write(str(label) + '\n')

    with open(name + '/status/init_representatives.txt', 'w') as f:
        for rep in representatives:
            f.write(str(rep) + '\n')

    print('Initial setup finished. Among ' + str(len(sids)) +
          ' total samples, assigned to fixed representatives:')
    for i, rep in enumerate(representatives):
        print(f"Cluster {i} representative: {sids[rep]}")

    return
     

In [None]:
initsetup(name,bulk,logged=logged,normed=normed,geneselection=True,representatives=[0,3,6])

In [None]:
import anndata as ad
import hdf5plugin

reps_processed = ad.concat([t0, t3,t6], axis=0, join='inner')

print(f"Number of observations (cells): {reps_processed.n_obs}")
print(f"Number of variables (genes): {reps_processed.n_vars}")

if 'cell_id' not in reps_processed.obs.columns:
    reps_processed.obs['cell_id'] = reps_processed.obs_names

if 'n_genes' not in reps_processed.obs.columns:
    reps_processed.obs['n_genes'] = (reps_processed.X > 0).sum(axis=1)


if 'gene_ids' not in reps_processed.var.columns:
    reps_processed.var['gene_ids'] = reps_processed.var_names


reps_processed.obs.columns = reps_processed.obs.columns.astype(str)
reps_processed.var.columns = reps_processed.var.columns.astype(str)

# Convert object dtype columns in obs and var to strings
for col in reps_processed.obs.columns:
    if reps_processed.obs[col].dtype == 'object':
        reps_processed.obs[col] = reps_processed.obs[col].astype(str)

for col in reps_processed.var.columns:
    if reps_processed.var[col].dtype == 'object':
        reps_processed.var[col] = reps_processed.var[col].astype(str)

print("Data types in obs:")
print(reps_processed.obs.dtypes)
print("Data types in var:")
print(reps_processed.var.dtypes)
import numpy as np

hvgenes = np.load(name + '/hvgenes.npy', allow_pickle=True)

print("First few genes in hvgenes:", hvgenes[:5])

reps_genes = reps_processed.var_names

common_genes = np.intersect1d(hvgenes, reps_genes)

print(f"Number of genes in hvgenes: {len(hvgenes)}")
print(f"Number of genes in reps_processed: {len(reps_genes)}")
print(f"Number of common genes: {len(common_genes)}")

missing_in_reps = np.setdiff1d(hvgenes, reps_genes)
print(f"Number of genes in hvgenes not in reps_processed: {len(missing_in_reps)}")

hvgenes_in_reps_ordered = [gene for gene in hvgenes if gene in reps_genes]




reps_filtered = reps_processed[:, hvgenes_in_reps_ordered].copy()


assert all(reps_filtered.var_names == hvgenes_in_reps_ordered), "Gene order does not match!"

In [None]:
reps_filtered.write_h5ad(
      name+'/representative_sc.h5ad',
      compression=hdf5plugin.FILTERS["zstd"]
    )

In [None]:
semi.scprocess(name=name,singlecell=name+'/representative_sc.h5ad',normed=True,logged=False,cellfilter=False,threshold=1e-3,geneset=True,weight=0.5,k=15)

In [None]:
# read the representatives and clusterings
sids = []
f = open(name + '/sids.txt','r')
lines = f.readlines()
for l in lines:
    sids.append(l.strip())
f.close()

repres = []
f=open(name + '/status/init_representatives.txt','r')
lines = f.readlines()
f.close()
for l in lines:
    repres.append(int(l.strip()))

cl = []
f=open(name + '/status/init_cluster_labels.txt','r')
lines = f.readlines()
f.close()
for l in lines:
    cl.append(int(l.strip()))

print('representatives:',repres)
print('cluster labels:',cl)

In [None]:
import torch

torch.cuda.empty_cache()


representatives = name + '/status/init_representatives.txt'
cluster = name + '/status/init_cluster_labels.txt'

bulktype = 'pseudobulk'
semi.scinfer(name, representatives,cluster,bulktype, device='cuda:0')

In [None]:
cluster_labels = cl
semisdata = assemble_cohort(name,
                repres,
                cl,
                celltype_key = 'celltype',
                sample_info_keys = ['sample_ids'],
                bulkpath= 'bulk_data_lung_Alveolus.h5ad')
     

In [None]:

# read the combined adata of gound true single cell data for subsequent comparison
combined_adata = anndata.read_h5ad(name+"/combined_data.h5ad")

In [None]:
#filter out NA celltypes
import pandas as pd

invalid_values = [None, pd.NA, float('nan'), 'nan', 'NA']

def filter_invalid_celltypes(adata):
    return adata[~adata.obs['celltype'].astype(str).str.strip().isin(invalid_values)].copy()

combined_adata = filter_invalid_celltypes(combined_adata)
semisdata = filter_invalid_celltypes(semisdata)

print(f"Filtered combined_adata cells: {combined_adata.n_obs}")
print(f"Filtered semisdata cells: {semisdata.n_obs}")

combined_adata.write_h5ad('combined_adata_filtered.h5ad')
semisdata.write_h5ad('semisdata_filtered.h5ad')

In [None]:

# visualize distribution of assembled ground truth data and semi-profiled data
combined_data,gtdata,semidata = compare_umaps(
            semidata = semisdata,
            gtdata = combined_adata,
            name = name,
            representatives = name + '/status/init_representatives.txt',
            cluster_labels = name + '/status/init_cluster_labels.txt',
            celltype_key = 'celltype',
            save = name+"/figures"
            )
     

In [None]:
def composition_by_group(
    adata: anndata.AnnData,
    colormap: Union[str, list] = None,
    groupby: str = None,
    title: str = 'Cell type composition',
    save: str = None,
    name: str = None
) -> None:
    """
    Visualizing the cell type composition in each group.

    Parameters
    ----------
    adata:
        The dataset to investigate.
    colormap:
        The colormap for visualization.
    groupby:
        The key in .obs specifying groups.
    title:
        Plot title.
    save:
        Path to save the plot as a PDF file.
    name:
        Folder to save the file in, if provided.

    Returns
    -------
        None

    Example
    -------
    >>> groupby = 'states_collection_sum'
    >>> composition_by_group2(
    >>>     adata=gtdata,
    >>>     groupby=groupby,
    >>>     title='Ground truth'
    >>> )
    """
    totaltypes = np.array(adata.obs['celltype'].cat.categories)

    if colormap is None:
        colormap = adata.uns['celltypes_colors']

    conditions = np.unique(adata.obs[groupby])
    n = conditions.shape[0]
    percentages = []

    for i in range(conditions.shape[0]):
        condition_prop = celltype_proportion(adata[adata.obs[groupby] == conditions[i]], totaltypes)
        percentages.append(condition_prop)

    fig, axs = plt.subplots(n, 1, figsize=(n, 1))
    axs[0].set_title(title)

    for j in range(n):
        for i in range(len(totaltypes)):
            axs[j].barh(conditions[j], percentages[j][i], left=sum(percentages[j][:i]), color=colormap[i])
            axs[j].set_xlim([0, 1])
            axs[j].set_yticklabels([])
            axs[j].yaxis.set_tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

            if j != n:
                axs[j].set_xticklabels([])

        axs[j].text(-0.01, 0, conditions[j], ha='right', va='center')

    patches = [mpatches.Patch(color=colormap[i], label=totaltypes[i]) for i in range(len(totaltypes))]
    axs[-1].legend(handles=patches, loc='center left', bbox_to_anchor=(1.1, n))

    plt.xlabel('Proportion')

    if save is not None:
        save_path = f"{name}/{save}.pdf" if name else f"{save}.pdf"
        plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')

    plt.show()


In [None]:
# visualize cell types composition by timepoints
groupby = 'sample_ids'
composition_by_group(
    adata = combined_adata,
    groupby = groupby,
    title = 'Ground truth',
    colormap = semidata.uns['celltype_colors'],
    save = "/composition_gt",
    name = name
    )

In [None]:
enrichment_comparison(name, combined_adata, semisdata, celltype_key = 'celltype', selectedtype = "AT1", save = "figures")

In [None]:
def enrichment_comparison_reactome(name:str,
                                   gtdata:anndata.AnnData,
                                   semisdata:anndata.AnnData,
                                   celltype_key:str,
                                   selectedtype:str,
                                   save = None
                                  ) -> Tuple[np.array, np.array, np.array, np.array]:
    """
    Compare the enrichment analysis results using the real-profiled and semi-profiled datasets, using Reactome pathway sets.

    Parameters
    ----------
    name:
        Project name
    gtdata:
        Real-profiled (ground truth) data (AnnData object)
    semisdata:
        Semi-profiled dataset (AnnData object)
    celltype_key:
        The key in anndata.AnnData.obs that stores cell type information
    selectedtype:
        The selected cell type to analyze
    save:
        Path within the 'figures' folder to save the plot

    Returns
    -------
    CommonDEGs : int
        The number of overlapping DEGs between real and semi-profiled data
    HypergeometricP : float
        P-value of hypergeometric test examining the overlap between two versions of DEGs
    PearsonR : float
        Pearson correlation between bar lengths in real-profiled and semi-profiled bar plots
    PearsonP : float
        P-value of the Pearson correlation test

    Example
    -------
    >>> _ = enrichment_comparison_reactome(name, gtdata, semisdata, celltype_key='celltypes', selectedtype='CD4')
    """

    totaltypes = np.unique(gtdata.obs[celltype_key])

    sc.tl.rank_genes_groups(gtdata, celltype_key, method='t-test')
    typededic = {}
    for j in range(totaltypes.shape[0]):
        celltype = totaltypes[j]
        typede = []
        for i in range(100):
            g = gtdata.uns['rank_genes_groups']['names'][i][j]
            typede.append(g)
        typededic[celltype] = typede

    sc.tl.rank_genes_groups(semisdata, celltype_key, method='t-test')
    semitypededic = {}
    for j in range(totaltypes.shape[0]):
        celltype = totaltypes[j]
        typede = []
        for i in range(100):
            g = semisdata.uns['rank_genes_groups']['names'][i][j]
            typede.append(g)
        semitypededic[celltype] = typede

    gtdeg = typededic[selectedtype]
    semideg = semitypededic[selectedtype]
    c = sum([1 for i in semideg if i in gtdeg])

    hyperpval = hypert(semisdata.X.shape[1], 100, 100, c)
    print('p-value of hypergeometric test for overlapping DEGs:', str(float(hyperpval)))

    if (os.path.isdir(name + '/gseapygt')) == False:
        os.system('mkdir ' + name + '/gseapygt')
    if (os.path.isdir(name + '/gseapysemi')) == False:
        os.system('mkdir ' + name + '/gseapysemi')

    results = gseapy.enrichr(gene_list=gtdeg, gene_sets='Reactome_2022', outdir=name + '/gseapygt')
    f = open(name + '/gseapygt/Reactome_2022.human.enrichr.reports.txt', 'r')
    lines = f.readlines()
    f.close()

    gtsets = []
    gtps = []
    gtdic = {}
    for l in lines[1:]:
        term = l.split('\t')[1]
        p = float(l.split('\t')[4])
        gtsets.append(term)
        gtps.append(p)
        gtdic[term] = p

    results = gseapy.enrichr(gene_list=semideg, gene_sets='Reactome_2022', outdir=name + '/gseapysemi')
    f = open(name + '/gseapysemi/Reactome_2022.human.enrichr.reports.txt','r')
    lines = f.readlines()
    f.close()

    semisets = []
    semips = []
    semidic = {}
    for l in lines[1:]:
        term = l.split('\t')[1]
        p = float(l.split('\t')[4])
        semisets.append(term)
        semips.append(p)
        semidic[term] = p

    terms = copy.deepcopy(gtsets[:10])
    real_data = copy.deepcopy(gtps[:10])
    sim_data = []
    for i in range(10):
        gtterm = semisets[i]
        if gtterm not in semidic.keys():
            sim_data.append(1)
        else:
            sim_data.append(semidic[gtterm])

    for i in range(10):
        if semisets[i] in terms:
            continue
        terms.append(semisets[i])
        sim_data.append(semips[i])
        if semisets[i] not in gtdic.keys():
            real_data.append(1)
        else:
            real_data.append(gtdic[semisets[i]])

    real_data = np.flip(real_data)
    sim_data = np.flip(sim_data)
    terms = np.flip(terms)
    sim_bar_lengths = [-np.log10(p) for p in sim_data]
    real_bar_lengths = [-np.log10(p) for p in real_data]

    res = scipy.stats.pearsonr(np.array(sim_bar_lengths), np.array(real_bar_lengths))
    print('Significance correlation:', res)

    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(8, 5))
    bar_width = 0.4
    y = np.arange(len(sim_data)) + 1
    ax1.barh(y, real_bar_lengths, height=bar_width, color='green', label='Real')
    ax1.set_xlabel('-log10(p)')
    ax1.set_ylabel('Term')
    ax1.set_title('Real Data (' + str(len(semideg)) + ' DEGs)')
    ax2.barh(y, sim_bar_lengths, height=bar_width, color='blue', label='Simulated')
    ax2.set_xlabel('-log10(p)')
    ax2.set_title('Semi-profiled Data(' + str(len(gtdeg)) + ' DEGs)')

    max_val = max(max(sim_bar_lengths), max(real_bar_lengths))
    ax1.set_xlim(0, max_val + 1)
    ax2.set_xlim(0, max_val + 1)
    ax1.invert_xaxis()
    ax1.set_yticks(y)
    ax2.set_yticklabels(terms)
    fig.suptitle(selectedtype + ' Reactome (' + str(c) + ' Overlap DEGs)')

    if save is not None:
        plt.savefig(name + '/figures/' + save + selectedtype + ' Reactome.pdf', bbox_inches='tight')
        plt.savefig(name + '/figures/' + save + selectedtype + ' Reactome.jpg', dpi=600, bbox_inches='tight')
    plt.show()

    return c, float(hyperpval), res[0], res[1]


In [None]:
enrichment_comparison_reactome(name, combined_adata, semisdata, celltype_key = 'celltype', selectedtype = "AT1", save = "figures")