In [None]:
import os
import math
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    mean_squared_error, mean_absolute_error, r2_score,
    average_precision_score
)
from scipy.sparse import csr_matrix
import muon as mu
from sb3_contrib.trpo import TRPO
import mudatasets as mds
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
warnings.filterwarnings("ignore", category=RuntimeWarning)

# -----------------------------
# Reproducibility & device
# -----------------------------
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("cpu")

MAX_STEPS = 200
PERTURB_PROB = 0.8
MAX_PERTURB = 40
N_EVAL_EPISODES = 30
#OUT_DIR = "pbmc_multi_output"
PLOTS_DIR = os.path.join(OUT_DIR, "pseudotime_plots")
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

# -----------------------------
# Load PBMC multiome dataset
# -----------------------------
mdata = mds.load("pbmc10k_multiome", full=True)
mdata.var_names_make_unique()
print("Available modalities:", list(mdata.mod.keys()))

# -----------------------------
# Subset cells for speed
# -----------------------------
subset_cells = np.random.choice(mdata.mod['rna'].obs_names, size=10000, replace=False)
rna = mdata.mod['rna'][subset_cells].copy()
atac = mdata.mod['atac'][subset_cells].copy()
adt = mdata.mod['adt'][subset_cells].copy() if 'adt' in mdata.mod else None

# -----------------------------
# Common cells across modalities
# -----------------------------
common_cells = rna.obs_names.intersection(atac.obs_names)
if adt is not None:
    common_cells = common_cells.intersection(adt.obs_names)
common_cells = np.array(common_cells)
print("Number of common cells:", len(common_cells))

# -----------------------------
# Sparse matrices
# -----------------------------
rna_X = csr_matrix(rna.X)
atac_X = csr_matrix(atac.X)
adt_X = csr_matrix(adt.X) if adt is not None else None

# -----------------------------
# Simple GAT Encoder
# -----------------------------
class SimpleGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, out_dim=32, heads=2, dropout=0.2):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim*heads, out_dim, heads=1, dropout=dropout)

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.gat2(x, edge_index)
        return x

# -----------------------------
# kNN Graph Helper
# -----------------------------
def compute_knn_graph(X, k=5):
    nbrs = NearestNeighbors(n_neighbors=k, metric='cosine').fit(X)
    distances, indices = nbrs.kneighbors(X)
    edge_index = []
    for i in range(X.shape[0]):
        for j in indices[i]:
            if i != j:
                edge_index.append([i, j])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    return edge_index

# -----------------------------
# GAT embedding function
# -----------------------------
def gat_embedding(X_np):
    X = torch.tensor(X_np, dtype=torch.float32)
    edge_index = compute_knn_graph(X_np, k=3)
    model = SimpleGAT(in_dim=X.shape[1])
    with torch.no_grad():
        Z = model(X, edge_index)
    return Z.numpy()

# -----------------------------
# Compute embeddings per modality
# -----------------------------
z_rna = gat_embedding(rna_X.toarray())
z_atac = gat_embedding(atac_X.toarray())
z_adt = gat_embedding(adt_X.toarray()) if adt is not None else None

modality_data = {'rna': z_rna, 'atac': z_atac}
if adt is not None:
    modality_data['adt'] = z_adt
print({k:v.shape for k,v in modality_data.items()})

# -----------------------------
# Scale & train/test split
# -----------------------------
expression_train, expression_test = {}, {}
pseudotime_train, pseudotime_test = {}, {}
scalers = {}

for mod, data in modality_data.items():
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(data)
    scalers[mod] = scaler
    train_idx, test_idx = train_test_split(np.arange(data_scaled.shape[0]), test_size=0.2, random_state=SEED)
    expression_train[mod] = data_scaled[train_idx]
    expression_test[mod] = data_scaled[test_idx]
    pseudotime = np.arange(data_scaled.shape[0], dtype=np.float32)
    pseudotime_train[mod] = pseudotime[train_idx]
    pseudotime_test[mod] = pseudotime[test_idx]

# -----------------------------
# Fuse modalities
# -----------------------------
fused_train = np.concatenate([expression_train[m] for m in expression_train.keys()], axis=1)
fused_test = np.concatenate([expression_test[m] for m in expression_test.keys()], axis=1)

modality_dims = {mod: expression_train[mod].shape[1] for mod in expression_train.keys()}
starts = np.cumsum([0] + list(modality_dims.values()))[:-1]
modality_splits = {}
idx = 0
for mod, dim in modality_dims.items():
    start = idx
    end = idx + dim
    modality_splits[mod] = (start, end)
    idx = end
total_features = fused_train.shape[1]
selected_gene_names = [f"feat_{i}" for i in range(total_features)]
print("Modality splits:", modality_splits)

# -----------------------------
# Gene names per modality
# -----------------------------
gene_name_mod = {}
for mod, (start, end) in modality_splits.items():
    gene_name_mod[mod] = [f"{mod}_{i}" for i in range(end-start)]

    
    
    
    
class PerGeneAdaptiveThreshold:
    def __init__(self, modality_dims, alpha=0.1):
        self.thresholds = {mod: {i: 0.0 for i in range(dim)} for mod, dim in modality_dims.items()}
        self.alpha = alpha
    def update(self, gene_rewards):
        for mod, rewards in gene_rewards.items():
            for gene_id, reward in rewards.items():
                if reward is None or (isinstance(reward, float) and np.isnan(reward)):
                    continue
                self.thresholds[mod][gene_id] = self.alpha*float(reward) + (1-self.alpha)*self.thresholds[mod].get(gene_id,0.0)
    def get(self, mod, gene_id):
        return float(self.thresholds.get(mod, {}).get(gene_id,0.0))

adaptive_thresholds = PerGeneAdaptiveThreshold(modality_dims)   
    
    
    
    
    
# -----------------------------
# Evaluate & plot per-gene/per-modality
# -----------------------------
def evaluate_and_plot_multi_modality(model, algo_name, expression_test, pseudotime_test,
                                     gene_names, modality_splits, adaptive_thresholds=None,
                                     n_episodes=N_EVAL_EPISODES, save_dir=PLOTS_DIR):

    os.makedirs(save_dir, exist_ok=True)
    results = []

    for mod_name, (start_idx, end_idx) in modality_splits.items():
        print(f"Evaluating modality: {mod_name} (genes {start_idx}:{end_idx})")
        gene_names = gene_name_mod.get(mod_name, [f"g{i}" for i in range(end_idx - start_idx)])

        for gene_idx in range(start_idx, end_idx):
            gene_name = gene_names[gene_idx - start_idx]

            y_true, y_pred = [], []
            perturbed_vals, original_vals, pseudotimes = [], [], []

            for ep in range(n_episodes):
                obs = expression_test[ep % expression_test.shape[0]]  # sample a cell
                true_expr = obs[gene_idx]                            # true expression for this gene

                # RL model prediction
                action, _ = model.predict(obs, deterministic=True)
                pred_expr = np.clip(obs[gene_idx] + action[gene_idx], -5, 5)

                y_true.append(1 if true_expr > 0 else 0)
                y_pred.append(1 if pred_expr > 0 else 0)

                original_vals.append(true_expr)
                perturbed_vals.append(pred_expr)
                pseudotimes.append(pseudotime_test[ep % len(pseudotime_test)])

            # --- compute metrics properly ---
            acc = accuracy_score(y_true, y_pred)
            prec = precision_score(y_true, y_pred, average="weighted", zero_division=0)
            rec = recall_score(y_true, y_pred, average="weighted", zero_division=0)
            f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
            try:
                auprc = average_precision_score(y_true, y_pred)
            except Exception:
                auprc = np.nan
            mse = mean_squared_error(original_vals, perturbed_vals)
            rmse = math.sqrt(mse)
            mae = mean_absolute_error(original_vals, perturbed_vals)
            r2 = r2_score(original_vals, perturbed_vals)
            pc = np.corrcoef(original_vals, perturbed_vals)[0, 1] if np.std(original_vals) != 0 else 0.0

            results.append({
                "Algorithm": algo_name,
                "Modality": mod_name,
                "Gene": gene_name,
                "Accuracy": acc,
                "Precision": prec,
                "Recall": rec,
                "F1": f1,
                "AUPRC": auprc,
                "Final Expression MSE": mse,
                "Final Expression RMSE": rmse,
                "Final Expression MAE": mae,
                "Final Expression R²": r2,
                "Final Expression PearsonCorr": pc
            })

    return pd.DataFrame(results)

# -----------------------------
# Example usage
# -----------------------------
# df_metrics = evaluate_and_plot_multi_modality(
#     model=None,
#     algo_name="PPO",
#     expression_test=fused_test,
#     pseudotime_test=pseudotime_test,
#     gene_name_mod=gene_name_mod,
#     modality_splits=modality_splits
# )
# print(df_metrics.head())


# -----------------------------
# Scale & train/test split for each modality
# -----------------------------

def scale_and_split(modality_data, test_size=0.2, seed=SEED):
    expression_train, expression_test = {}, {}
    pseudotime_train, pseudotime_test = {}, {}
    scalers = {}

    for mod, data in modality_data.items():
        scaler = StandardScaler()
        data_scaled = scaler.fit_transform(data)
        scalers[mod] = scaler
        train_idx, test_idx = train_test_split(np.arange(data_scaled.shape[0]), test_size=test_size, random_state=seed)
        expression_train[mod] = data_scaled[train_idx]
        expression_test[mod] = data_scaled[test_idx]
        pseudotime = np.arange(data_scaled.shape[0], dtype=np.float32)
        pseudotime_train[mod] = pseudotime[train_idx]
        pseudotime_test[mod] = pseudotime[test_idx]

    return expression_train, expression_test, pseudotime_train, pseudotime_test, scalers


# -----------------------------
# Adaptive per-gene thresholds (simple moving average)
# -----------------------------
class PerGeneAdaptiveThreshold:
    def __init__(self, modality_dims, alpha=0.1):
        # modality_dims: dict mod -> int (number of features)
        self.thresholds = {mod: {i: 0.0 for i in range(dim)} for mod, dim in modality_dims.items()}
        self.alpha = alpha

    def update(self, gene_rewards):
        for mod, rewards in gene_rewards.items():
            for gene_id, reward in rewards.items():
                if reward is None or (isinstance(reward, float) and np.isnan(reward)):
                    continue
                self.thresholds[mod][gene_id] = self.alpha * float(reward) + (1 - self.alpha) * self.thresholds[mod].get(gene_id, 0.0)

    def get(self, mod, gene_id):
        return float(self.thresholds.get(mod, {}).get(gene_id, 0.0))


# -----------------------------
# PBMC_CRISPR_MultiModalEnv (Gym)
# -----------------------------
import numpy as np
import torch
import torch.nn as nn

# ... other imports and class definitions ...

import numpy as np
import gym
from gym import spaces

import numpy as np
import gym
from gym import spaces

class PBMC_CRISPR_MultiModalEnv(gym.Env):
    metadata = {"render_modes": ["human"]}

    def __init__(self, expression_dict_or_matrix, pseudotime_dict_or_array, modality_splits,
                 max_steps=MAX_STEPS, adaptive_thresholds=None, device='cpu',
                 action_magnitude=0.25, perturb_prob=0.1, max_perturb=3):
        super().__init__()

        # Accept either dicts (per-modality) or already-fused matrices
        if isinstance(expression_dict_or_matrix, dict):
            # build fused matrix in modality_splits order
            self.expression_dict = {mod: np.asarray(expression_dict_or_matrix[mod], dtype=np.float32) for mod in modality_splits.keys()}
            self.expression = np.concatenate([self.expression_dict[mod] for mod in modality_splits.keys()], axis=1)
            # pseudotime: use first modality's pseudotime (cell-level)
            self.pseudotime = np.asarray(next(iter(pseudotime_dict_or_array.values())), dtype=np.float32)
        else:
            self.expression = np.asarray(expression_dict_or_matrix, dtype=np.float32)
            self.expression_dict = {}
            start = 0
            for mod, (s, e) in modality_splits.items():
                self.expression_dict[mod] = self.expression[:, s:e]
            self.pseudotime = np.asarray(pseudotime_dict_or_array, dtype=np.float32)

        self.modality_splits = modality_splits
        self.modality_dims = {m: (e - s) for m, (s, e) in modality_splits.items()}
        self.n_cells, self.n_genes = self.expression.shape
        self.max_steps = max_steps
        self.adaptive_thresholds = adaptive_thresholds
        self.device = device
        self.action_magnitude = action_magnitude
        self.perturb_prob = perturb_prob
        self.max_perturb = max_perturb

        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.n_genes,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.n_genes,), dtype=np.float32)

        # bookkeeping
        self.current_cell = 0

    def reset(self, seed=None, options=None):
        self.idx = np.random.randint(self.n_cells)
        self.state = self.expression[self.idx].copy()
        self.original_state = self.state.copy()

        eligible_idxs = np.where(np.arange(self.n_cells) > self.idx)[0]
        if len(eligible_idxs) == 0:
            eligible_idxs = np.array([self.idx])
        target_idx = np.random.choice(eligible_idxs)
        self.target = self.expression[target_idx].copy()

        self.steps = 0
        self.history = [self.state.copy()]
        self.knockout_genes = set()
        self.overexpressed_genes = set()
        self._apply_crispr_perturbation()
        self.current_cell = 0
        return self.state.copy()

    def _apply_crispr_perturbation(self):
        n_perturb = np.random.randint(1, self.max_perturb + 1)
        for _ in range(n_perturb):
            gene = np.random.randint(0, self.n_genes)
            if np.random.rand() < 0.5:
                self.state[gene] = 0.0
                self.knockout_genes.add(int(gene))
            else:
                self.state[gene] = self.state[gene] * 2.0
                self.overexpressed_genes.add(int(gene))

    def step(self, action):
        action = np.asarray(action, dtype=np.float32).ravel()
        if action.shape[0] != self.n_genes:
            raise ValueError("Action length mismatch.")
        for i, delta in enumerate(action):
            self.state[i] = np.clip(self.state[i] + delta * self.action_magnitude, -5.0, 5.0)

        if np.random.rand() < self.perturb_prob:
            self._apply_crispr_perturbation()

        old_mse = float(np.mean((self.history[-1] - self.target) ** 2))
        new_mse = float(np.mean((self.state - self.target) ** 2))
        reward = old_mse - new_mse

        # subtract adaptive thresholds per modality (if provided)
        if self.adaptive_thresholds is not None:
            for mod, (start, end) in self.modality_splits.items():
                for local_idx, g in enumerate(range(start, end)):
                    reward -= self.adaptive_thresholds.get(mod, local_idx)

        self.steps += 1
        self.history.append(self.state.copy())
        terminated = self.steps >= self.max_steps
        done = terminated
        info = {}
        self.current_cell += 1
        return self.state.copy(), float(reward), done, info

    def render(self, mode='human'):
        print(f"Step {self.steps} - state (first 10): {self.state[:10]}")
        print(f"Knockouts: {sorted(list(self.knockout_genes))[:10]}, Overexpr: {sorted(list(self.overexpressed_genes))[:10]}")




# ... rest of the script ...
# ... rest of the script ...
# -----------------------------
# SB3 wrapper (returns obs only on reset)
# -----------------------------
class GRNEnvWrapper(gym.Env):
    def __init__(self, base_env):
        super().__init__()
        self.env = base_env
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.env.n_genes,), dtype=np.float32)
        self.action_space = self.env.action_space

    def reset(self, **kwargs):
        result = self.env.reset(**kwargs)
        # gym may return obs or (obs, info)
        if isinstance(result, tuple):
            obs = result[0]
        else:
            obs = result
        return np.asarray(obs, dtype=np.float32)

    def step(self, action):
        result = self.env.step(action)
        # handle both old and new gym API
        if isinstance(result, tuple) and len(result) == 5:
            obs, reward, terminated, truncated, info = result
            done = bool(terminated or truncated)
        else:
            obs, reward, done, info = result
        return np.asarray(obs, dtype=np.float32), float(reward), bool(done), info

    def seed(self, seed=None):
        if hasattr(self.env, 'seed'):
            self.env.seed(seed)
        np.random.seed(seed)
        return [seed]


# -----------------------------
# make_env_factory
# -----------------------------
def make_env_factory(expression, pseudotime, modality_splits, adaptive_thresholds, modality_choice="all"):
    def _init():
        base_env = PBMC_CRISPR_MultiModalEnv(
            expression_dict_or_matrix=expression,
            pseudotime_dict_or_array=pseudotime,
            modality_splits=modality_splits,
            max_steps=MAX_STEPS,
            adaptive_thresholds=adaptive_thresholds,
            device=DEVICE,
            action_magnitude=0.25,
            perturb_prob=PERTURB_PROB,
            max_perturb=MAX_PERTURB
        )
        return GRNEnvWrapper(base_env)
    return _init


# -----------------------------
# evaluation helper for SB3 (mean/std reward)
# -----------------------------
def evaluate_model_sb3(model, env, n_eval_episodes=10, deterministic=True):
    all_rewards = []
    for episode in range(n_eval_episodes):
        obs = env.reset()
        done = False
        total_reward = 0.0
        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, reward, done, info = env.step(action)
            if isinstance(reward, (list, tuple, np.ndarray)):
                total_reward += float(np.mean(reward))
            else:
                total_reward += float(reward)
        all_rewards.append(total_reward)
    mean_reward = np.mean(all_rewards)
    std_reward = np.std(all_rewards)
    return mean_reward, std_reward




# -----------------------------
# Train & evaluate PPO / TRPO / TRPO->PPO
# -----------------------------
def train_and_evaluate_algorithms(modality_choice, expression_train, pseudotime_train,
                                  expression_test, pseudotime_test, modality_splits,
                                  gene_names, adaptive_thresholds,
                                  algorithms_to_run=("ppo", "trpo", "trpo_to_ppo"),
                                  train_steps=100000, save_dir=OUT_DIR):

    os.makedirs(save_dir, exist_ok=True)

    # -----------------------------
    # Environments
    # -----------------------------

    
    
    # -----------------------------
    # Modality-specific data
    # -----------------------------
    if modality_choice == "all":
        expr_train_mod = expression_train
        expr_test_mod = expression_test
        pseudo_train_mod = pseudotime_train[next(iter(pseudotime_train.keys()))]  # pick any modality's pseudotime
        pseudo_test_mod = pseudotime_test[next(iter(pseudotime_test.keys()))]
        gene_names_mod = gene_names
        splits_for_env_mod = modality_splits
    else:
        start, end = modality_splits[modality_choice]
        expr_train_mod = expression_train[:, start:end]
        expr_test_mod = expression_test[:, start:end]
        pseudo_train_mod = pseudotime_train[modality_choice]
        pseudo_test_mod = pseudotime_test[modality_choice]
        splits_for_env_mod = {modality_choice: (0, end - start)}
        gene_names_mod = [f"{modality_choice}_{i}" for i in range(end - start)]

    train_env = DummyVecEnv([
        make_env_factory(expr_train_mod, pseudo_train_mod, splits_for_env_mod, adaptive_thresholds, modality_choice)
    ])
   # train_env = VecNormalize(train_env, norm_obs=True, norm_reward=False, clip_obs=10.)

    eval_env = DummyVecEnv([
        make_env_factory(expr_test_mod, pseudo_test_mod, splits_for_env_mod, adaptive_thresholds, modality_choice)
    ])
    eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.)
    
    
    trained_models = {}
    results_reward = {}
    test_metrics_df = []

    # -----------------------------
    # Algorithm-specific kwargs
    # -----------------------------
    ppo_kwargs = dict(
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.0,
        vf_coef=0.5,
        max_grad_norm=0.5,
        learning_rate=1e-4,
        policy_kwargs=dict(net_arch=[dict(pi=[256, 256], vf=[256, 256])], activation_fn=nn.Tanh)
    )

    trpo_kwargs = dict(
        gamma=0.99,
        gae_lambda=0.95,
        cg_max_steps=15,
        cg_damping=0.01,
        line_search_shrinking_factor=0.8,
        n_critic_updates=10,
        learning_rate=1e-4,
        policy_kwargs=dict(net_arch=[dict(pi=[256, 256], vf=[256, 256])], activation_fn=nn.Tanh)
    )

    # -----------------------------
    # Training Loop
    # -----------------------------
    for algo_name in algorithms_to_run:
        print(f"\n===== Training {algo_name.upper()} =====")

        if algo_name == "ppo":
            model = PPO("MlpPolicy", train_env, verbose=1, seed=SEED, **ppo_kwargs)
            model.learn(total_timesteps=train_steps)

        elif algo_name == "trpo":
            model = TRPO("MlpPolicy", train_env, verbose=1, seed=SEED, **trpo_kwargs)
            model.learn(total_timesteps=train_steps)

        elif algo_name == "trpo_to_ppo":
            # Stage 1: TRPO
            model_trpo = TRPO("MlpPolicy", train_env, verbose=1, seed=SEED, **trpo_kwargs)
            model_trpo.learn(total_timesteps=max(1, 10000))

            # Stage 2: PPO warm start
            model = PPO("MlpPolicy", train_env, verbose=1, seed=SEED, **ppo_kwargs)
            try:
                # try parameter transfer if shapes match
                model.set_parameters(model_trpo.get_parameters())
                print("✅ Parameters transferred from TRPO → PPO")
            except Exception as e:
                print("⚠️ Parameter transfer failed, PPO starts fresh:", e)
            model.learn(total_timesteps=max(1, 100000-10000))

        else:
            print(f"⚠️ Unknown algorithm: {algo_name}")
            continue

        # Save trained model
        trained_models[algo_name] = model

        # -----------------------------
        # Evaluate
        # -----------------------------
        mean_r, std_r = evaluate_model_sb3(model, eval_env, n_eval_episodes=N_EVAL_EPISODES)
        results_reward[algo_name] = (mean_r, std_r)
        print(f"✅ Eval {algo_name}: mean={mean_r:.4f}, std={std_r:.4f}")

        metrics_df = evaluate_and_plot_multi_modality(
            model=model,
            algo_name=algo_name,
            expression_test=expr_test_mod,
            pseudotime_test=pseudo_test_mod,
            gene_names=gene_name_mod,
            modality_splits=splits_for_env_mod,
            adaptive_thresholds=adaptive_thresholds,
            n_episodes=N_EVAL_EPISODES,
            save_dir=os.path.join(save_dir, "plots")
        )
        metrics_df.to_csv(os.path.join(save_dir, f"{modality_choice}_{algo_name}_metrics.csv"), index=False)
        test_metrics_df.append(metrics_df)

    # -----------------------------
    # Save testing metrics
    # -----------------------------
    final_test_df = pd.concat(test_metrics_df, ignore_index=True)
    final_test_df.to_csv(os.path.join(save_dir, f"testing_per_gene_metrics_{modality_choice}.csv"), index=False)
    test_summary_df = final_test_df.groupby("Algorithm").mean(numeric_only=True)
    test_summary_df.to_csv(os.path.join(save_dir, f"testing_overall_metrics_{modality_choice}.csv"))

    print(f"\n✅ Training & evaluation complete for all optimizers on modality={modality_choice}.")


    return trained_models, results_reward

def evaluate_model_sb3(model, env, n_eval_episodes=10, deterministic=True):
    """
    Evaluate a Stable-Baselines3 model.
    
    Returns:
        mean_reward, std_reward
    """
    all_rewards = []

    for episode in range(n_eval_episodes):
        obs = env.reset()
        done = False
        total_reward = 0.0

        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, reward, done, info = env.step(action)
            total_reward += reward

        all_rewards.append(total_reward)

    mean_reward = np.mean(all_rewards)
    std_reward = np.std(all_rewards)
    return mean_reward, std_reward


# -----------------------------
# Example run (change args as needed)
# -----------------------------
#"pbmc_multi_output"
trained_models, results_reward = train_and_evaluate_algorithms(
            modality_choice="atac",
            expression_train=fused_train,
            pseudotime_train=pseudotime_train,
            expression_test=fused_test,
            pseudotime_test=pseudotime_test,
            modality_splits=modality_splits,
            gene_names=selected_gene_names,
            adaptive_thresholds=adaptive_thresholds,
            algorithms_to_run=["trpo", "ppo", "trpo_to_ppo"],
            train_steps=100_000,
            save_dir="pbmc_multi_output_atac"
        )
        

■ File filtered_feature_bc_matrix.h5 from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/filtered_feature_bc_matrix.h5
■ Checksum is validated (md5) for filtered_feature_bc_matrix.h5
■ File atac_fragments.tsv.gz from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_fragments.tsv.gz
■ Checksum is validated (md5) for atac_fragments.tsv.gz
■ File atac_fragments.tsv.gz.tbi from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_fragments.tsv.gz.tbi
■ Checksum is validated (md5) for atac_fragments.tsv.gz.tbi
■ File atac_peaks.bed from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_peaks.bed
■ Checksum is validated (md5) for atac_peaks.bed
■ File atac_peak_annotation.tsv from pbmc10k_multiome has been found at /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_peak_annotation.tsv
■ Checksum is validated (md5) for atac_peak_annota

  warn("Dataset is in the 10X .h5 format and can't be loaded as backed.")
  utils.warn_names_duplicates("var")


Added `interval` annotation for features from /Users/boabangfrancis/mudatasets/pbmc10k_multiome/filtered_feature_bc_matrix.h5


  utils.warn_names_duplicates("var")
  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


Added peak annotation from /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_peak_annotation.tsv to .uns['atac']['peak_annotation']
Added gene names to peak annotation in .uns['atac']['peak_annotation']
Located fragments file: /Users/boabangfrancis/mudatasets/pbmc10k_multiome/atac_fragments.tsv.gz
pysam is not available. It is required to work with the fragments file.                 Install pysam from PyPI (`pip install pysam`)                 or from GitHub (`pip install git+https://github.com/pysam-developers/pysam`)
Available modalities: ['rna', 'atac']
Number of common cells: 10000
{'rna': (10000, 32), 'atac': (10000, 32)}
Modality splits: {'rna': (0, 32), 'atac': (32, 64)}

===== Training TRPO =====
Using cpu device




-----------------------------
| time/              |      |
|    fps             | 4341 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
----------------------------------------
| time/                     |          |
|    fps                    | 3198     |
|    iterations             | 2        |
|    time_elapsed           | 1        |
|    total_timesteps        | 4096     |
| train/                    |          |
|    explained_variance     | -0.00201 |
|    is_line_search_success | 1        |
|    kl_divergence_loss     | 0.00406  |
|    learning_rate          | 0.0001   |
|    n_updates              | 1        |
|    policy_objective       | 7.5e+06  |
|    std                    | 1        |
|    value_loss             | 39.1     |
----------------------------------------
----------------------------------------
| time/                     |          |
|    fps                    | 2948     |
|    iterat