<a href="https://colab.research.google.com/github/klundquist/scTCR-project/blob/main/fig1A.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scirpy==0.16.1 scanpy matplotlib muon

In [None]:
# 1. Setup: Imports, download, load full Wu2020

import os
import muon as mu
import numpy as np
import scanpy as sc
import scirpy as ir
from matplotlib import pyplot as plt

sc.set_figure_params(figsize=(4, 4))
sc.settings.verbosity = 2

use_3k = False  # Use full dataset, not just 3k
file_path = "data/wu2020.h5mu"

if os.path.exists(file_path):
    print("Loading full wu2020 dataset from disk...")
    mdata = mu.read(file_path)
else:
    print("Downloading full wu2020 dataset...")
    mdata = ir.datasets.wu2020()
    os.makedirs("data", exist_ok=True)
    mdata.write(file_path)
print(mdata)

In [None]:
# 2. QC/filter

sc.pp.filter_genes(mdata["gex"], min_cells=10)
sc.pp.filter_cells(mdata["gex"], min_genes=100)
sc.pp.normalize_per_cell(mdata["gex"])
sc.pp.log1p(mdata["gex"])
sc.pp.highly_variable_genes(mdata["gex"], flavor="cell_ranger", n_top_genes=5000)
sc.tl.pca(mdata["gex"])
sc.pp.neighbors(mdata["gex"])
# Removed mdata.update() here as it seems to be causing issues with the airr modality

In [None]:
# Step 3: TCR chain QC, clonotype assignment with proper filtering and syncing


# Index chains and assign primary/secondary chains per cell
ir.pp.index_chains(mdata)

# Quality control on chains
ir.tl.chain_qc(mdata)

# Filter out 'multichain', 'orphan VDJ', and 'orphan VJ' cells in a single step
# Filter the airr modality's AnnData object
airr_filtered = mdata["airr"][
    ~np.isin(mdata["airr"].obs["chain_pairing"], ["multichain", "orphan VDJ", "orphan VJ"])
].copy()

# Update the mdata object with the filtered airr modality
mdata.mod["airr"] = airr_filtered

mdata.update()  # Synchronize after filtering

# Compute nucleotide sequence identity distances
ir.pp.ir_dist(mdata)

# Define clonotypes with full receptor arms, ignoring dual receptors beyond primary pairs
ir.tl.define_clonotypes(mdata, receptor_arms="all", dual_ir="primary_only")

# Print updated MuData info to confirm filtering and annotation
print(mdata)

In [None]:
# 4. Merge cell metadata

mdata.update()
gex_obs = mdata.mod["gex"].obs
airr_obs = mdata.mod["airr"].obs
merged_obs = gex_obs.join(airr_obs, how="inner")  # index: cell barcodes

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# ----------------------------
# Parameters
# ----------------------------
patients_to_plot = ["Renal1", "Renal2", "Renal3", "Lung6"]

PALETTE = {
    "Blood_singleton": "#a6cee3",   # light blue
    "Blood_multiplet": "#1f78b4",   # dark blue
    "NAT_singleton": "#b2df8a",     # light green
    "NAT_multiplet": "#33a02c",     # dark green
    "Tumor_singleton": "#fb9a99",   # light pink
    "Tumor_multiplet": "#e31a1c",   # dark red
    "Dual": "#6a3d9a",              # purple
    "Other": "#cccccc"              # gray (catch-all)
}

# Scaling constants
s_min = 10
s_scale = 40
alpha_pts = 0.7
jitter_decades = 0.25   # jitter width in log10 space
min_frac = 1e-6         # much lower cutoff to show singletons

# ----------------------------
# Expansion category function
# ----------------------------
def expansion_category(n, t, b):
    if b == 1 and n == 0 and t == 0: return "Blood_singleton"
    if b > 1 and n == 0 and t == 0: return "Blood_multiplet"
    if n == 1 and t == 0 and b == 0: return "NAT_singleton"
    if n > 1 and t == 0 and b == 0: return "NAT_multiplet"
    if t == 1 and n == 0 and b == 0: return "Tumor_singleton"
    if t > 1 and n == 0 and b == 0: return "Tumor_multiplet"
    if n > 0 and t > 0: return "Dual"
    return "Other"

# ----------------------------
# Main scatter panels
# ----------------------------
fig, axes = plt.subplots(2, 2, figsize=(14, 12), sharex=True, sharey=True)

for idx, patient in enumerate(patients_to_plot):
    data = merged_obs[merged_obs['patient'] == patient].copy()

    # count clonotypes across sources
    counts = data.groupby(['clonotype_orig', 'source']).size().unstack(fill_value=0)
    for col in ['NAT', 'Tumor', 'Blood']:
        if col not in counts.columns:
            counts[col] = 0
    counts = counts[['NAT', 'Tumor', 'Blood']].copy()

    nat_total   = int((data['source'] == 'NAT').sum())
    tumor_total = int((data['source'] == 'Tumor').sum())

    counts['NAT_frac']   = np.where(nat_total   > 0, counts['NAT']   / nat_total,   0.0)
    counts['Tumor_frac'] = np.where(tumor_total > 0, counts['Tumor'] / tumor_total, 0.0)
    counts['NAT_frac_plot']   = np.clip(counts['NAT_frac'].values,   min_frac, 1.0)
    counts['Tumor_frac_plot'] = np.clip(counts['Tumor_frac'].values, min_frac, 1.0)

    # Categories & sizes
    counts['Category'] = [expansion_category(n, t, b) for n, t, b in counts[['NAT','Tumor','Blood']].itertuples(index=False)]
    counts['Color']    = counts['Category'].map(PALETTE)
    counts['size']     = s_min + s_scale * np.sqrt(counts['Blood'].values)

    # Jitter in log space
    logx = np.log10(counts['NAT_frac_plot'].values)
    logy = np.log10(counts['Tumor_frac_plot'].values)
    logx += np.random.uniform(-jitter_decades, jitter_decades, size=len(logx))
    logy += np.random.uniform(-jitter_decades, jitter_decades, size=len(logy))
    x_plot, y_plot = 10**logx, 10**logy

    # Scatter
    ax = axes[idx // 2, idx % 2]
    ax.scatter(
        x_plot, y_plot,
        s=counts['size'].values,
        c=counts['Color'].values,
        edgecolor='k', linewidth=0.15,
        alpha=alpha_pts
    )

    ax.set_xscale('log'); ax.set_yscale('log')
    ax.set_xlim(min_frac, 1); ax.set_ylim(min_frac, 1)
    ax.plot([min_frac, 1], [min_frac, 1], 'k--', lw=1)
    ax.set_xlabel("Normalized clone size in NAT")
    ax.set_ylabel("Normalized clone size in Tumor")

    # Weighted correlation (Dual clones only)
    dual = counts[(counts['NAT'] > 0) & (counts['Tumor'] > 0)]
    if len(dual) > 1:
        x = np.log10(dual['NAT'].values.astype(float))
        y = np.log10(dual['Tumor'].values.astype(float))
        w = 1.0 + dual['Blood'].values.astype(float)
        xm, ym = np.average(x, weights=w), np.average(y, weights=w)
        num = np.sum(w * (x - xm) * (y - ym))
        den = np.sqrt(np.sum(w * (x - xm)**2) * np.sum(w * (y - ym)**2))
        rw = num / den if den > 0 else np.nan
        nD = len(dual)
        ax.set_title(f"{patient} (nD={nD}, rw={rw:.2f})")
    else:
        ax.set_title(patient)

plt.tight_layout()
plt.show()

# ----------------------------
# Separate legend figure
# ----------------------------
fig_leg, ax_leg = plt.subplots(figsize=(6, 5))
ax_leg.axis("off")

# Expansion categories
cat_handles = [
    mpatches.Patch(color=PALETTE[k], label=k.replace("_"," "))
    for k in ["Blood_singleton","Blood_multiplet",
              "NAT_singleton","NAT_multiplet",
              "Tumor_singleton","Tumor_multiplet","Dual"]
]
leg1 = ax_leg.legend(cat_handles, [h.get_label() for h in cat_handles],
                     title="Expansion pattern", loc="upper left", frameon=False)
ax_leg.add_artist(leg1)

# Blood clone size legend
blood_sizes = [0, 1, 2, 5, 10, 20, 50, 100]
size_handles = [plt.scatter([], [], s=s_min + s_scale*np.sqrt(v), color="k") for v in blood_sizes]
leg2 = ax_leg.legend(size_handles, list(map(str, blood_sizes)),
                     title="Clone size, blood", loc="upper right",
                     ncol=2, frameon=False)
ax_leg.add_artist(leg2)

# NAT × Tumor inset heatmap
inset_ax = fig_leg.add_axes([0.35, 0.05, 0.25, 0.25])  # relative to legend figure
g = np.outer(np.linspace(0, 1, 25), np.linspace(1, 0, 25))
inset_ax.imshow(g, cmap="cool", origin="lower")
inset_ax.set_xticks([0,24]); inset_ax.set_xticklabels(["0","∞"], fontsize=8)
inset_ax.set_yticks([0,24]); inset_ax.set_yticklabels(["0","∞"], fontsize=8)
inset_ax.set_xlabel("Clone size, NAT", fontsize=8, labelpad=1)
inset_ax.set_ylabel("Clone size, Tumour", fontsize=8, labelpad=1)
inset_ax.tick_params(axis='both', which='both', length=0)
for spine in inset_ax.spines.values():
    spine.set_visible(False)

plt.tight_layout()
plt.show()
