# Trajectory Inference

This notebook performs trajectory analysis including:
1. RNA velocity computation
2. Pseudotime inference
3. Trajectory visualization
4. Cell state dynamics


In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.insert(0, str(Path.cwd().parent))

from src.trajectory import infer_trajectories
from src.preprocess import correct_batch
from src.utils import load_adata, normalize_data, find_hvg

sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, facecolor='white')


In [None]:
# Load filtered data
adata = load_adata("../data/synthetic/adata_filtered.h5ad")
print(f"Loaded: {adata.n_obs:,} cells, {adata.n_vars:,} genes")

# Normalize
adata = normalize_data(adata, method='log1p')
adata = find_hvg(adata, n_top_genes=2000)

# Batch correction
adata = correct_batch(adata, method='harmony', batch_key='patient_id')


In [None]:
# Infer trajectories
adata = infer_trajectories(
    adata,
    compute_velocity=True,
    compute_pseudotime=True,
    n_pcs=50
)

print("Trajectory inference complete!")
print(f"Available keys: {list(adata.obs.columns)}")
print(f"Available embeddings: {list(adata.obsm.keys())}")


In [None]:
# Visualize trajectories
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# UMAP colored by timepoint
if 'X_umap' in adata.obsm and 'timepoint' in adata.obs:
    scatter = axes[0].scatter(
        adata.obsm['X_umap'][:, 0],
        adata.obsm['X_umap'][:, 1],
        c=pd.Categorical(adata.obs['timepoint']).codes,
        cmap='viridis',
        s=1,
        alpha=0.5
    )
    axes[0].set_xlabel('UMAP 1')
    axes[0].set_ylabel('UMAP 2')
    axes[0].set_title('UMAP by Timepoint')
    plt.colorbar(scatter, ax=axes[0], label='Timepoint')

# UMAP colored by pseudotime
if 'velocity_pseudotime' in adata.obs:
    scatter = axes[1].scatter(
        adata.obsm['X_umap'][:, 0],
        adata.obsm['X_umap'][:, 1],
        c=adata.obs['velocity_pseudotime'],
        cmap='plasma',
        s=1,
        alpha=0.5
    )
    axes[1].set_xlabel('UMAP 1')
    axes[1].set_ylabel('UMAP 2')
    axes[1].set_title('UMAP by Pseudotime')
    plt.colorbar(scatter, ax=axes[1], label='Pseudotime')

plt.tight_layout()
plt.show()


In [None]:
# Save trajectory data
adata.write("../data/synthetic/adata_with_trajectories.h5ad")
print("Trajectory data saved.")
