In [1]:
import muon as mu
import scanpy as sc

In [20]:
# -----------------------------
# Full updated notebook cell: PBMC Multi-modal PPO with GAT embeddings
# -----------------------------
# Imports & settings
# -----------------------------
import os
import math
import time
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
from sklearn.neighbors import NearestNeighbors
import numpy as np

from scipy.interpolate import interp1d
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    mean_squared_error, mean_absolute_error, r2_score,
    average_precision_score
)

import muon as mu
import mudatasets as mds
import scanpy as sc
from stable_baselines3 import PPO
import torch
from torch import optim
import gym
from gym import spaces

import torch
import torch.nn as nn
from torch.optim import Optimizer

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
# 📦 Imports
import cptac
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import umap
import seaborn as sns
from scipy.sparse import csr_matrix
warnings.filterwarnings("ignore", category=RuntimeWarning)

# -----------------------------
# PyTorch Geometric for GAT embeddings
# -----------------------------
import torch_geometric
from torch_geometric.nn import GATConv
from torch_geometric.data import Data

# -----------------------------
# Reproducibility & device
# -----------------------------
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("cpu")

MAX_STEPS = 100
PERTURB_PROB = 0.5
MAX_PERTURB = 10
N_EVAL_EPISODES = 30
OUT_DIR = "pbmc_multi_output"
PLOTS_DIR = os.path.join(OUT_DIR, "pseudotime_plots")
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

# -----------------------------
# Load & preprocess PBMC multiome
# -----------------------------
mdata = mds.load("pbmc10k_multiome", full=True)
mdata.var_names_make_unique()
print("Available modalities:", list(mdata.mod.keys()))

# Select a small subset of cells (e.g., 1000)
np.random.seed(42)
subset_cells = np.random.choice(mdata.mod['rna'].obs_names, size=5000, replace=False)

# Subset each modality
rna = mdata.mod['rna'][subset_cells].copy()
atac = mdata.mod['atac'][subset_cells].copy()
adt = mdata.mod['adt'][subset_cells].copy() if 'adt' in mdata.mod else None

# -----------------------------
# Build common cell indices
# -----------------------------
common_cells = rna.obs_names.intersection(atac.obs_names)
if adt is not None:
    common_cells = common_cells.intersection(adt.obs_names)
common_cells = np.array(common_cells)
print("Number of common cells:", len(common_cells))

# -----------------------------
# Sparse matrix access instead of slicing AnnData
# -----------------------------
rna_X = csr_matrix(rna.X)
atac_X = csr_matrix(atac.X)
# ADT (optional)
if adt is not None:
    adt_X = csr_matrix(adt.X)

    
    
    
# -----------------------------
# Simple GAT encoder
# -----------------------------
import torch
from torch_geometric.nn import GATConv

class SimpleGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim=8, out_dim=4, heads=2, dropout=0.2):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim*heads, out_dim, heads=1, dropout=dropout)

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.gat2(x, edge_index)
        return x

# -----------------------------
# kNN graph helper
# -----------------------------
from sklearn.neighbors import NearestNeighbors
import torch

def compute_knn_graph(X, k=5):
    nbrs = NearestNeighbors(n_neighbors=k, metric='cosine').fit(X)
    distances, indices = nbrs.kneighbors(X)
    edge_index = []
    for i in range(X.shape[0]):
        for j in indices[i]:
            if i != j:
                edge_index.append([i, j])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    return edge_index

# -----------------------------
# Compute embeddings
# -----------------------------
def gat_embedding(X_np):
    X = torch.tensor(X_np, dtype=torch.float32)
    edge_index = compute_knn_graph(X_np, k=3)
    model = SimpleGAT(in_dim=X.shape[1])
    with torch.no_grad():
        Z = model(X, edge_index)
    return Z.numpy()

# -----------------------------
# Run embeddings
# -----------------------------
z_rna = gat_embedding(rna_X.toarray())
z_atac = gat_embedding(atac_X.toarray())
if adt is not None:
    z_adt = gat_embedding(adt_X.toarray())

modality_data = {'rna': z_rna, 'atac': z_atac}
if adt is not None:
    modality_data['adt'] = z_adt

print({k:v.shape for k,v in modality_data.items()})


# -----------------------------
# Scale & train/test split
# -----------------------------
expression_train, expression_test = {}, {}
pseudotime_train, pseudotime_test = {}, {}
scalers = {}

for mod, data in modality_data.items():
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(data)
    scalers[mod] = scaler
    train_idx, test_idx = train_test_split(np.arange(data_scaled.shape[0]), test_size=0.2, random_state=SEED)
    expression_train[mod] = data_scaled[train_idx]
    expression_test[mod] = data_scaled[test_idx]
    pseudotime = np.arange(data_scaled.shape[0], dtype=np.float32)
    pseudotime_train[mod] = pseudotime[train_idx]
    pseudotime_test[mod] = pseudotime[test_idx]

# -----------------------------
# Fuse modalities
# -----------------------------
fused_train = np.concatenate([expression_train[m] for m in expression_train.keys()], axis=1)
fused_test = np.concatenate([expression_test[m] for m in expression_test.keys()], axis=1)

modality_dims = {mod: expression_train[mod].shape[1] for mod in expression_train.keys()}
starts = np.cumsum([0] + list(modality_dims.values()))[:-1]
modality_splits = {}
idx = 0
for mod, dim in modality_dims.items():
    modality_splits[mod] = (idx, idx+dim)
    idx += dim

print("Modality splits:", modality_splits)
print("Fused train shape:", fused_train.shape)
    
    
    
# -----------------------------
# Scaling & train/test split
# -----------------------------
expression_train, expression_test = {}, {}
pseudotime_train, pseudotime_test = {}, {}
scalers = {}

for mod, data in modality_data.items():
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(data)
    scalers[mod] = scaler
    train_idx, test_idx = train_test_split(np.arange(data_scaled.shape[0]), test_size=0.2, random_state=SEED)
    expression_train[mod] = data_scaled[train_idx]
    expression_test[mod] = data_scaled[test_idx]
    pseudotime = np.arange(data_scaled.shape[0], dtype=np.float32)
    pseudotime_train[mod] = pseudotime[train_idx]
    pseudotime_test[mod] = pseudotime[test_idx]

fused_train = np.concatenate([expression_train[m] for m in expression_train.keys()], axis=1)
fused_test = np.concatenate([expression_test[m] for m in expression_test.keys()], axis=1)

# -----------------------------
# Modality splits
# -----------------------------
modality_dims = {mod: expression_train[mod].shape[1] for mod in expression_train.keys()}
starts = np.cumsum([0] + list(modality_dims.values()))[:-1]
modality_splits = {}
idx = 0
for mod, dim in modality_dims.items():
    start = idx
    end = idx + dim
    modality_splits[mod] = (start, end)
    idx = end
total_features = fused_train.shape[1]
selected_gene_names = [f"feat_{i}" for i in range(total_features)]
print("Modality splits:", modality_splits)

# -----------------------------
# Adaptive thresholds
# -----------------------------
class PerGeneAdaptiveThreshold:
    def __init__(self, modality_dims, alpha=0.1):
        self.thresholds = {mod: {i: 0.0 for i in range(dim)} for mod, dim in modality_dims.items()}
        self.alpha = alpha
    def update(self, gene_rewards):
        for mod, rewards in gene_rewards.items():
            for gene_id, reward in rewards.items():
                if reward is None or (isinstance(reward, float) and np.isnan(reward)):
                    continue
                self.thresholds[mod][gene_id] = self.alpha*float(reward) + (1-self.alpha)*self.thresholds[mod].get(gene_id,0.0)
    def get(self, mod, gene_id):
        return float(self.thresholds.get(mod, {}).get(gene_id,0.0))

adaptive_thresholds = PerGeneAdaptiveThreshold(modality_dims)



# -----------------------------
# Step 9: Multi-modal PBMC CRISPR environment
# -----------------------------
class PBMC_CRISPR_MultiModalEnv(gym.Env):
    metadata = {"render_modes": ["human"]}

    def __init__(self, expression_dict_or_matrix, pseudotime_dict_or_array, modality_splits,
                 max_steps=MAX_STEPS, adaptive_thresholds=None, device='cpu',
                 action_magnitude=0.25, perturb_prob=0.1, max_perturb=3):
        super().__init__()

        # Accept either dicts (per-modality) or already-fused matrices
        if isinstance(expression_dict_or_matrix, dict):
            # build fused matrix in modality_splits order
            self.expression_dict = {mod: np.asarray(expression_dict_or_matrix[mod], dtype=np.float32) for mod in modality_splits.keys()}
            self.expression = np.concatenate([self.expression_dict[mod] for mod in modality_splits.keys()], axis=1)
            # pseudotime: use first modality's pseudotime (cell-level)
            self.pseudotime = np.asarray(next(iter(pseudotime_dict_or_array.values())), dtype=np.float32)
        else:
            self.expression = np.asarray(expression_dict_or_matrix, dtype=np.float32)
            self.expression_dict = {}
            start = 0
            for mod, (s, e) in modality_splits.items():
                self.expression_dict[mod] = self.expression[:, s:e]
            self.pseudotime = np.asarray(pseudotime_dict_or_array, dtype=np.float32)

        self.modality_splits = modality_splits
        self.modality_dims = {m: (e - s) for m, (s, e) in modality_splits.items()}
        self.n_cells, self.n_genes = self.expression.shape
        self.max_steps = max_steps
        self.adaptive_thresholds = adaptive_thresholds
        self.device = device
        self.action_magnitude = action_magnitude
        self.perturb_prob = perturb_prob
        self.max_perturb = max_perturb

        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.n_genes,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.n_genes,), dtype=np.float32)

        # bookkeeping
        self.current_cell = 0

    def reset(self, seed=None, options=None):
        self.idx = np.random.randint(self.n_cells)
        self.state = self.expression[self.idx].copy()
        self.original_state = self.state.copy()

        eligible_idxs = np.where(np.arange(self.n_cells) > self.idx)[0]
        if len(eligible_idxs) == 0:
            eligible_idxs = np.array([self.idx])
        target_idx = np.random.choice(eligible_idxs)
        self.target = self.expression[target_idx].copy()

        self.steps = 0
        self.history = [self.state.copy()]
        self.knockout_genes = set()
        self.overexpressed_genes = set()
        self._apply_crispr_perturbation()
        self.current_cell = 0
        return self.state.copy()

    def _apply_crispr_perturbation(self):
        n_perturb = np.random.randint(1, self.max_perturb + 1)
        for _ in range(n_perturb):
            gene = np.random.randint(0, self.n_genes)
            if np.random.rand() < 0.5:
                self.state[gene] = 0.0
                self.knockout_genes.add(int(gene))
            else:
                self.state[gene] = self.state[gene] * 2.0
                self.overexpressed_genes.add(int(gene))

    def step(self, action):
        action = np.asarray(action, dtype=np.float32).ravel()
        if action.shape[0] != self.n_genes:
            raise ValueError("Action length mismatch.")
        for i, delta in enumerate(action):
            self.state[i] = np.clip(self.state[i] + delta * self.action_magnitude, -5.0, 5.0)

        if np.random.rand() < self.perturb_prob:
            self._apply_crispr_perturbation()

        old_mse = float(np.mean((self.history[-1] - self.target) ** 2))
        new_mse = float(np.mean((self.state - self.target) ** 2))
        reward = old_mse - new_mse

        # subtract adaptive thresholds per modality (if provided)
        if self.adaptive_thresholds is not None:
            for mod, (start, end) in self.modality_splits.items():
                for local_idx, g in enumerate(range(start, end)):
                    reward -= self.adaptive_thresholds.get(mod, local_idx)

        self.steps += 1
        self.history.append(self.state.copy())
        terminated = self.steps >= self.max_steps
        done = terminated
        info = {}
        self.current_cell += 1
        return self.state.copy(), float(reward), done, info

    def render(self, mode='human'):
        print(f"Step {self.steps} - state (first 10): {self.state[:10]}")
        print(f"Knockouts: {sorted(list(self.knockout_genes))[:10]}, Overexpr: {sorted(list(self.overexpressed_genes))[:10]}")

# -----------------------------
# SB3 Env wrapper
# -----------------------------
import gym
from gym import spaces
import numpy as np

class GRNEnvWrapper(gym.Env):
    def __init__(self, base_env):
        super().__init__()
        self.env = base_env
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(self.env.n_genes,), dtype=np.float32
        )
        self.action_space = self.env.action_space

    def reset(self, **kwargs):
        result = self.env.reset(**kwargs)
        # Handle Gym >=0.26 returning (obs, info)
        if isinstance(result, tuple) and len(result) == 2:
            obs, info = result
            return np.asarray(obs, dtype=np.float32), info
        else:
            return np.asarray(result, dtype=np.float32)

    def step(self, action):
        result = self.env.step(action)
        if len(result) == 5:  # Gym >=0.26
            obs, reward, terminated, truncated, info = result
            done = terminated or truncated
        else:
            obs, reward, done, info = result
        return np.asarray(obs, dtype=np.float32), reward, done, info

    def seed(self, seed=None):
        # Optional: delegate seeding to the base environment
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)
        np.random.seed(seed)
        return [seed]


# -----------------------------
# make_env_factory (multi-modality aware)
# -----------------------------
def make_env_factory(expression, pseudotime, modality_splits, adaptive_thresholds,
                     perturb_prob=PERTURB_PROB, max_perturb=MAX_PERTURB):
    def _init():
        base_env = PBMC_CRISPR_MultiModalEnv(
            expression_dict_or_matrix=expression,
            pseudotime_dict_or_array=pseudotime,
            modality_splits=modality_splits,
            max_steps=MAX_STEPS,
            adaptive_thresholds=adaptive_thresholds,
            device=DEVICE,
            action_magnitude=0.25,
            perturb_prob=perturb_prob,
            max_perturb=max_perturb
        )
        return GRNEnvWrapper(base_env)
    return _init

# -----------------------------
# evaluate_and_plot_multi_modality
# -----------------------------
def evaluate_and_plot_multi_modality(model, algo_name, expression_test, pseudotime_test,
                                     gene_names, modality_splits, adaptive_thresholds,
                                     n_episodes=50, save_dir=PLOTS_DIR):
    os.makedirs(save_dir, exist_ok=True)
    results = []
    eval_env_factory = make_env_factory(expression_test, pseudotime_test, modality_splits, adaptive_thresholds)

    for mod_name, (start_idx, end_idx) in modality_splits.items():
        print(f"Evaluating modality: {mod_name}")
        for gene_idx in range(start_idx, end_idx):
            gene_name = gene_names[gene_idx] if gene_idx < len(gene_names) else f"g{gene_idx}"
            y_true, y_pred = [], []
            perturbed_vals, original_vals, pseudotimes = [], [], []

            for ep in range(n_episodes):
                env = eval_env_factory()
                try:
                    obs = env.reset()
                except Exception:
                    continue

                original = env.env.original_state.copy()
                target = env.env.target.copy()
                pt_idx = getattr(env.env, "current_cell", 0)
                pt_value = float(env.env.pseudotime[pt_idx]) if len(env.env.pseudotime) > 0 else 0.0

                done = False
                traj_pred = []

                while not done:
                    try:
                        action, _ = model.predict(obs, deterministic=True)
                        step_result = env.step(action)
                        if len(step_result) == 5:
                            obs, reward, terminated, truncated, info = step_result
                            done = terminated or truncated
                        else:
                            obs, reward, done, info = step_result
                    except Exception:
                        break
                    traj_pred.append(env.env.state[gene_idx])

                if len(traj_pred) == 0:
                    continue

                final_state = env.env.history[-1]
                delta = float(final_state[gene_idx] - original[gene_idx])
                label = 1 if target[gene_idx] > original[gene_idx] else 0
                prediction = 1 if delta > 0 else 0

                y_true.append(label)
                y_pred.append(prediction)
                perturbed_vals.append(float(final_state[gene_idx]))
                original_vals.append(float(original[gene_idx]))
                pseudotimes.append(pt_value)

            if len(y_true) == 0:
                continue

            # metrics
            acc = accuracy_score(y_true, y_pred)
            prec = precision_score(y_true, y_pred, average="weighted", zero_division=0)
            rec = recall_score(y_true, y_pred, average="weighted", zero_division=0)
            f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
            try:
                auprc = average_precision_score(y_true, y_pred)
            except Exception:
                auprc = np.nan

            mse = mean_squared_error(original_vals, perturbed_vals)
            rmse = math.sqrt(mse)
            mae = mean_absolute_error(original_vals, perturbed_vals)
            r2 = r2_score(original_vals, perturbed_vals)
            pc = np.corrcoef(original_vals, perturbed_vals)[0, 1] if np.std(original_vals) != 0 else 0.0

            results.append({
                "Algorithm": algo_name,
                "Modality": mod_name,
                "Gene": gene_name,
                "Accuracy": acc,
                "Precision": prec,
                "Recall": rec,
                "F1": f1,
                "AUPRC": auprc,
                "Final Expression MSE": mse,
                "Final Expression RMSE": rmse,
                "Final Expression MAE": mae,
                "Final Expression R²": r2,
                "Final Expression PearsonCorr": pc
            })

            # plot pseudotime
            try:
                df = pd.DataFrame({
                    "pseudotime": pseudotimes,
                    "original_expression": original_vals,
                    "perturbed_expression": perturbed_vals
                })
                df['delta'] = df['perturbed_expression'] - df['original_expression']
                df['label'] = df['delta'].apply(lambda x: "Up" if x > 0 else "Down")
                plt.figure(figsize=(8, 4))
                sns.scatterplot(data=df, x="pseudotime", y="perturbed_expression", hue="label", style="label")
                sns.lineplot(data=df.sort_values('pseudotime'), x="pseudotime", y="perturbed_expression", lw=1, alpha=0.5)
                plt.title(f"{algo_name} — {mod_name} — {gene_name} Perturbation")
                plt.xlabel("Pseudotime")
                plt.ylabel("Expression (z-score)")
                plt.grid(True)
                plt.tight_layout()
                plt.savefig(os.path.join(save_dir, f"{algo_name}_{mod_name}_{gene_name}.png"), dpi=300)
                plt.close()
            except Exception as e:
                print("Plot error:", e)

    return pd.DataFrame(results)

# -----------------------------
# Custom optimizers (Padam, ASGDAdam, ASGDAmsgrad)
# -----------------------------
# Implementations included so we can use them as optimizer_class via policy_kwargs in PPO wrappers.
# (Identical implementations as earlier; kept minimal here to register classes.)

def count_nonzero(tensor):
    return int((tensor != 0).sum().item())

class ASGDAdam(Optimizer):
    """ASGD-style optimizer using separate lr_min/lr_max per step."""
    def __init__(self, params, lr=None, beta1=0.9, beta2=0.999, eps=1e-8,
                 lr_min=1e-4, lr_max=3e-4):
        defaults = dict(beta1=beta1, beta2=beta2, eps=eps, lr_min=lr_min, lr_max=lr_max)
        super().__init__(params, defaults)
        self.last_total_nonzero_fmin = 0
        self.last_total_nonzero_fmax = 0
        self.last_lr = lr_max

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        total_nonzero_fmin, total_nonzero_fmax = 0, 0
        for group in self.param_groups:
            beta1, beta2, eps = group['beta1'], group['beta2'], group['eps']
            for p in group['params']:
                if p.grad is None:
                    continue
                g = p.grad
                state = self.state[p]
                if not state:
                    state['t'] = 0
                    state['m'] = torch.zeros_like(p)
                    state['v'] = torch.zeros_like(p)
                    state['v_prev'] = torch.zeros_like(p)
                m, v = state['m'], state['v']
                state['t'] += 1
                t = state['t']
                m.mul_(beta1).add_(g, alpha=1-beta1)
                v.mul_(beta2).addcmul_(g, g, value=1-beta2)
                dv = v - state['v_prev']
                state['v_prev'].copy_(v)
                f_min = (dv > 0).to(dtype=torch.int32)
                f_max = (dv <= 0).to(dtype=torch.int32)
                total_nonzero_fmin += count_nonzero(f_min)
                total_nonzero_fmax += count_nonzero(f_max)
                mhat = m / (1 - beta1 ** t)
                state['step_dir'] = mhat / (v.sqrt().add(eps))
        use_lr_min = (total_nonzero_fmax < total_nonzero_fmin) #
        self.last_total_nonzero_fmin = total_nonzero_fmin
        self.last_total_nonzero_fmax = total_nonzero_fmax
        for group in self.param_groups:
            lr = group['lr_min'] if use_lr_min else group['lr_max']
            self.last_lr = lr
            for p in group['params']:
                if p.grad is not None and 'step_dir' in self.state[p]:
                    p.add_(self.state[p]['step_dir'], alpha=-lr)
        return loss

class ASGDAmsgrad(Optimizer):
    """ASGD-style optimizer mimicking AMSGrad"""
    def __init__(self, params, lr=None, beta1=0.9, beta2=0.999, eps=1e-8,
                 lr_min=1e-5, lr_max=3e-4):
        defaults = dict(beta1=beta1, beta2=beta2, eps=eps, lr_min=lr_min, lr_max=lr_max)
        super().__init__(params, defaults)
        self.last_total_nonzero_fmin = 0
        self.last_total_nonzero_fmax = 0
        self.last_lr = lr_max

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        total_nonzero_fmin, total_nonzero_fmax = 0, 0
        for group in self.param_groups:
            beta1, beta2, eps = group['beta1'], group['beta2'], group['eps']
            for p in group['params']:
                if p.grad is None:
                    continue
                g = p.grad
                state = self.state[p]
                if not state:
                    state['t'] = 0
                    state['m'] = torch.zeros_like(p)
                    state['v'] = torch.zeros_like(p)
                    state['v_hat'] = torch.zeros_like(p)
                    state['v_prev'] = torch.zeros_like(p)
                m, v, v_hat = state['m'], state['v'], state['v_hat']
                state['t'] += 1
                t = state['t']
                m.mul_(beta1).add_(g, alpha=1-beta1)
                v.mul_(beta2).addcmul_(g, g, value=1-beta2)
                torch.maximum(v_hat, v, out=v_hat)
                denom = v_hat.sqrt().add(eps)
                dv = v - state['v_prev']
                state['v_prev'].copy_(v)
                f_min = (dv > 0).to(dtype=torch.int32)
                f_max = (dv <= 0).to(dtype=torch.int32)
                total_nonzero_fmin += count_nonzero(f_min)
                total_nonzero_fmax += count_nonzero(f_max)
                mhat = m / (1 - beta1 ** t)
                state['step_dir'] = mhat / denom
        use_lr_min = (total_nonzero_fmax <  total_nonzero_fmin) # 
        self.last_total_nonzero_fmin = total_nonzero_fmin
        self.last_total_nonzero_fmax = total_nonzero_fmax
        for group in self.param_groups:
            lr = group['lr_min'] if use_lr_min else group['lr_max']
            self.last_lr = lr
            for p in group['params']:
                if p.grad is not None and 'step_dir' in self.state[p]:
                    p.add_(self.state[p]['step_dir'], alpha=-lr)
        return loss

class Padam(Optimizer):
    """Padam optimizer"""
    def __init__(self, params, lr=1e-3, betas=(0.9,0.999), eps=1e-8, weight_decay=0, amsgrad=False, p=0.125):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, p=p)
        super().__init__(params, defaults)
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Padam does not support sparse gradients')
                state = self.state[p]
                amsgrad = group['amsgrad']
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    if amsgrad:
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1
                exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                if amsgrad:
                    torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                step_size = group['lr'] / bias_correction1
                denom = denom.pow(group['p'])
                if group['weight_decay'] != 0:
                    grad = grad.add(p.data, alpha=group['weight_decay'])
                p.data.addcdiv_(exp_avg, denom, value=-step_size)
        return loss

    
# --------------------------
# PPO Wrappers
# --------------------------
class PPOAdamAMSGrad(PPO):
    def __init__(self, *args, **kwargs):
        if "policy_kwargs" not in kwargs: kwargs["policy_kwargs"] = {}
        kwargs["policy_kwargs"].update({"optimizer_class": optim.Adam, "optimizer_kwargs": {"amsgrad": True}})
        super().__init__(*args, **kwargs)

class PPOAdam(PPO):
    def __init__(self, *args, **kwargs):
        if "policy_kwargs" not in kwargs: kwargs["policy_kwargs"] = {}
        kwargs["policy_kwargs"].update({"optimizer_class": optim.Adam})
        super().__init__(*args, **kwargs)

class PPOSGD(PPO):
    def __init__(self, *args, **kwargs):
        if "policy_kwargs" not in kwargs: kwargs["policy_kwargs"] = {}
        kwargs["policy_kwargs"].update({"optimizer_class": optim.SGD})
        super().__init__(*args, **kwargs)

class PPOPadam(PPO):
    def __init__(self, *args, **kwargs):
        if "policy_kwargs" not in kwargs: kwargs["policy_kwargs"] = {}
        kwargs["policy_kwargs"].update({"optimizer_class": Padam, "optimizer_kwargs": {"amsgrad": True}})
        super().__init__(*args, **kwargs)

class PPOASGDAdam(PPO):
    def __init__(self, *args, **kwargs):
        if "policy_kwargs" not in kwargs: kwargs["policy_kwargs"] = {}
        kwargs["policy_kwargs"].update({"optimizer_class": ASGDAdam, "optimizer_kwargs": {"lr_min":1e-7, "lr_max":1e-4}})
        super().__init__(*args, **kwargs)

class PPOASGDAmsgrad(PPO):
    def __init__(self, *args, **kwargs):
        if "policy_kwargs" not in kwargs: kwargs["policy_kwargs"] = {}
        kwargs["policy_kwargs"].update({"optimizer_class": ASGDAmsgrad, "optimizer_kwargs": {"lr_min":1e-7, "lr_max":1e-4}})
        super().__init__(*args, **kwargs)


    
    
    
# -----------------------------
# Training orchestration: train_compare_optimizers
# -----------------------------
import os
import pandas as pd
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

import os
import pandas as pd
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

def train_and_evaluate_optimizers(modality_choice="all",
                                  expression_train=None, pseudotime_train=None,
                                  expression_test=None, pseudotime_test=None,
                                  modality_splits=None, gene_names=None,
                                  adaptive_thresholds=None,
                                  optimizers_to_run=None,
                                  train_steps=10000,
                                  save_dir="outputs"):
    
    os.makedirs(save_dir, exist_ok=True)
    
    if optimizers_to_run is None:
        optimizers_to_run = ["adam", "amsgrad"]

    algo_map = {
        "sgd": PPOSGD,
        "amsgrad": PPOAdamAMSGrad,
        "adam": PPOAdam,
        "padam": PPOPadam,
        "asgdadam": PPOASGDAdam,
        "asgdaamsgrad": PPOASGDAmsgrad,
    }

    base_kwargs = {
        "gamma": 0.99,
        "gae_lambda": 0.95,
        "clip_range": 0.2,
        "ent_coef": 0.0,
        "vf_coef": 0.5,
        "max_grad_norm": 0.5,
        "policy_kwargs": dict(net_arch=[dict(pi=[256, 256], vf=[256, 256])], activation_fn=nn.Tanh)
    }

    # -----------------------------
    # Modality-specific data
    # -----------------------------
    if modality_choice == "all":
        expr_train_mod = expression_train
        expr_test_mod = expression_test
        pseudo_train_mod = pseudotime_train[next(iter(pseudotime_train.keys()))]  # pick any modality's pseudotime
        pseudo_test_mod = pseudotime_test[next(iter(pseudotime_test.keys()))]
        gene_names_mod = gene_names
        splits_for_env_mod = modality_splits
    else:
        start, end = modality_splits[modality_choice]
        expr_train_mod = expression_train[:, start:end]
        expr_test_mod = expression_test[:, start:end]
        pseudo_train_mod = pseudotime_train[modality_choice]
        pseudo_test_mod = pseudotime_test[modality_choice]
        splits_for_env_mod = {modality_choice: (0, end - start)}
        gene_names_mod = [f"{modality_choice}_{i}" for i in range(end - start)]

    # -----------------------------
    # Environments
    # -----------------------------
    train_env = DummyVecEnv([make_env_factory(expr_train_mod, pseudo_train_mod, splits_for_env_mod, adaptive_thresholds)])
    eval_env = DummyVecEnv([make_env_factory(expr_test_mod, pseudo_test_mod, splits_for_env_mod, adaptive_thresholds)])
    eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.)

    trained_models = {}
    results_reward = {}
    train_metrics_df = []
    test_metrics_df = []

    # -----------------------------
    # Training loop
    # -----------------------------
    for opt_name in optimizers_to_run:
        PPOClass = algo_map[opt_name]
        print(f"\n--- Training {opt_name} on modality={modality_choice} ---")

        model = PPOClass("MlpPolicy", train_env, verbose=1, seed=SEED, **base_kwargs)
        model.learn(total_timesteps=train_steps)
        trained_models[opt_name] = model

        # Evaluate mean reward on test environment
        mean_r, std_r = evaluate_model_sb3(model, eval_env, n_eval_episodes=N_EVAL_EPISODES)
        results_reward[opt_name] = (mean_r, std_r)
        print(f"Eval mean reward ({opt_name}): {mean_r:.4f} ± {std_r:.4f}")

        # -----------------------------
        # Per-gene evaluation on training data
        # -----------------------------
# -----------------------------
# Per-gene evaluation on training data
# -----------------------------
# -----------------------------
# Per-gene evaluation on training data
# -----------------------------
        df_train = evaluate_and_plot_multi_modality(
            model=model,
            algo_name=f"{opt_name}_train",
            expression_test=expr_train_mod,
            pseudotime_test=pseudo_train_mod,
            modality_splits=splits_for_env_mod,  # pass the modality splits
            gene_names=gene_names_mod,
           # encoder_path="gcn_encoder.pth",
            adaptive_thresholds=adaptive_thresholds,
            n_episodes=30,
            save_dir=PLOTS_DIR
        )
        train_metrics_df.append(df_train)

        # -----------------------------
        # Per-gene evaluation on test data
        # -----------------------------
        df_test = evaluate_and_plot_multi_modality(
            model=model,
            algo_name=f"{opt_name}_test",
            expression_test=expr_test_mod,
            pseudotime_test=pseudo_test_mod,
            modality_splits=splits_for_env_mod,
            gene_names=gene_names_mod,
          #  encoder_path="gcn_encoder.pth",
            adaptive_thresholds=adaptive_thresholds,
            n_episodes=30,
            save_dir=PLOTS_DIR
        )
        test_metrics_df.append(df_test)


    # -----------------------------
    # Save training metrics
    # -----------------------------
    final_train_df = pd.concat(train_metrics_df, ignore_index=True)
    final_train_df.to_csv(os.path.join(save_dir, f"training_per_gene_metrics_{modality_choice}.csv"), index=False)
    train_summary_df = final_train_df.groupby("Algorithm").mean(numeric_only=True)
    train_summary_df.to_csv(os.path.join(save_dir, f"training_overall_metrics_{modality_choice}.csv"))

    # -----------------------------
    # Save testing metrics
    # -----------------------------
    final_test_df = pd.concat(test_metrics_df, ignore_index=True)
    final_test_df.to_csv(os.path.join(save_dir, f"testing_per_gene_metrics_{modality_choice}.csv"), index=False)
    test_summary_df = final_test_df.groupby("Algorithm").mean(numeric_only=True)
    test_summary_df.to_csv(os.path.join(save_dir, f"testing_overall_metrics_{modality_choice}.csv"))

    print(f"\n✅ Training & evaluation complete for all optimizers on modality={modality_choice}.")

    return trained_models, results_reward

# -----------------------------
# Evaluate a SB3 model
# -----------------------------
def evaluate_model_sb3(model, env, n_eval_episodes=10, deterministic=True):
    """
    Evaluate a Stable-Baselines3 model.
    
    Returns:
        mean_reward, std_reward
    """
    all_rewards = []

    for episode in range(n_eval_episodes):
        obs = env.reset()
        done = False
        total_reward = 0.0

        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, reward, done, info = env.step(action)
            total_reward += reward

        all_rewards.append(total_reward)

    mean_reward = np.mean(all_rewards)
    std_reward = np.std(all_rewards)
    return mean_reward, std_reward


■ File filtered_feature_bc_matrix.h5 from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/filtered_feature_bc_matrix.h5
■ Checksum is validated (md5) for filtered_feature_bc_matrix.h5
■ File atac_fragments.tsv.gz from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_fragments.tsv.gz
■ Checksum is validated (md5) for atac_fragments.tsv.gz
■ File atac_fragments.tsv.gz.tbi from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_fragments.tsv.gz.tbi
■ Checksum is validated (md5) for atac_fragments.tsv.gz.tbi
■ File atac_peaks.bed from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_peaks.bed
■ Checksum is validated (md5) for atac_peaks.bed
■ File atac_peak_annotation.tsv from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_peak_annotation.tsv
■ Checksum is validated (md5) for atac_peak_annota

  warn("Dataset is in the 10X .h5 format and can't be loaded as backed.")
  utils.warn_names_duplicates("var")


Added `interval` annotation for features from /Users/boabangfrancis/mudatasets/pbmc10k_multiome/filtered_feature_bc_matrix.h5


  utils.warn_names_duplicates("var")
  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


Added peak annotation from /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_peak_annotation.tsv to .uns['atac']['peak_annotation']
Added gene names to peak annotation in .uns['atac']['peak_annotation']
Located fragments file: /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_fragments.tsv.gz
pysam is not available. It is required to work with the fragments file.                 Install pysam from PyPI (`pip install pysam`)                 or from GitHub (`pip install git+https://github.com/pysam-developers/pysam`)
Available modalities: ['rna', 'atac']
Number of common cells: 5000
{'rna': (5000, 4), 'atac': (5000, 4)}
Modality splits: {'rna': (0, 4), 'atac': (4, 8)}
Fused train shape: (4000, 8)
Modality splits: {'rna': (0, 4), 'atac': (4, 8)}


In [23]:
adaptive_thresholds = PerGeneAdaptiveThreshold(modality_dims)
trained_models, rewards = train_and_evaluate_optimizers(
    modality_choice="rna",  # or "all", "atac", "adt"
    expression_train=fused_train,
    pseudotime_train=pseudotime_train,
    expression_test=fused_test,
    pseudotime_test=pseudotime_test,
    modality_splits=modality_splits,
    gene_names=selected_gene_names,
    adaptive_thresholds=adaptive_thresholds,
    optimizers_to_run=["sgd", "adam","amsgrad",  "adam", "padam","asgdadam",  "asgdaamsgrad"],
    train_steps=1000000,
    save_dir="outputs"
)





--- Training sgd on modality=rna ---
Using cpu device
-----------------------------
| time/              |      |
|    fps             | 6143 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 4282          |
|    iterations           | 2             |
|    time_elapsed         | 0             |
|    total_timesteps      | 4096          |
| train/                  |               |
|    approx_kl            | 2.4592446e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -5.68         |
|    explained_variance   | -0.00184      |
|    learning_rate        | 0.0003        |
|    loss                 | 58.2          |
|    n_updates            | 10            |
|    policy_gradient_loss | -0.000219     |
|    std                  | 1  



-----------------------------
| time/              |      |
|    fps             | 7013 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 3683         |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0037106352 |
|    clip_fraction        | 0.0433       |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.7         |
|    explained_variance   | -0.00184     |
|    learning_rate        | 0.0003       |
|    loss                 | 51.4         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00595     |
|    std                  | 1.01         |
|    value_loss           | 449          |
----------------



-----------------------------
| time/              |      |
|    fps             | 6999 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 3215         |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0033206567 |
|    clip_fraction        | 0.0406       |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.69        |
|    explained_variance   | -0.00184     |
|    learning_rate        | 0.0003       |
|    loss                 | 51.3         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00625     |
|    std                  | 1.01         |
|    value_loss           | 449          |
----------------



-----------------------------
| time/              |      |
|    fps             | 7009 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 3447         |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0033206567 |
|    clip_fraction        | 0.0406       |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.69        |
|    explained_variance   | -0.00184     |
|    learning_rate        | 0.0003       |
|    loss                 | 51.3         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00625     |
|    std                  | 1.01         |
|    value_loss           | 449          |
----------------



-----------------------------
| time/              |      |
|    fps             | 7401 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 3026          |
|    iterations           | 2             |
|    time_elapsed         | 1             |
|    total_timesteps      | 4096          |
| train/                  |               |
|    approx_kl            | 1.0802818e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -5.68         |
|    explained_variance   | -0.00184      |
|    learning_rate        | 0.0003        |
|    loss                 | 58.1          |
|    n_updates            | 10            |
|    policy_gradient_loss | -0.000474     |
|    std                  | 1             |
|    value_loss           | 459           



-----------------------------
| time/              |      |
|    fps             | 7258 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 2507         |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0017994352 |
|    clip_fraction        | 0.0122       |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.68        |
|    explained_variance   | -0.00184     |
|    learning_rate        | 0.0003       |
|    loss                 | 55.4         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00671     |
|    std                  | 1            |
|    value_loss           | 456          |
----------------



-----------------------------
| time/              |      |
|    fps             | 7375 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 2580         |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0017975197 |
|    clip_fraction        | 0.0122       |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.68        |
|    explained_variance   | -0.00184     |
|    learning_rate        | 0.0003       |
|    loss                 | 55.4         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00671     |
|    std                  | 1            |
|    value_loss           | 456          |
----------------