In [None]:
import torch
import torch.nn.functional as F
from torch_sparse import SparseTensor
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse
from sklearn.cluster import KMeans
from sklearn.metrics import (
    adjusted_rand_score, normalized_mutual_info_score,
    adjusted_mutual_info_score, fowlkes_mallows_score,
    mutual_info_score, rand_score
)
from tqdm import tqdm
import copy

In [None]:
import json
from pathlib import Path

# Load config
with open('config.json', 'r') as f:
    config = json.load(f)

selected_ds = config['selected_dataset']
#selected_ds = "Baron_Human"
ds_params = config["datasets"][selected_ds]["parameters"]

# Extract Dataset Specifics
ds_config = config["datasets"][selected_ds]
data_path = ds_config["path"]
params = ds_config["parameters"]

# Map to variables used in your existing loop
target_zero_percentages = params["target_percentage"]
seeds_per_rate = ds_params["seeds_per_rate"]
seeds_per_rate = {float(k): v for k, v in seeds_per_rate.items()}
N_HVG = params["N_HVG"]
batch_size = params["batch_size"]
model_dir = Path(config["model_dir"])
bridge_iters = params["iter"]

print(f"Ready to process {selected_ds} with {N_HVG} HVGs and seeds {seeds_per_rate}")

In [None]:
!jupyter nbconvert \
  --execute \
  --to notebook \
  "scGPT Embeddings generation.ipynb" \
  --output "scGPT_Embeddings_executed.ipynb"

In [None]:
adata = sc.read_h5ad(data_path)

In [None]:
def knn_fast(X, k, b, gcn_norm, sym):
    device = X.device
    X = F.normalize(X, dim=1, p=2)
    index = 0
    values = torch.zeros(X.shape[0] * (k + 1)).to(device)
    rows = torch.zeros(X.shape[0] * (k + 1)).to(device)
    cols = torch.zeros(X.shape[0] * (k + 1)).to(device)

    while index < X.shape[0]:
        end = min(index + b, X.shape[0])
        sub_tensor = X[index:end]
        similarities = torch.mm(sub_tensor, X.t())
        vals, inds = similarities.topk(k=k + 1, dim=-1)
        values[index * (k + 1):(end) * (k + 1)] = vals.view(-1)
        cols[index * (k + 1):(end) * (k + 1)] = inds.view(-1)
        rows[index * (k + 1):(end) * (k + 1)] = torch.arange(index, end).view(-1, 1).repeat(1, k + 1).view(-1)
        index += b

    rows, cols = rows.long(), cols.long()
    sparse_adj = SparseTensor(row=rows, col=cols, value=values).to(device)
    return sparse_post_processing(sparse_adj, gcn_norm=gcn_norm, sym=sym).to_torch_sparse_coo_tensor().float()

def sparse_post_processing(adj, add_self_loop=True, sym=True, gcn_norm=False):
    from torch_sparse import fill_diag, sum as sparsesum, mul
    if add_self_loop:
        adj = fill_diag(adj, 2)
    if sym:
        adj = adj + adj.t()
        adj = mul(adj, (torch.ones(adj.size(0), device=adj.device()) * 1/2).view(-1, 1))
    deg = sparsesum(adj, dim=1)
    if gcn_norm:
        deg_inv_sqrt = deg.pow_(-0.5).masked_fill(deg == float('inf'), 0.)
        adj = mul(adj, deg_inv_sqrt.view(-1, 1))
        adj = mul(adj, deg_inv_sqrt.view(1, -1))
    else:
        deg_inv = deg.pow_(-1).masked_fill(deg == float('inf'), 0.)
        adj = mul(adj, deg_inv.view(-1, 1))
    return adj

def drop_data(data_t, rate):
    X = data_t.X
    if scipy.sparse.issparse(X):
        X = np.array(X.todense())
    X_train = np.copy(X)
    if rate > 0.0:
        i, j = np.nonzero(X)
        ix = np.random.choice(len(i), int(np.floor(rate * len(i))), replace=False)
        X_train[i[ix], j[ix]] = 0.0
    data_t.obsm['train'] = X_train
    data_t.obsm['test'] = X
    return data_t

def forward(x, sparse_adj, mask, iters):
    original_x = copy.copy(x)
    device = sparse_adj.device
    x = x.to(device)
    for _ in tqdm(range(iters)):
        x = torch.sparse.mm(sparse_adj, x)
        if mask:
            nonzero_idx = torch.nonzero(original_x)
            x[nonzero_idx[:, 0], nonzero_idx[:, 1]] = original_x[nonzero_idx[:, 0], nonzero_idx[:, 1]]
    return x.cpu()


In [None]:
def run_method(method, adata, rate, seed, target_zero_percentage, adj_matrix=None,iters=None):
    np.random.seed(seed)
    #Prepare data and simulate dropout
    adata_copy = adata.copy()
    adata_copy = drop_data(adata_copy, rate)
    filtered_matrix = torch.tensor(adata_copy.obsm["train"], dtype=torch.float32)
    ft = filtered_matrix.T
    actual_zero_percentage = (filtered_matrix == 0).sum().item() / filtered_matrix.numel()
    print(f"Actual zero percentage: {actual_zero_percentage*100:.2f}%")

    #Run BRIDGE
    if method == "BRIDGE":
        ft = filtered_matrix
        denoised = forward(ft, adj_matrix, iters=iters, mask=False)     
    else:
        raise ValueError(f"Unsupported method: {method}")

    # Clustering Evaluation
    celltype = adata.obs["celltype"].values
    kmeans = KMeans(n_clusters=np.unique(celltype).shape[0], n_init=20, random_state=0)
    pred = kmeans.fit_predict(denoised.numpy())
    true = celltype
    #Compile Results
    return {
        "method": method,
        "target_zero_percentage": target_zero_percentage,
        "dropout_rate": rate,
        "random_seed": seed,
        "actual_zero_percentage": actual_zero_percentage,
        "RI": rand_score(pred, true),
        "NMI": normalized_mutual_info_score(pred, true),
        "AMI": adjusted_mutual_info_score(pred, true),
        "FMI": fowlkes_mallows_score(pred, true),
        "MI": mutual_info_score(pred, true),
        "ARI": adjusted_rand_score(pred, true),
    }


In [None]:
# Initialize storage for all metrics
metrics_dict = {metric: [] for metric in ['RI', 'NMI', 'AMI', 'FMI', 'MI', 'ARI']}

original_X = adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X
initial_zero_rate = (original_X == 0).sum() / original_X.size
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for target_zero_percentage in target_zero_percentages:
    # Calculate required dropout rate to reach target sparsity
    if target_zero_percentage > initial_zero_rate:
        random_seeds = seeds_per_rate.get(target_zero_percentage, [0])
        needed_zeros = int(target_zero_percentage * original_X.size)
        needed_drops = needed_zeros - (original_X == 0).sum()
        rate = max(0.0, min(1.0, needed_drops / np.count_nonzero(original_X)))
    else:
        rate = 0.0
        random_seeds = [0]

    # Iterating through seeds (currently fixed at 0)
    for seed in random_seeds:
        print(f"Running for target_zero={target_zero_percentage:.2f}, seed={seed}")
        targetpercentage_100 = int(target_zero_percentage * 100)
        # Load precomputed scGPT embeddings
        file_name = f'Embeddings/{selected_ds}_Embeddings_{targetpercentage_100}_{seed}.h5ad'
        embd = sc.read(file_name)
        X_np = embd.obsm['X_scGPT']
        X_torch = torch.tensor(X_np, dtype=torch.float32).to(device)
        # Generate adjacency matrix
        adj_matrix = knn_fast(X_torch, k=10, b=1000, gcn_norm=False, sym=True)
        # Execute the method and store results
        results_BRIDGE = run_method('BRIDGE', adata, rate, seed, target_zero_percentage, adj_matrix, bridge_iters)                 
        for metric in metrics_dict:
            metrics_dict[metric].append({
                'target_zero_percentage': target_zero_percentage,
                'dropout_rate': rate,
                'random_seed': seed,
                'actual_zero_percentage': results_BRIDGE['actual_zero_percentage'],
                'BRIDGE': results_BRIDGE[metric],
            })

In [None]:
# Save metrics to CSV
for metric, data in metrics_dict.items():
    df = pd.DataFrame(data)
    df.to_csv(f'{metric}.csv', index=False)