In [None]:
import os
from pathlib import Path
import random
import string

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import tcr_tda
from tcr_tda import (
    run_pipeline,
    build_olga_pgen_human_TRB,
    compute_distance_matrix,
    compute_distance_matrices_parallel,
    dm_to_graph,
    dm_to_graph_parallel_threshold,
    basic_network_metrics,
    compute_persistence,
    compute_persistence_batch,
    plot_persistence_diagram,
    plot_persistence_barcode,
    plot_betti_curve,
    node_removal_analysis,
    node_removal_analysis_batch,
    NodeRemovalResult,
)

AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWY"  # amino acid alphabet


def random_cdr3(length: int) -> str:
    return "".join(random.choice(AA_ALPHABET) for _ in range(length))


def random_repertoire(n_seqs: int, min_len: int = 10, max_len: int = 14):
    lengths = np.random.randint(min_len, max_len + 1, size=n_seqs)
    return np.array([random_cdr3(int(L)) for L in lengths], dtype=object)


: 

In [None]:
N = 100  # depth per dataset

seqs_A = random_repertoire(N)
seqs_B = random_repertoire(N)

print("Dataset A (first 5):", seqs_A[:5])
print("Dataset B (first 5):", seqs_B[:5])


In [None]:
# Cell 3: test compute_distance_matrix

dm_A = compute_distance_matrix(seqs_A)
dm_B = compute_distance_matrix(seqs_B)

print("Distance matrix A shape:", dm_A.shape)
print("Distance matrix B shape:", dm_B.shape)


In [None]:
# Cell 4: test compute_distance_matrices_parallel

seqs_by_dataset = {"A": seqs_A, "B": seqs_B}
dms_parallel = compute_distance_matrices_parallel(seqs_by_dataset)

for name, dm in dms_parallel.items():
    print(f"{name}: shape {dm.shape}")


In [None]:
# Cell 5: test graph builders (sequential + parallel threshold) and basic_network_metrics

# Use the dm_A computed above
epsilon = 0.3  # toy threshold

G_thresh = dm_to_graph(dm_A, nodes=seqs_A, mode="threshold", epsilon=epsilon)
metrics_thresh = basic_network_metrics(G_thresh)

print("Sequential threshold graph metrics:")
for k, v in metrics_thresh.items():
    print(f"  {k}: {v}")

G_thresh_parallel = dm_to_graph_parallel_threshold(
    dm_A,
    nodes=seqs_A,
    epsilon=epsilon,
    include_weights=True,
    max_workers=2,
)

metrics_thresh_parallel = basic_network_metrics(G_thresh_parallel)

print("\nParallel threshold graph metrics:")
for k, v in metrics_thresh_parallel.items():
    print(f"  {k}: {v}")


In [None]:
# Cell 6: TDA on single matrix and batch

hom_dim = 1

diagram_A, betti_A = compute_persistence(dm_A, hom_dim=hom_dim)
print("Diagram A shape:", diagram_A.shape)
print("Betti A length:", len(betti_A))

tda_batch = compute_persistence_batch({"A": dm_A, "B": dm_B}, hom_dim=hom_dim)

for name, (dgm, bt) in tda_batch.items():
    print(f"{name}: diagram shape {dgm.shape}, betti length {len(bt)}")


In [None]:
# Cell 7: plot diagram, barcode, and Betti curve for dataset A

plt.figure(figsize=(4, 4))
plot_persistence_diagram(diagram_A, title="Random A – H1 diagram")

plt.figure(figsize=(6, 3))
plot_persistence_barcode(diagram_A, title="Random A – H1 barcode")

plt.figure(figsize=(6, 3))
plot_betti_curve(betti_A, title="Random A – H1 Betti curve")


In [None]:
# Cell 8: NRA for a single sequence (node_removal_analysis)

# Pick a sequence that definitely exists in seqs_A
target_seq = seqs_A[0]
print("Target sequence for NRA:", target_seq)

# Use G_thresh (graph on A) and dm_A
nra_result = node_removal_analysis(
    seq_to_remove=target_seq,
    dm=dm_A,
    node_labels=list(seqs_A),
    hom_dim=hom_dim,
    graph=G_thresh,
    pgen_fn=None,  # no OLGA here to keep it light
)

print("\nNRA result (single):")
print("  seq:", nra_result.seq)
print("  index:", nra_result.index)
print("  delta_betti_sum:", nra_result.delta_betti_sum)
print("  pgen:", nra_result.pgen)
print("  network metrics BEFORE:", nra_result.network_metrics_before)
print("  network metrics AFTER:", nra_result.network_metrics_after)

assert isinstance(nra_result, NodeRemovalResult)


In [None]:
# Cell 9: NRA for a few sequences in parallel (node_removal_analysis_batch)

from tcr_tda.nra import node_removal_analysis_batch  # explicit import from subpackage

seqs_to_remove = list(seqs_A[:5])  # first 5 sequences

nra_results_batch = node_removal_analysis_batch(
    seqs_to_remove=seqs_to_remove,
    dm=dm_A,
    node_labels=list(seqs_A),
    hom_dim=hom_dim,
    graph=G_thresh,
    pgen_fn=None,
    max_workers=2,
)

print("\nNRA batch results (5 nodes):")
for seq, res in nra_results_batch.items():
    print(f"  {seq}: delta_betti_sum={res.delta_betti_sum}")


In [None]:
# Cell 10: test build_olga_pgen_human_TRB (optional, will skip if OLGA not installed)

try:
    pgen_fn = build_olga_pgen_human_TRB()
    test_seq = seqs_A[1]
    pgen_val = pgen_fn(test_seq)
    print("OLGA Pgen test:")
    print("  sequence:", test_seq)
    print("  Pgen:", pgen_val)
except Exception as e:
    print("OLGA not available or failed to initialize, skipping Pgen test.")
    print("Error:", e)


In [None]:
# Cell 11: end-to-end pipeline test on synthetic data (depth 100)

BASE_DIR = Path("demo_random_data")
DATA_DIR = BASE_DIR / "data"
OUT_DM = BASE_DIR / "networks"
OUT_TDA = BASE_DIR / "Topology"

DATA_DIR.mkdir(parents=True, exist_ok=True)

# Build two fake datasets, each with 100 sequences and a cdr3aa column
df_A = pd.DataFrame({"cdr3aa": seqs_A})
df_B = pd.DataFrame({"cdr3aa": seqs_B})

df_A.to_csv(DATA_DIR / "demo_A.txt", sep="\t", index=False)
df_B.to_csv(DATA_DIR / "demo_B.txt", sep="\t", index=False)

print("Wrote demo_A.txt and demo_B.txt to", DATA_DIR.resolve())

# Simple dataset filter (accept everything)
def keep_all(name: str) -> bool:
    return True

# Run pipeline WITHOUT OLGA (pgen_builder=None)
run_pipeline(
    data_glob=str(DATA_DIR / "*.txt"),
    out_dir_dm=str(OUT_DM),
    out_dir_tda=str(OUT_TDA),
    max_rows_per_dataset=None,  # use full depth 100
    hom_dim=1,
    cdr3_col="cdr3aa",
    dataset_name_filter=keep_all,
    distance_metric=None,        # default: Needleman–Wunsch
    build_graphs=True,
    graph_mode="threshold",
    graph_epsilon=0.3,
    parallel_graph_threshold=True,
    pgen_builder=None,           # skip OLGA for demo
    max_workers_dm=2,
    max_workers_tda=2,
    max_workers_node_removal=2,
)

print("\nPipeline finished. Check outputs in:")
print("  Distance / graphs ->", OUT_DM.resolve())
print("  TDA ->", OUT_TDA.resolve())


In [None]:
# Cell 12: quick inspection of saved outputs

# Load combined NPZ
dm_npz_path = OUT_DM / "distance_matrices_and_nodes_all.npz"
npz = np.load(dm_npz_path, allow_pickle=True)

print("Keys in NPZ:", list(npz.keys())[:10])

# Load one graph
import networkx as nx

graph_files = list(OUT_DM.glob("graph_*.gpickle"))
print("Graph files:", [g.name for g in graph_files])

if graph_files:
    G_loaded = nx.read_gpickle(graph_files[0])
    print("Loaded graph:", graph_files[0].name)
    print("  n_nodes:", G_loaded.number_of_nodes())
    print("  n_edges:", G_loaded.number_of_edges())
