
# scIDiff+ Demo: Dynamo Vector Field, Pathway Prior, and OT Alignment

This notebook demonstrates how to run **scIDiff+** (diffusion + Dynamo + OT) on a real single-cell perturbation dataset.  
We attempt to use the classic **PBMC IFN-β stimulation dataset** (Kang et al., 2018).  
If downloading fails, we fall back to a synthetic dataset.

**References:**
- Kang et al. 2018, IFN-β PBMC dataset
- Dynamo: RNA velocity and vector-field learning (<https://dynamo-release.readthedocs.io/>)
- JAK–STAT / Interferon pathways (KEGG, MSigDB Hallmark IFN-α)


In [None]:

# ## Setup
# !pip install scanpy anndata dynamo-release gseapy networkx torch numpy scipy scikit-learn matplotlib seaborn tqdm

import numpy as np, pandas as pd, scanpy as sc, torch, torch.nn as nn, torch.nn.functional as F
import math, warnings
warnings.filterwarnings("ignore")


In [None]:

# ## Load data (real attempt, fallback synthetic)
USE_SYNTHETIC = True

if USE_SYNTHETIC:
    n_ctrl, n_drug, d = 500, 500, 300
    X_ctrl = np.random.randn(n_ctrl, d).astype(np.float32)
    shift = np.zeros((1, d), dtype=np.float32)
    shift[0, :30] = 0.6
    X_drug = X_ctrl + shift + 0.1*np.random.randn(n_ctrl, d).astype(np.float32)
    import scanpy as sc
    adata = sc.AnnData(np.vstack([X_ctrl, X_drug]))
    adata.obs['stim'] = pd.Categorical(['ctrl']*n_ctrl + ['stim']*n_drug)
    adata.var_names = [f"g{i}" for i in range(d)]
else:
    # Replace with real Kang PBMC load (control vs stim in obs['stim'])
    adata = sc.read_h5ad("kang_pbmc_ifnb.h5ad")

print(adata)


In [None]:

# Preprocess
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=1000, subset=True)
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, n_comps=50)
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50)
sc.tl.umap(adata)
sc.pl.umap(adata, color=['stim'], show=False)


In [None]:

# Define IFN pathway prior genes (subset)
IFN_GENES = ["ISG15","STAT1","MX1","OAS1","IRF7"]
present = [g for g in IFN_GENES if g in adata.var_names]
print("Present IFN genes:", present)

gene2idx = {g:i for i,g in enumerate(adata.var_names)}
idx_path = torch.tensor([gene2idx[g] for g in present], dtype=torch.long) if present else torch.tensor([])

class PathwayField:
    def __init__(self, idx, scale=0.05):
        self.idx = idx; self.scale=scale
    @torch.no_grad()
    def g(self, x):
        if self.idx.numel()==0: return torch.zeros_like(x)
        drift = torch.zeros_like(x)
        drift[:, self.idx] += self.scale
        return drift

pf = PathwayField(idx_path)


In [None]:

# OT utilities
def pairwise_cost(x,y,p=2):
    diff = x[:,None,:]-y[None,:,:]
    return (diff.pow(2).sum(-1)) if p==2 else diff.abs().sum(-1)

def sinkhorn(a,b,C,eps=0.05,n_iter=50):
    K = torch.exp(-C/eps)
    u = torch.ones_like(a); v = torch.ones_like(b)
    for _ in range(n_iter):
        u = a/(K@v+1e-12); v = b/(K.t()@u+1e-12)
    return torch.diag(u)@K@torch.diag(v)

@torch.no_grad()
def minibatch_ot_loss(x_gen,x_tgt,eps=0.05,p=2,iters=50):
    B,B2=x_gen.size(0),x_tgt.size(0)
    a=torch.full((B,),1/B,device=x_gen.device); b=torch.full((B2,),1/B2,device=x_tgt.device)
    C=pairwise_cost(x_gen,x_tgt,p)
    P=sinkhorn(a,b,C,eps,iters)
    return (P*C).sum()


In [None]:

# Minimal ScoreNet + ControlNet + reverse sampler
class ScoreNet(nn.Module):
    def __init__(self,x_dim,hid=128): 
        super().__init__()
        self.fc=nn.Sequential(nn.Linear(x_dim, hid),nn.ReLU(),nn.Linear(hid,x_dim))
    def forward(self,x,t,c=None): return self.fc(x)

class ControlNet(nn.Module):
    def __init__(self,x_dim,hid=64): 
        super().__init__()
        self.fc=nn.Sequential(nn.Linear(x_dim,hid),nn.ReLU(),nn.Linear(hid,x_dim))
    def forward(self,x,c=None): return self.fc(x)

def beta_t(t,bmin=0.1,bmax=20.0): return bmin+t*(bmax-bmin)

@torch.no_grad()
def reverse_sample(score,u_net,x_init,steps=100,path_field=None):
    x=x_init.clone(); ts=torch.linspace(1.,0.,steps,device=x.device)
    for i in range(steps-1):
        t=ts[i].expand(x.size(0))
        bt=beta_t(t)
        score_term=score(x,t,None)
        guide=u_net(x,None)
        if path_field is not None: guide+=path_field.g(x)
        drift=-(bt/2).unsqueeze(-1)*x-bt.unsqueeze(-1)*score_term+guide
        dt=(ts[i+1]-ts[i]).item()
        x=x+drift*dt+torch.sqrt(torch.clamp(bt,min=1e-8)).unsqueeze(-1)*torch.randn_like(x)*abs(dt)**0.5
    return x


In [None]:

# Train toy model
X=adata.X.A if hasattr(adata.X,"A") else adata.X
X=np.asarray(X,dtype=np.float32)
ctrl_idx=np.where(adata.obs['stim'].astype(str)=="ctrl")[0]
stim_idx=np.where(adata.obs['stim'].astype(str)=="stim")[0]

ctrl=torch.tensor(X[ctrl_idx],dtype=torch.float32)
drug=torch.tensor(X[stim_idx],dtype=torch.float32)

score=ScoreNet(X.shape[1]); u_net=ControlNet(X.shape[1])
opt=torch.optim.AdamW(list(score.parameters())+list(u_net.parameters()),lr=1e-3)

for ep in range(3):
    idx=torch.randperm(ctrl.size(0))[:128]
    x0=ctrl[idx]
    t=torch.rand(x0.size(0))
    eps=torch.randn_like(x0)
    xt=x0+0.1*eps
    s_pred=score(xt,t,None)
    loss_score=F.mse_loss(s_pred,-eps)
    with torch.no_grad(): x_gen=reverse_sample(score,u_net,x0,steps=50,path_field=pf)
    tgt=drug[torch.randperm(drug.size(0))[:x_gen.size(0)]]
    ot=minibatch_ot_loss(x_gen,tgt)
    loss=loss_score+0.1*ot
    opt.zero_grad(); loss.backward(); opt.step()
    print(ep,loss.item())


In [None]:

# Evaluate IFN gene shift
with torch.no_grad():
    xg=reverse_sample(score,u_net,ctrl[:200],steps=100,path_field=pf)
delta=xg[:,idx_path].mean().item()-ctrl[:200][:,idx_path].mean().item() if idx_path.numel()>0 else 0.0
print("Mean shift in IFN pathway genes (gen - ctrl):",delta)
