In [None]:
# Import dependencies

import pandas as pd
import pickle
import numpy as np
import matplotlib.pyplot as plt
import math
from collections import Counter
from scipy.optimize import curve_fit
from scipy.stats import poisson
from scipy.special import zeta
import seaborn as sb
import matplotlib.cm as cm
from scipy.stats import binom
import copy
from Bio import Phylo
from ete3 import Tree, TreeStyle, NodeStyle
from scipy.interpolate import interp1d

from data_analysis_utils import *

In [None]:
# File paths
tonsil_vdjs_path = 'Data/tonsil_vdjs.tsv'
metadata_path = "Data/tonsil_LR_spatialbc_metadata.csv"
UMI_collapsed_path = "Data/UMI_collapsed_on_tissue_read_list.csv"
clone_list_path = "Data/tonsil_all_ontissue_clone_list.csv"

# Read in data
df_IGH_IF, df_IGH_EF_PBdom, df_combined, df_TRB, metadata_df, df_IGH_EF, df_IGH = data_readin(tonsil_vdjs_path, metadata_path, UMI_collapsed_path, clone_list_path)

# Write sequences of multi-clone lineages to FASTA files for tree construction
# Note that FastTree must be run separately on these files to generate the trees
lineage_ids, lineage_lens, N_trees, N_lineages_tot = prepare_sequences_for_trees(df_IGH, df_IGH_IF,df_combined, write_files = False)

In [None]:
# FIGURE S1: Clone statistics across barcodes

# Clone statistics
n_clones_IF = df_IGH_IF['vdj_sequence'].nunique()
vdj_counts_IF = df_IGH_IF['vdj_sequence'].value_counts()
n_clones_repeated_IF = (vdj_counts_IF > 1).sum()
vdj_barcode_counts_IF = df_IGH_IF.groupby(['vdj_sequence', 'st_barcode']).size()
n_clones_repeated_stbc_IF = (vdj_barcode_counts_IF > 1).groupby(level=0).any().sum()
values_IF = [n_clones_IF - n_clones_repeated_IF, n_clones_repeated_IF - n_clones_repeated_stbc_IF, n_clones_repeated_stbc_IF]
labels = ['Unique clones', 'Repeated (diff. spatial BCs only)', 'Repeated clones (with same spatial BC)']
colors = ['#66c2a5', '#3288bd', '#1b2838']  # teal, blue, dark navy

# For each vdj_sequence with vdj_barcode_counts_IF > 1, count unique st_barcodes it is associated with
vdj_with_repeats = vdj_barcode_counts_IF[vdj_barcode_counts_IF > 1].index.get_level_values(0).unique()
unique_st_barcodes_per_vdj = [
    df_IGH_IF[df_IGH_IF['vdj_sequence'] == vdj]['st_barcode'].nunique()
    for vdj in vdj_with_repeats
]
print(("Fraction of repeatedly sampled clones sampled across multiple barcodes = " +  str(np.sum(np.array(unique_st_barcodes_per_vdj) > 1)/len(unique_st_barcodes_per_vdj))))

# Count unique vdj_sequences that appear exactly twice, both with the same st_barcode
vdj_counts_exact2 = vdj_barcode_counts_IF.groupby(level=0).sum()
vdj_twice = vdj_counts_exact2[vdj_counts_exact2 == 2].index

# For these, check if both counts are for the same st_barcode
vdj_twice_same_stbc = [
    vdj for vdj in vdj_twice
    if df_IGH_IF[df_IGH_IF['vdj_sequence'] == vdj]['st_barcode'].nunique() == 1
]
fraction_twice_same_stbc = len(vdj_twice_same_stbc) / df_IGH_IF['vdj_sequence'].nunique()
print("Fraction of unique vdj_sequences appearing exactly twice, both with the same st_barcode =", fraction_twice_same_stbc)

n_clones_EF = df_IGH_EF['vdj_sequence'].nunique()
vdj_counts_EF = df_IGH_EF['vdj_sequence'].value_counts()
n_clones_repeated_EF = (vdj_counts_EF > 1).sum()
vdj_barcode_counts_EF = df_IGH_EF.groupby(['vdj_sequence', 'st_barcode']).size()
n_clones_repeated_stbc_EF= (vdj_barcode_counts_EF > 1).groupby(level=0).any().sum()
values_EF = [n_clones_EF - n_clones_repeated_EF, n_clones_repeated_EF - n_clones_repeated_stbc_EF, n_clones_repeated_stbc_EF]

n_clones_EF_PBdom = df_IGH_EF_PBdom['vdj_sequence'].nunique()
vdj_counts_EF_PBdom = df_IGH_EF_PBdom['vdj_sequence'].value_counts()
n_clones_repeated_EF_PBdom = (vdj_counts_EF_PBdom > 1).sum()
vdj_barcode_counts_EF_PBdom = df_IGH_EF_PBdom.groupby(['vdj_sequence', 'st_barcode']).size()
n_clones_repeated_stbc_EF_PBdom = (vdj_barcode_counts_EF_PBdom > 1).groupby(level=0).any().sum()
values_EF_PBdom = [
    n_clones_EF_PBdom - n_clones_repeated_EF_PBdom,
    n_clones_repeated_EF_PBdom - n_clones_repeated_stbc_EF_PBdom,
    n_clones_repeated_stbc_EF_PBdom
]

x_if, surv_if = survival_func(vdj_barcode_counts_IF.values)
x_ef, surv_ef = survival_func(vdj_barcode_counts_EF.values)
x_ef_pbdom, surv_ef_pbdom = survival_func(vdj_barcode_counts_EF_PBdom.values)

# Survival function for number of unique vdj_sequences per st_barcode in IF
vdj_per_stbc_IF = df_IGH_IF.groupby('st_barcode')['vdj_sequence'].nunique()
x_vdj_stbc, surv_vdj_stbc = survival_func(vdj_per_stbc_IF.values)

fig, axs = plt.subplots(1, 3, figsize=(9, 3))

# Bar chart (left subplot)
axs[0].bar(['IF'], [sum(values_IF)], color='white', edgecolor='black')
bottom = 0
for v, l, c in zip(values_IF, labels, colors):
    axs[0].bar(['IF'], [v], bottom=bottom, color=c)
    bottom += v

axs[0].bar(['EF (PBs)'], [sum(values_EF_PBdom)], color='white', edgecolor='black')
bottom = 0
for v, l, c in zip(values_EF_PBdom, labels, colors):
    axs[0].bar(['EF (PBs)'], [v], bottom=bottom, color=c)
    bottom += v

axs[0].bar(['EF (all)'], [sum(values_EF)], color='white', edgecolor='black')
bottom = 0
for v, l, c in zip(values_EF, labels, colors):
    axs[0].bar(['EF (all)'], [v], bottom=bottom, color=c)
    bottom += v

axs[0].set_ylabel('Number of clones')
axs[0].set_xticks(['IF', 'EF (PBs)', 'EF (all)'])

# Survival function (middle subplot)
axs[1].plot(x_if, surv_if, label='IF', color='black', linewidth=3)
axs[1].plot(x_ef, surv_ef, label='EF (all)', color=[0.6,0,0], linewidth=3)
axs[1].plot(x_ef_pbdom, surv_ef_pbdom, label='EF (PBs)', color=[1,0.3,0.3], linestyle='--', linewidth=3)
axs[1].set_yscale('log')
#axs[1].set_xscale('log')
axs[1].set_xlabel('# UMIs per clone-location pair, n')
axs[1].set_ylabel('Frac. pairs with >= n UMIs')
axs[1].set_xlim(left=1)
axs[1].set_xticks([1, 10, 20, 30])
axs[1].legend()

# Survival function for unique vdj_sequences per st_barcode (right subplot)
axs[2].plot(x_vdj_stbc, surv_vdj_stbc, color='black', linewidth=3)
axs[2].set_yscale('log')
axs[2].set_xlim(left=1)
axs[2].set_xticks([1,10,20])
axs[2].set_xlabel('# unique clones per spatial BC')
axs[2].set_ylabel('Frac. spatial BCs with >= n clones')

plt.tight_layout()

plt.savefig(r'Figures\Supplementary\SI_Clone_Barcode_Statistics_Subplots.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S2: Lineage calling threshold

# CDR3 clustering

dist_list_CDR3 = []
thresholds = np.array([0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3])

import sys, os, time
import scipy.spatial.distance as ssd

sys.path.append(r"snakemake_workflow\scripts")
sys.path.append(r"snakemake_workflow\scripts\cython_packages")

from pacbio_vdj_utils import *
from pacbio_vdj_utils.cluster_vdj import *

for j in range(len(thresholds)):
    airr_path = "Data/Clustering_Analysis/igblast_filtered_annotated_preprocessed.tsv"
    mat_path = "Data//Clustering_Analysis/cdr3_dist_mats/"
    OUTDIR = "Data/Clustering_Analysis/cdr3_clusters/"
    V_SSTART_MAX = 0
    C_SSTART_MAX = 0
    FRACTIONAL_CUTOFF = thresholds[j]

    SAMPLENAME= "tonsil_vdjs"

    #    large_groups = expand(
    #        "{base}/lineage_clustering/cdr3/{group}_cdr3.npy",
    #        base=wildcards.base,
    #        group=glob_wildcards(
    #            os.path.join(checkpoint_output, "{group}_cdr3.fasta")
    #        ).group,
    #    )
    #    return large_groups

    df = pd.read_table(airr_path, usecols=['sequence',
                                'v_family',
                                'v_sequence_alignment',
                                'j_sequence_alignment',
                                'v_sequence_start',
                                'v_germline_start',
                                'v_germline_alignment',
                                'j_germline_alignment',
                                'cdr3_start',
                                'cdr3_end',
                                'cdr3',
                                'j_sequence_end','locus'])

    df['cdr3_length'] = df['cdr3'].str.len()

    sys.stderr.write(f"Starting with {df.shape[0]} reads...\n")

    df = df[df.v_germline_start <= V_SSTART_MAX + 1]
    try:
        df = df[df['c_sstart'] <= C_SSTART_MAX + 1]
    except KeyError:
        pass
    sys.stderr.write(f"Keeping {df.shape[0]} reads that contain entire VDJ region...\n")

    #df = df[df['v_family'].str.startswith("IGH")]
    #sys.stderr.write("Subsetting to {} reads that map to heavy chain...\n".format(df.shape[0]))

    df['vdj_sequence'] = df.apply(lambda x: x.sequence[int(x.v_sequence_start)-1:
                                        int(x.j_sequence_end)], axis = 1)

    df['v_templated_len'] = df['cdr3_start'] - df['v_sequence_start']
    df['j_templated_len'] = df['j_sequence_end'] - df['cdr3_end']

    # drop reads with gap in templated alignment
    df['sum_gap_len'] = df.v_sequence_alignment.map(lambda x: x.count("-"))
    df['sum_gap_len'] = df['sum_gap_len'] + df.j_sequence_alignment.map(lambda x: x.count("-"))
    df = df[df['sum_gap_len'] < 1]

    sys.stderr.write(f"Keeping {df.shape[0]} reads with no deletions in templated sequences...\n")

    df['sum_insertion_len'] = df.v_germline_alignment.map(lambda x: x.count("-"))
    df['sum_insertion_len'] = df['sum_insertion_len'] + df.j_germline_alignment.map(lambda x: x.count("-"))
    df = df[df['sum_insertion_len'] < 1]

    sys.stderr.write(f"Keeping {df.shape[0]} reads with no insertions in templated sequences...\n")

    unique_vdjs = df[['v_family',
                    'vdj_sequence',
                    'v_templated_len',
                    'j_templated_len',
                    'cdr3',
                    'cdr3_length','locus']].drop_duplicates(ignore_index=True)

    unique_vdjs['cdr3_group'] = unique_vdjs['v_family'] + "_" + unique_vdjs.cdr3_length.astype(str)

    cdr3_group_sizes = unique_vdjs.cdr3_group.value_counts()
    unique_vdjs['cdr3_group_size'] = unique_vdjs.cdr3_group.map(cdr3_group_sizes)

    sys.stderr.write("Verifying that all distance matrices are available...\n")
    for cdr3_group in unique_vdjs.cdr3_group.unique():
        vfam, cdr3_len = cdr3_group.split("_")[0], cdr3_group.split("_")[1]
        binary_matrix_filename = f'{mat_path}/{SAMPLENAME}_{vfam}_{cdr3_len}_cdr3.npy'
        if os.path.exists(binary_matrix_filename):
            pass
        else:
            sys.stderr.write(f"Cannot find the following file {binary_matrix_filename}. Aborting...\n")
            sys.exit(1)

    sys.stderr.write(f"Clustering {unique_vdjs.shape[0]} unique variable sequences...\n")


    TOTAL_CLUSTERS = 0
    unique_vdjs['cluster_id'] = -1

    start = time.time()

    for cdr3_group in unique_vdjs.cdr3_group.unique():
        IGH_flag = not unique_vdjs[(unique_vdjs['cdr3_group'] == cdr3_group) & (unique_vdjs['locus'] == 'IGH')].empty

        vfam, cdr3_len = cdr3_group.split("_")[0], int(cdr3_group.split("_")[1])

        subset = unique_vdjs[unique_vdjs['cdr3_group'] == cdr3_group]

        cdr3_seqs = subset.cdr3.values
        subset_idx = subset.index

        #sys.stderr.write(f"Processing VDJ sequence subset: v_family={vfam}, cdr3_length={cdr3_len}, n={subset.shape[0]}\n")

        binary_matrix_filename = f'{mat_path}/{SAMPLENAME}_{vfam}_{cdr3_len}_cdr3.npy'

        Ds = np.load(binary_matrix_filename,allow_pickle = False)
        n=Ds.shape[0]
        diag_indices = np.diag_indices(n)
        Ds[diag_indices] = np.uint8(255)
        closest_d = np.min(Ds, axis=1)

        Ds[diag_indices] = np.uint8(0)

        if n > 1:
            upper_triangular_indices = np.triu_indices(n, 1)
            if thresholds[j] == 0.2 and IGH_flag: dist_list_CDR3.extend((Ds[upper_triangular_indices]/cdr3_len).tolist())

            #cluster CDR3s
            cutoff = np.uint8(FRACTIONAL_CUTOFF*cdr3_len)
            #Ds = ssd.squareform(Ds)
            
            condensed_distance_matrix = Ds[upper_triangular_indices]

            new_cluster_ids = get_cluster_ids(condensed_distance_matrix,
                                            cutoff=cutoff,
                                            method='single')
        else:
            new_cluster_ids = np.zeros(1)

        unique_vdjs.loc[subset_idx, 'cluster_id'] = TOTAL_CLUSTERS + new_cluster_ids
        TOTAL_CLUSTERS += max(new_cluster_ids) + 1

    sys.stderr.write(f"Clustering CDR3s took {round(time.time() - start)} seconds\n")

    unique_vdjs['templated_v'] = unique_vdjs.apply(lambda x: x.vdj_sequence[0:int(x.v_templated_len)], axis=1)
    unique_vdjs['templated_j'] = unique_vdjs.apply(lambda x: x.vdj_sequence[-int(x.j_templated_len):], axis=1)

    unique_vdjs.to_csv(f"{OUTDIR}/{SAMPLENAME}_unique_vdjs_cdr3_clusters_threshold_{FRACTIONAL_CUTOFF}.tsv", sep= '\t')

    sys.stderr.write("Preparing to cluster templated sequences within cdr3 clusters...\n")

    cluster_sizes = unique_vdjs.cluster_id.value_counts()
    unique_vdjs['cluster_size'] = unique_vdjs.cluster_id.map(cluster_sizes)

    groups = unique_vdjs['cluster_id'].unique()

    start = time.time()

    sys.stderr.write(
        f"Calculating distance matrices for {len(unique_vdjs.cluster_id.unique())} unique clusters...\n")

    for cluster_id in groups:
        subset = unique_vdjs[unique_vdjs['cluster_id'] == cluster_id]
        
        for seq_element in ['templated_v', 'templated_j']:
            BINARY_MATRIX_FILENAME = f'{OUTDIR}/{SAMPLENAME}_{cluster_id}_{seq_element}_threshold_{FRACTIONAL_CUTOFF}.npy'

            templated_seqs = subset[seq_element].values
            n = len(templated_seqs)
            subset_idx = subset.index

            sys.stderr.write(
                f"Processing VDJ sequence subset: cluster_id={cluster_id}, n={n}\n")

            sys.stderr.write("\t computing distance matrix locally...\n")

            D = np.zeros((n,n),np.uint8)

            for i in range(n):
                for j in range(i):
                    d = distance(templated_seqs[i], templated_seqs[j])
                    d = np.uint8(min(d, 255))
                    D[i,j] = d
                    D[j,i] = d

            np.save(BINARY_MATRIX_FILENAME, D, allow_pickle=False)

        sys.stderr.write(f"\t\t this took {time.time() - start} seconds\n")


    sys.stderr.write(f'Done! Execution took {(time.time() - start)} seconds\n')

dist_list_CDR3 = np.array(dist_list_CDR3)

In [None]:
# VJ filtering for final lineage IDs

import scipy.spatial.distance as ssd

from pacbio_vdj_utils.cluster_vdj import *
dist_list_VJ = []


for j in range(len(thresholds)):

    OUTDIR = "Data/Clustering_Analysis/final_lineage_ids/tonsil_vdjs_threshold_" + str(thresholds[j])
    CONTIGFILE = "Data/Clustering_Analysis/igblast_filtered_annotated_preprocessed.tsv"
    SAMPLENAME= "tonsil_vdjs"
    UNIQUE_VDJ_FILE="Data/Clustering_Analysis/cdr3_clusters/tonsil_vdjs_unique_vdjs_cdr3_clusters_threshold_" + str(thresholds[j]) + ".tsv"
    #MATRIXFILES = args.matrixfiles
    MATRIXDIR = "Data/Clustering_Analysis/cdr3_clusters/"
    FRACTIONAL_CUTOFF = thresholds[j]

    ####################################################################################################
    def get_current_memory_usage():
        ''' Memory usage in GB '''
        with open('/proc/self/status') as memusage_file:
            memusage = memusage_file.read().split('VmRSS:')[1].split('\n')[0][:-3]
        return int(memusage.strip())/1024/1024

    ####################################################################################################

    # read only unique vdjs first


    start = time.time()
    unique_vdjs = pd.read_table(UNIQUE_VDJ_FILE)

    sys.stderr.write("Verifying that all distance matrices are available...\n")
    for cluster_id in unique_vdjs.cluster_id.unique():
        for s in ['v', 'j']:
            BINARY_MATRIX_FILENAME = f'{MATRIXDIR}/{SAMPLENAME}_{cluster_id}_templated_{s}_threshold_{FRACTIONAL_CUTOFF}.npy'
            if os.path.exists(BINARY_MATRIX_FILENAME):
                pass
            else:
                sys.stderr.write(f"Cannot find the following file {BINARY_MATRIX_FILENAME}. Aborting...\n")
                sys.exit(1)

    sys.stderr.write(f"Clustering {unique_vdjs.shape[0]} unique variable sequences...\n")

    unique_vdjs['templated_v'] = unique_vdjs.apply(lambda x: x.vdj_sequence[0:int(x.v_templated_len)], axis=1)
    unique_vdjs['templated_j'] = unique_vdjs.apply(lambda x: x.vdj_sequence[-int(x.j_templated_len):], axis=1)
    unique_vdjs['templated_vj'] = unique_vdjs['templated_v'] + "+" + unique_vdjs['templated_j']

    TOTAL_LINEAGES = 0
    unique_vdjs['lineage_id'] = -1

    for cluster_id in unique_vdjs.cluster_id.value_counts().index:
        subset = unique_vdjs[unique_vdjs['cluster_id'] == cluster_id]
        IGH_flag = not subset[subset['locus'] == 'IGH'].empty

        V_BINARY_MATRIX_FILENAME = f'{MATRIXDIR}/{SAMPLENAME}_{cluster_id}_templated_v_threshold_{FRACTIONAL_CUTOFF}.npy'
        J_BINARY_MATRIX_FILENAME = f'{MATRIXDIR}/{SAMPLENAME}_{cluster_id}_templated_j_threshold_{FRACTIONAL_CUTOFF}.npy'
    
        templated_seqs = subset.templated_vj.values
        n = len(templated_seqs)
        subset_idx = subset.index

        longest_templated_sequence = subset.templated_vj.str.len().max()

        sys.stderr.write(
            f"Processing VDJ sequence subset: cluster_id={cluster_id}, n={n}\n")

        Ds = np.load(V_BINARY_MATRIX_FILENAME, allow_pickle = False)
        DJs = np.load(J_BINARY_MATRIX_FILENAME, allow_pickle = False)
        
        R = np.uint8(255) * np.ones((n,n), np.uint8) - Ds
        Ds = Ds + DJs
        Ds[R<DJs] = np.uint8(255)

        n=Ds.shape[0]
        diag_indices = np.diag_indices(n)
        Ds[diag_indices] = np.uint8(255)
        closest_d = np.min(Ds, axis=1)

        Ds[diag_indices] = np.uint8(0)

        if n > 1:
            #cluster CDR3s
            cutoff = np.uint8(FRACTIONAL_CUTOFF*longest_templated_sequence)
            #Ds = ssd.squareform(Ds)
            upper_triangular_indices = np.triu_indices(n, 1)
            condensed_distance_matrix = Ds[upper_triangular_indices]

            if FRACTIONAL_CUTOFF == 0.2 and IGH_flag: dist_list_VJ.extend((condensed_distance_matrix/longest_templated_sequence).tolist())

            new_lineage_ids = get_cluster_ids(condensed_distance_matrix,
                                            cutoff=cutoff,
                                            method='single')
        else:
            new_lineage_ids = np.zeros(1)
        unique_vdjs.loc[subset_idx, 'lineage_id'] = TOTAL_LINEAGES + new_lineage_ids
        TOTAL_LINEAGES += max(new_lineage_ids) + 1

        sys.stderr.write(f"[{time.time() - start}s]  found {new_lineage_ids.max() + 1} new lineages...\n")

    sys.stderr.write(f"Clustering templated sequences took {round(time.time() - start)} seconds\n")


    vdj_families = unique_vdjs[['vdj_sequence','lineage_id']].set_index('vdj_sequence').to_dict()
    vdj_families = vdj_families['lineage_id']

    # read whole dataframe and append lineage ids
    df = pd.read_table(CONTIGFILE)
    df['vdj_sequence'] = df.apply(lambda x: x.sequence[int(x.v_sequence_start)-1:
                                        int(x.j_sequence_end)], axis = 1)

    df['lineage_id'] = df.vdj_sequence.map(vdj_families)
    df = df[df.lineage_id.notna()]

    df.to_csv(f'{OUTDIR}.tsv', sep = '\t', index = False)

dist_list_VJ = np.array(dist_list_VJ)


In [None]:
fname = "Data/Preprocessed_Data/cluster_dists.pkl"
with open(fname, "rb") as f:
    dist_lists = pickle.load(f)
    dist_list_CDR3 = dist_lists["dist_list_CDR3"]
    dist_list_VJ = dist_lists["dist_list_VJ"]

x1, y1, y_err1 = pdf_histogram(dist_list_CDR3, np.linspace(0, 1, 30))
x2, y2, y_err2 = pdf_histogram(dist_list_VJ, np.linspace(0, 0.3, 30))

n_lineages_threshold = np.zeros(len(thresholds))
all_lineage_sizes = []
for j in range(len(thresholds)):
    OUTDIR = "Data/Clustering_Analysis/final_lineage_ids/tonsil_vdjs_threshold_" + str(thresholds[j])
    df = pd.read_table(f'{OUTDIR}.tsv')
    n_lineages_threshold[j] = df[df['locus']=='IGH'].lineage_id.nunique()

    lineage_sizes = np.array(df[df['locus']=='IGH'].groupby('lineage_id').size().values)
    all_lineage_sizes.append(lineage_sizes)

n_extra_bins = 10
bin_edges = np.append(np.arange(1,10),np.exp(np.linspace(np.log(10),1+np.max(np.log(all_lineage_sizes[-1])),n_extra_bins)))
bin_edges[9] = 10

x3, y3, y_err3 = pdf_histogram(all_lineage_sizes[0], bin_edges)
x4, y4, y_err4 = pdf_histogram(all_lineage_sizes[5], bin_edges)
x5, y5, y_err5 = pdf_histogram(all_lineage_sizes[-1], bin_edges)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

# First subplot: CDR3 and VJ distance distributions
axs[0].plot(x1, y1, color='black', linewidth=3, label='CDR3 distance')
#axs[0].plot(x2, y2, color='gray', linewidth=3, label='VJ distance after CDR3 clustering')
axs[0].set_xlabel('Genetic distance')
axs[0].set_ylabel('Probability density')
axs[0].set_ylim([0.001, 21])
axs[0].set_xlim([0, 1])
axs[0].set_yscale('log')
#axs[0].legend()
axs[0].vlines(0.2, 0, 25, color='red', linestyle='--', linewidth=2, label='x = 0.2')

# Second subplot: Number of lineages vs threshold
axs[1].plot(thresholds, n_lineages_threshold, color='black', marker='o', linewidth=2, label='Number of lineages')
axs[1].set_xlabel('Lineage clustering threshold')
axs[1].set_ylabel('Number inferred lineages')
axs[1].axvline(x=0.2, color='red', linestyle='--', linewidth=2, label='x = 0.2')
axs[1].set_ylim([1, 10000])
axs[1].set_yticks([1, 5000, 10000])
axs[1].set_xticks([0.1, 0.2, 0.3])

# Third subplot: Lineage size distributions for different thresholds
axs[2].plot(x3 * np.random.normal(1, 0.05, size=len(x3)), y3, color=[0.2, 0.2, 0.8], linewidth=3, alpha=0.8, label='Threshold = 0.1', marker='o')
axs[2].plot(x4 * np.random.normal(1, 0.05, size=len(x4)), y4, color=[0, 0, 0], linewidth=3, alpha=0.8, label='Threshold = 0.2', marker='o')
axs[2].plot(x5 * np.random.normal(1, 0.05, size=len(x5)), y5, color=[0.8, 0.2, 0.2], linewidth=3, alpha=0.8, label='Threshold = 0.3', marker='o')
axs[2].legend()
axs[2].set_yscale('log')
axs[2].set_xscale('log')
axs[2].set_xlabel('Lineage size')
axs[2].set_ylabel('Probability density')

plt.tight_layout()

plt.savefig(r'Figures\Supplementary\Lineage_Threshold.pdf', format='pdf')

plt.show()



In [None]:
fname = 'Data\Preprocessed_Data/pairwise_distances.pickle'
with open(fname, 'rb') as f:
    pairwise_distances = pickle.load(f)
    intra_dists = pairwise_distances['intra_dists']
    all_inter_dists = pairwise_distances['all_inter_dists']
    inter_dists = pairwise_distances['inter_dists']

typ_intra_dist = np.median(intra_dists)

from scipy.spatial.distance import pdist

df_TRB_IF = df_TRB[
    (df_TRB['Follicles_seurat'] != 'nonFoll') &
    (df_TRB['Follicles_seurat'].apply(lambda x: isinstance(x, str)))
]

In [None]:
# FIGURE S3: Further analysis of migration distances

from scipy.spatial.distance import cdist

# Label all spatial barcodes adjacent to extrafollicular regions
near_cutoff = 700

metadata_df['near_EF'] = metadata_df['Follicles_seurat'].apply(lambda x: x == 'nonFoll' or not isinstance(x, str))

for section in metadata_df['section'].unique():
    section_rows = metadata_df[metadata_df['section'] == section]
    ef_mask = section_rows['near_EF'].values
    ef_coords = section_rows.loc[ef_mask, ['x', 'y']].astype(float).values
    all_coords = section_rows[['x', 'y']].astype(float).values

    if len(ef_coords) == 0:
        continue

    dists = cdist(all_coords, ef_coords)
    near_any_ef = (dists <= near_cutoff).any(axis=1)
    indices_to_update = section_rows.index[near_any_ef]
    metadata_df.loc[indices_to_update, 'near_EF'] = True

df_IGH_inner = df_IGH_IF[df_IGH_IF['st_barcode'].isin(metadata_df[metadata_df['near_EF'] == False]['spatial_bc'])]

fname = 'Data\Preprocessed_Data/pairwise_distances.pickle'
with open(fname, 'rb') as f:
    pairwise_distances = pickle.load(f)
    intra_dists = pairwise_distances['intra_dists']
    all_inter_dists = pairwise_distances['all_inter_dists']
    inter_dists = pairwise_distances['inter_dists']

typ_intra_dist = np.median(intra_dists)

all_inter_dists_inner = np.zeros(len(df_IGH_inner)**2)
inter_dists_inner = np.zeros(len(df_IGH_inner)**2)
all_inter_dists_crosssec = np.zeros(len(df_IGH_IF)**2)
inter_dists_crosssec = np.zeros(len(df_IGH_IF)**2)
k1 = 0
k2 = 0

grouped = df_IGH_inner.groupby('follicle')
for name1, group1 in grouped:
    for name2, group2 in grouped:
        if name1 >= name2: continue
        for i, row1 in group1.iterrows():
            for j, row2 in group2.iterrows():
                all_inter_dists_inner[k1] = euclidean_distance(row1, row2)
                k1 += 1
                if row1['lineage_id'] == row2['lineage_id']:
                    inter_dists_inner[k2] = euclidean_distance(row1, row2)
                    k2 += 1

all_inter_dists_inner = all_inter_dists_inner[:k1]
inter_dists_inner = inter_dists_inner[:k2]

k1 = 0
k2 = 0

grouped = df_IGH_IF.groupby('follicle')
for name1, group1 in grouped:
    for name2, group2 in grouped:
        if name1 >= name2: continue
        for i, row1 in group1.iterrows():
            for j, row2 in group2.iterrows():
                if row1['section'] == row2['section']: continue
                all_inter_dists_crosssec[k1] = euclidean_distance(row1, row2)
                k1 += 1
                if row1['lineage_id'] == row2['lineage_id']:
                    inter_dists_crosssec[k2] = euclidean_distance(row1, row2)
                    k2 += 1

all_inter_dists_crosssec = all_inter_dists_crosssec[:k1]
inter_dists_crosssec = inter_dists_crosssec[:k2]

In [None]:
fname = 'Data\Preprocessed_Data/pairwise_distances.pickle'
with open(fname, 'rb') as f:
    pairwise_distances = pickle.load(f)
    intra_dists = pairwise_distances['intra_dists']
    all_inter_dists = pairwise_distances['all_inter_dists']
    inter_dists = pairwise_distances['inter_dists']

typ_intra_dist = np.median(intra_dists)

fname = 'Data\Preprocessed_Data/pairwise_distances_misc.pickle'

with open(fname, 'rb') as f:
    pairwise_distances_misc = pickle.load(f)
    all_inter_dists_inner = pairwise_distances_misc['all_inter_dists_inner']
    inter_dists_inner = pairwise_distances_misc['inter_dists_inner']
    all_inter_dists_crosssec = pairwise_distances_misc['all_inter_dists_crosssec']
    inter_dists_crosssec = pairwise_distances_misc['inter_dists_crosssec']

bin_edges = np.linspace(0,0.1+max(np.max(inter_dists_inner)/typ_intra_dist, np.max(all_inter_dists_inner)/typ_intra_dist),40)
x_dist1, y_dist1, y_dist_err1 = pdf_histogram(inter_dists_inner/typ_intra_dist, bin_edges)
x_dist2, y_dist2, y_dist_err2 = pdf_histogram(all_inter_dists_inner/typ_intra_dist, bin_edges)

bin_edges = np.linspace(0,0.1+max(np.max(inter_dists_crosssec)/typ_intra_dist, np.max(all_inter_dists_crosssec)/typ_intra_dist),40)
x_dist3, y_dist3, y_dist_err3 = pdf_histogram(inter_dists_crosssec/typ_intra_dist, bin_edges)
x_dist4, y_dist4, y_dist_err4 = pdf_histogram(all_inter_dists_crosssec/typ_intra_dist, bin_edges)


palette = sb.color_palette("colorblind", 4)
colors = [palette[0], palette[1]]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharey=True)

ax1.plot(x_dist1, y_dist1, color=colors[0], label='Cross-follicle, within lineage', alpha=0.75, linewidth=3)
ax1.plot(x_dist2, y_dist2, color=colors[1], label='Cross-follicle, all pairs', alpha=0.75, linewidth=3)
ax1.set_xlabel('Euclidean distance')
ax1.set_ylabel('Probability density')
ax1.set_yticks([0, 0.1, 0.2, 0.3])
ax1.set_ylim([0, 0.36])
ax1.set_xlim([0, 10])
ax1.legend()

ax2.plot(x_dist3, y_dist3, color=colors[0], label='Cross-follicle, within lineage', alpha=0.75, linewidth=3)
ax2.plot(x_dist4, y_dist4, color=colors[1], label='Cross-follicle, all pairs', alpha=0.75, linewidth=3)
ax2.set_xlabel('Euclidean distance (cross-section)')
ax2.set_xlim([0, 10])


plt.tight_layout()

#plt.savefig(r'Figures\Supplementary\Locality_Analysis_With_Filtering.pdf', format='pdf')

plt.show()


In [None]:
# FIGURE S4: T cell spatial distribution

fname = 'Data\Preprocessed_Data/pairwise_distances.pickle'
with open(fname, 'rb') as f:
    pairwise_distances = pickle.load(f)
    intra_dists = pairwise_distances['intra_dists']
    all_inter_dists = pairwise_distances['all_inter_dists']
    inter_dists = pairwise_distances['inter_dists']

typ_intra_dist = np.median(intra_dists)

from scipy.spatial.distance import pdist

df_TRB_IF = df_TRB[
    (df_TRB['Follicles_seurat'] != 'nonFoll') &
    (df_TRB['Follicles_seurat'].apply(lambda x: isinstance(x, str)))
]

dist_list_full = []
dist_list = []
dist_list_avg = []
for clone_id in df_TRB_IF['cloneId'].unique():
    dist_list_temp = []
    subset = df_TRB_IF[df_TRB_IF['cloneId'] == clone_id]
    if len(subset) > 1:
        for i, row1 in subset.iterrows():
            for j, row2 in subset.iterrows():
                if i <= j or row1['Follicles_seurat'] == row2['Follicles_seurat']: continue
                dist = np.sqrt((row1['x'] - row2['x'])**2 + (row1['y'] - row2['y'])**2)
                dist_list_temp.append(dist)
    if len(dist_list_temp) > 0:
        dist_list.extend(dist_list_temp)
        dist_list_avg.append(np.mean(dist_list_temp))
dist_list = np.array(dist_list)/typ_intra_dist
dist_list_avg = np.array(dist_list_avg)/typ_intra_dist

for foll in df_TRB_IF['Follicles_seurat'].unique():
    for foll2 in df_TRB_IF['Follicles_seurat'].unique():
        if foll >= foll2: continue
        subset1 = df_TRB_IF[df_TRB_IF['Follicles_seurat'] == foll]
        subset2 = df_TRB_IF[df_TRB_IF['Follicles_seurat'] == foll2]
        for i, row1 in subset1.iterrows():
            for j, row2 in subset2.iterrows():
                dist = np.sqrt((row1['x'] - row2['x'])**2 + (row1['y'] - row2['y'])**2)
                dist_list_full.append(dist)
dist_list_full = np.array(dist_list_full)/typ_intra_dist



In [None]:
fname = 'Data\Preprocessed_Data/T_cell_distance.pickle'
with open(fname, 'rb') as f:
    data = pickle.load(f)
    dist_list_full = data['dist_list_full']
    dist_list = data['dist_list']
    dist_list_avg = data['dist_list_avg']

bins = np.linspace(0, max(dist_list_full)+0.1, 50)
bins2 = np.linspace(0, max(dist_list_full)+0.1, 25)

x3, y3, _ = pdf_histogram(dist_list_full[dist_list_full>0], bins)
x4, y4, _ = pdf_histogram(dist_list[dist_list>0], bins)
x5, y5, _ = pdf_histogram(dist_list_avg[dist_list_avg>0], bins2)

In [None]:
palette = sb.color_palette("colorblind", 4)
colors = [palette[0], palette[1], palette[2], palette[3]]

plt.figure(figsize=(6, 4))

#plt.plot(x, y, color=colors[1], label='All T cell pairs',alpha=0.75,linewidth=3)
#plt.plot(x2, y2, color=colors[0], label='T cell clones',alpha=0.75,linewidth=3)

plt.plot(x4, y4, color=colors[0], label='Clonal UMI pairs',alpha=0.75,linewidth=3)
plt.plot(x5, y5, color=[0.5,0.75,1], label='Clonal pairs, avg. across lineages',alpha=0.75,linewidth=3)

plt.plot(x3, y3, color=colors[1], label='All UMI pairs',alpha=0.75,linewidth=3)

plt.xlabel('Distance between spatial barcodes')
plt.ylabel('T cell pair density')
plt.legend()
plt.xlim([0,13])
plt.ylim([0,0.26])
plt.yticks([0,0.1,0.2])

plt.savefig(r'Figures\Supplementary\SI_T_Cell_Spatial.pdf', format='pdf')

plt.show()

In [None]:
p_clone_same = 0
p_clone_overall = 0

denom = 0

grouped = df_TRB_IF.groupby('Follicles_seurat')
for name, group in grouped:
    clone_counts = np.array(group['cloneId'].value_counts())
    p_clone_same += np.sum(clone_counts * (clone_counts - 1))
    denom += np.sum(clone_counts) * (np.sum(clone_counts) - 1)
p_clone_same /= denom

clone_counts_all = np.array(df_TRB_IF['cloneId'].value_counts())
p_clone_overall = np.sum(clone_counts_all * (clone_counts_all - 1))
p_clone_overall /= np.sum(clone_counts_all) * (np.sum(clone_counts_all) - 1)

plt.figure()
fig, ax = plt.subplots(figsize=(4, 3))
ax.bar(['Same GC', 'All UMI pairs'], [p_clone_same, p_clone_overall], color=['black', 'red'],alpha=0.75)
ax.set_ylabel('Fraction clonal pairs')
ax.set_ylim([0,0.0015])
ax.set_yticks([0, 0.0005, 0.001, 0.0015])
ax.set_yticklabels(['0%', '0.05%', '0.1%', '0.15%'])
plt.savefig(r'Figures\Supplementary\T_Cell_Clone_Prob.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S5A: Plasma fraction distribution

plasma_fracs_IF = metadata_df[(metadata_df['Follicles_seurat'] != 'nonFoll') & (metadata_df['Follicles_seurat'].apply(lambda x: isinstance(x, str)))]['plasma_frac']
plasma_fracs_EF = metadata_df[(metadata_df['Follicles_seurat'] == 'nonFoll') | (metadata_df['Follicles_seurat'].apply(lambda x: not isinstance(x, str)))]['plasma_frac']

bin_edges = np.linspace(0, 1, 20)

x_dist1, y_dist1, y_dist_err1 = pdf_histogram(plasma_fracs_IF, bin_edges)
x_dist2, y_dist2, y_dist_err2 = pdf_histogram(plasma_fracs_EF, bin_edges)


plt.plot(x_dist1, y_dist1, 'gray', linewidth=3,label='Intrafollicular')
plt.plot(x_dist2, y_dist2, 'black', linewidth=3,label='Extrafollicular')

plt.xlabel('Plasma fraction')
plt.ylabel('Probability density')
plt.legend()
plt.xlim([0,1])
plt.yticks([0, 4, 8, 12])
plt.ylim([0,14])

plt.axvline(x=0.9, color=[0.3,0.3,0.7], linestyle='--', linewidth=2)

plt.savefig(r'Figures\Supplementary\Plasma_Dom_Distribution.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S5B: Spatial location of plasmablasts

section_df = metadata_df[metadata_df['section'] == '17']

for _, row in section_df.iterrows():
    x = float(row['x'])
    y = float(row['y'])
    if (row['Follicles_seurat'] != 'nonFoll' and isinstance(row['Follicles_seurat'], str)) or row['plasma_frac'] > 0.9: continue
    plt.scatter(x, y, color=[0.7,0.7,0.7], s=20)

for _, row in section_df.iterrows():
    x = float(row['x'])
    y = float(row['y'])
    if (row['Follicles_seurat'] != 'nonFoll' and isinstance(row['Follicles_seurat'], str)) or row['plasma_frac'] <= 0.9: continue
    plt.scatter(x, y, color=[0.3,0.3,0.7], s=20)

for _, row in section_df.iterrows():
    x = float(row['x'])
    y = float(row['y'])
    if (row['Follicles_seurat'] == 'nonFoll' or not isinstance(row['Follicles_seurat'], str)): continue
    plt.scatter(x, y, color=[1,0.8,1], s=20)

plt.gca().set_aspect('equal', 'box')

plt.tight_layout()

plt.savefig(r'Figures\Supplementary\Plasma_Dom_Visual.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S5C: Lineage age scales with plasmabalst relatives

lineage_metrics = []
for lineage_id in df_IGH_IF['lineage_id'].unique():
    n_ef = len(df_IGH_EF_PBdom[df_IGH_EF_PBdom['lineage_id'] == lineage_id])
    n_total = len(df_combined[df_combined['lineage_id'] == lineage_id])
    ratio = n_ef / n_total

    mean_div = np.mean([
        int(row['v_mismatch']) / len(row['v_sequence_no_trunc'])
        for _, row in df_IGH_IF[df_IGH_IF['lineage_id'] == lineage_id].iterrows()
    ])
    lineage_metrics.append([mean_div, ratio])
lineage_metrics = np.array(lineage_metrics)

bins = np.linspace(0,0.12,21)
mean_divs = np.zeros(len(bins)-1)
prop_EF = np.zeros(len(bins)-1)
err_EF = np.zeros(len(bins)-1)
for i in range(0,len(bins)-1):
    mask = np.logical_and(lineage_metrics[:,0] >= bins[i], lineage_metrics[:,0] < bins[i+1])

    mean_divs[i] = np.mean(lineage_metrics[mask,0])
    prop_EF[i] = np.mean(lineage_metrics[mask,1])
    err_EF[i] = np.std(lineage_metrics[mask,1]) / np.sqrt(np.sum(mask))

plt.plot(mean_divs, prop_EF, color='black',marker='o')
plt.fill_between(mean_divs, prop_EF - err_EF, prop_EF + err_EF, color='gray', alpha=0.3)

plt.xticks([0,0.03,0.06,0.09,0.12], [r'$0\%$', r'$3\%$', r'$6\%$', r'$9\%$', r'$12\%$'])
plt.xlabel('Mean lineage divergence from root')
plt.ylabel('Plasmablast proportion')
plt.ylim([0,0.5])
plt.xlim([0,0.12])

plt.savefig(r'Figures\Supplementary\Lineage_Age_PBs.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S6: Clonal burst pairwise analysis

from scipy.optimize import curve_fit

sequence_info = []
for sequence in df_combined['v_sequence_no_trunc'].unique():
    subset = df_combined[df_combined['v_sequence_no_trunc'] == sequence]
    num_appearances = len(subset)
    num_distinct_follicles = subset[subset['follicle'] != 'EF']['follicle'].nunique()
    div_from_root = int(subset.iloc[0]['v_mismatch'])/len(sequence)
    num_appearances_EF = len(subset[subset['follicle'] == 'EF'])

    prob_distinct = 0
    if num_distinct_follicles > 1:
        foll_list = subset[subset['follicle'] != 'EF']['follicle'].unique()
        for foll1 in range(len(foll_list)):
            for foll2 in range(foll1+1, len(foll_list)):
                prob_distinct += 2 * len(subset[subset['follicle'] == foll_list[foll1]])/(num_appearances - num_appearances_EF) * len(subset[subset['follicle'] == foll_list[foll2]])/(num_appearances - num_appearances_EF - 1)

    sequence_info.append([sequence, num_appearances-num_appearances_EF, num_distinct_follicles, div_from_root, num_appearances_EF, prob_distinct])

clonal_burst_df = pd.DataFrame(sequence_info, columns=['v_sequence_no_trunc', 'num_appearances_IF', 'num_distinct_follicles', 'div_from_root', 'num_appearances_EF', 'prob_distinct'])
clonal_burst_df = clonal_burst_df[clonal_burst_df['num_appearances_IF'] + clonal_burst_df['num_appearances_EF'] >= 2]
clonal_burst_df = clonal_burst_df[clonal_burst_df['div_from_root'] > 0.01]

mask = clonal_burst_df['num_appearances_IF'] >= 2
overall_prob = np.sum(clonal_burst_df[mask]['num_distinct_follicles'] > 1)/np.sum(mask)

bins = np.array([2,3,4,5,6,7,8,9,10,15,20,30,40,60,80,100,120])
x = np.zeros(len(bins)-1)
prop_pairs_migrant = np.zeros(len(bins)-1)
prop_pairs_migrant_err = np.zeros(len(bins)-1)

for i in range(0,len(bins)-1):
    mask = np.logical_and(clonal_burst_df['num_appearances_IF'] >= bins[i], clonal_burst_df['num_appearances_IF'] < bins[i+1])
    prop_pairs_migrant[i] = np.mean(clonal_burst_df[mask]['prob_distinct'])
    prop_pairs_migrant_err[i] = np.std(clonal_burst_df[mask]['prob_distinct'])/np.sqrt(np.sum(mask))
    x[i] = np.mean(clonal_burst_df[mask]['num_appearances_IF'])


plt.errorbar(x, prop_pairs_migrant, prop_pairs_migrant_err, marker='o', linewidth=2, linestyle='none', markersize=7, color='black', capsize=5)
plt.ylim([0,1])
plt.xscale('log')
plt.xticks([2, 3, 5,10, 20, 30, 50, 100], ['2', '3','5', '10', '20', '30','50', '100'])
plt.xlabel('Clonal burst size')
plt.ylabel('Frac. cross-follicle pairs in burst')

plt.savefig(r'Figures\Supplementary\Clonal_Burst_Pairwise.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S7A: Power law distributions of lineage sizes

foll_stats = df_IGH_IF.groupby('follicle').agg(
    n_UMIs = ('follicle','size'),
    n_clones = ('vdj_sequence', 'nunique'),
    n_lineages = ('lineage_id', 'nunique')
).reset_index()
foll_stats = foll_stats.sort_values(by='n_UMIs', ascending=False)

EF_stats = np.array([len(df_IGH_EF), df_IGH_EF['vdj_sequence'].nunique(), df_IGH_EF['lineage_id'].nunique()])

lineage_stats = df_IGH_IF.groupby('lineage_id').agg(
    n_UMIs=('lineage_id', 'size'),
    n_follicles=('follicle', 'nunique'),
    n_vdjs = ('vdj_sequence', 'nunique')
).reset_index()

ef_counts = df_IGH_EF['lineage_id'].value_counts().rename('EF_UMIs').reset_index()
ef_counts.columns = ['lineage_id', 'EF_UMIs']

ef_vdjs = df_IGH_EF.groupby('lineage_id')['vdj_sequence'].nunique().reset_index().rename(columns={'vdj_sequence': 'EF_VDJs'})
ef_counts = ef_counts.merge(ef_vdjs, on='lineage_id', how='left').fillna(0)
ef_counts['EF_VDJs'] = ef_counts['EF_VDJs'].astype(int)

lineage_stats = lineage_stats.merge(ef_counts, on='lineage_id', how='left').fillna(0)
lineage_stats['EF_UMIs'] = lineage_stats['EF_UMIs'].astype(int)
lineage_stats['EF_VDJs'] = lineage_stats['EF_VDJs'].astype(int)

n_extra_bins = 10
bin_edges = np.append(np.arange(1,10),np.exp(np.linspace(np.log(10),1+np.max(np.log(lineage_stats['n_UMIs'])),n_extra_bins)))
bin_edges[9] = 10

x_all, y_all, y_err_all = pdf_histogram(lineage_stats['n_UMIs'], bin_edges)
x_1, y_1, y_err_1 = pdf_histogram(lineage_stats[lineage_stats['n_follicles'] == 1]['n_UMIs'], bin_edges)
x_2, y_2, y_err_2 = pdf_histogram(lineage_stats[lineage_stats['n_follicles'] == 2]['n_UMIs'], bin_edges)
x_3, y_3, y_err_3 = pdf_histogram(lineage_stats[lineage_stats['n_follicles'] == 3]['n_UMIs'], bin_edges)
x_4, y_4, y_err_4 = pdf_histogram(lineage_stats[lineage_stats['n_follicles'] >= 4]['n_UMIs'], bin_edges)

palette = sb.color_palette("colorblind", 4)
colors = [palette[0], palette[1], palette[2], palette[3]]

plt.figure(figsize=(6, 6))
plt.plot(x_all, y_all, marker='o', color='black', label='All',alpha=0.75, linewidth=4)

popt_all, _ = curve_fit(single_power_law_pdf, x_all, np.log(y_all), p0=2.)
alpha_all = popt_all[0]
y_fit_all = single_power_law_pdf(x_all, alpha_all, logY=False)
print('alpha =', alpha_all)

plt.plot(x_all, y_fit_all, color=[0.8,0,0.8], linestyle='--', linewidth=2)

plt.plot(x_1, y_1, marker='o', color=colors[0], label='1-follicle',alpha=0.75, linewidth=4)
plt.plot(x_2, y_2, marker='o', color=colors[1], label='2-follicle',alpha=0.75,linewidth=4)
plt.plot(x_3, y_3, marker='o', color=colors[2], label='3-follicle',alpha=0.75,linewidth=4)
plt.plot(x_4, y_4, marker='o', color='gray', label='4+-follicle',alpha=0.75,linewidth=4)

plt.xlabel('Number of UMIs')
plt.ylabel('Probability density')

plt.legend()

#plt.gca().add_patch(plt.Rectangle((0.9, 10**-3.5), 9.3, 1.2 - 10**-3.5, fill=False, edgecolor='gray', linewidth=1,linestyle='--'))

plt.xscale('log')
plt.yscale('log')
plt.yticks([])

#plt.savefig(r'Figures\Supplementary\Full_Power_Laws.pdf', format='pdf')
plt.show()

In [None]:
# FIGURE S7B: Overlapping power law fit

x_for_fit = np.append(np.append(x_1[:9],x_2[:8]),x_3[:7])
y_for_fit = np.append(np.append(y_1[:9],y_2[:8]),y_3[:7])
popt, pcov = curve_fit(power_law_combined_pdf, x_for_fit, np.log(y_for_fit), p0=2.)
alpha = popt[0]
print(alpha)

y_fit = power_law_combined_pdf(x_for_fit, alpha, logY=False, cond_adj = False)

plt.figure(figsize=(5, 4))
plt.plot(x_1[:9]*np.random.normal(1,0.05,size=9), y_1[:9], marker='o', color=colors[0], label='1-follicle',linestyle='none',markersize=7)
plt.plot(x_2[:8]*np.random.normal(1,0.05,size=8), y_2[:8]*(1-y_fit[0]), marker='o', color=colors[1], label='2-follicle',linestyle='none',markersize=7)
plt.plot(x_3[:7]*np.random.normal(1,0.05,size=7), y_3[:7]*(1-y_fit[0]-y_fit[1]), marker='o', color=colors[2], label='3-follicle',linestyle='none', markersize=7)


plt.plot(x_for_fit[:9], y_fit[:9], color='black', linestyle='--')
#plt.plot(x_for_fit[9:17], y_fit[9:17], color='black', linestyle='--')
#plt.plot(x_for_fit[17:], y_fit[17:], color='black', linestyle='--')


plt.xlabel('Number of UMIs')
plt.ylabel('Scaled probability')

plt.legend(loc = 'lower left')

plt.xscale('log')
plt.yscale('log')

plt.savefig(r'Figures\Supplementary\Overlapping_Power_Law.pdf', format='pdf')
plt.show()

In [None]:
# FIGURE S8: Survival functions with filtering and alternative models

foll_names = df_IGH_IF['follicle'].unique()

migration_count = np.zeros((len(lineage_ids), len(foll_names), len(foll_names)), dtype=int)
migration_denoms = np.zeros((len(lineage_ids), len(foll_names)))
branch_start_ends_mig = None

for ind in range(len(lineage_ids)):

    lin_id = int(float(lineage_ids[ind]))
    tree = Phylo.read("Data/Lineage_Seqs/Lineage_Trees/tree_lin_" + str(lin_id) + "_no_ef.newick", "newick")

    root_flag = False
    for clade in tree.find_clades():
        if clade.name and clade.name.startswith("ROOT"):
            tree.root_with_outgroup(clade)
            if clade.name == "ROOT__INFERRED": tree.collapse(clade)
            root_flag = True
            break
    if not root_flag: print("Error: No root found")

    label_tips(tree, df_IGH_IF, lineage_ids[ind], ignore_EF = True)
    fill_internal_nodes(tree, lineage_ids[ind], df_IGH_IF, track_dominant = True)

    #visualize_tree(tree, df_IGH_IF, lineage_ids[ind], ignore_EF = True, use_internal_nodes=True)

    branches = get_branch_annotations(tree)

    tot_migrations = 0
    tot_branch_len = 0
    for branch in branches:
        if branch['branch_length'] == 0: continue

        par_ind = np.argmax(foll_names == branch['parent_foll'])
        child_ind = np.argmax(foll_names == branch['child_foll'])
        migration_denoms[ind, par_ind] += branch['branch_length']

        if branch['parent_foll'] != branch['child_foll']:
            migration_count[ind,par_ind,child_ind] += 1
            branch_data = np.array([branch['parent_depth'], branch['parent_depth']+branch['branch_length'], 1])
        else:
            branch_data = np.array([branch['parent_depth'], branch['parent_depth']+branch['branch_length'], 0])

        if branch_start_ends_mig is None:
            branch_start_ends_mig = branch_data
        else:
            branch_start_ends_mig = np.vstack((branch_start_ends_mig, branch_data))

lin_denoms = np.sum(migration_denoms, axis=1)
migration_ct_lins = np.sum(migration_count, axis=(1,2))

In [None]:
tot_ct = np.sum(migration_count,axis=(1,2))
tot_denom = np.sum(migration_denoms,axis=1)
frac_monofoll = np.sum(tot_ct == 0) / len(tot_ct)

filt = tot_ct >= 1

tot_ct = tot_ct[filt]
tot_denom = tot_denom[filt]

tau_ests1 = np.zeros((len(tot_ct), 3))

for i in range(len(tot_ct)): tau_ests1[i] = poiss_CI(tot_ct[i], tot_denom[i])

tot_ct = np.sum(migration_count,axis=(0,2))
tot_denom = np.sum(migration_denoms,axis=0)
filt = tot_ct >= 1

tot_ct = tot_ct[filt]
tot_denom = tot_denom[filt]

tau_ests2 = np.zeros((len(tot_ct), 3))

for i in range(len(tot_ct)): tau_ests2[i] = poiss_CI(tot_ct[i], tot_denom[i])

rate_values = 1/tau_ests1[:,0]
rate_inferred = fit_survival_func_single_rate(rate_values, lin_denoms, np.exp(np.mean(np.log(rate_values))), points_side = 20, plot_fit = False)

x_vals = np.logspace(np.log10(rate_values.min()), np.log10(rate_values.max()), 500)
surv_sampled_mean, surv_sampled_std, frac_monofoll_sampled, mig_numbers_sampled = get_survival_func_single_rate(lin_denoms, rate_inferred, x_vals)

rates_data_sorted, survival_data = surv_func(rate_values)

In [None]:
std_range = np.linspace(0,2,21)*rate_inferred
N_trials = 10

single_rate_err = np.zeros((len(std_range),N_trials))

for j in range(0,N_trials):

    all_x_vals_gamma = np.zeros((500, len(std_range)))
    all_gamma_surv_sampled_mean = np.zeros((500, len(std_range)))
    all_gamma_surv_sampled_std = np.zeros((500, len(std_range)))
    all_survival_data_gamma_interp = np.zeros((500, len(std_range)))


    for i in range(0,len(std_range)):
        denoms_adj = np.copy(lin_denoms)
        if std_range[i] > 0: denoms_adj *= np.random.gamma(rate_inferred**2/std_range[i]**2, std_range[i]**2/rate_inferred, len(denoms_adj))/rate_inferred
        sampled_gamma_rates = np.random.poisson(denoms_adj*rate_inferred)/lin_denoms
        sampled_gamma_rates_filt = sampled_gamma_rates[sampled_gamma_rates > 0]

        gamma_rate_inferred = fit_survival_func_single_rate(sampled_gamma_rates_filt, lin_denoms, np.exp(np.mean(np.log(sampled_gamma_rates_filt))), points_side=20, plot_fit=False)

        x_vals_gamma = np.logspace(np.log10(sampled_gamma_rates_filt.min()), np.log10(sampled_gamma_rates_filt.max()), 500)
        gamma_surv_sampled_mean, gamma_surv_sampled_std, _, _ = get_survival_func_single_rate(lin_denoms, gamma_rate_inferred, x_vals_gamma)

        rates_data_sorted_gamma, survival_data_gamma = surv_func(sampled_gamma_rates_filt)

        interp_survival_gamma = interp1d(rates_data_sorted_gamma, survival_data_gamma, bounds_error=False, fill_value="extrapolate")
        survival_data_gamma_interp = interp_survival_gamma(x_vals_gamma)

        all_x_vals_gamma[:,i] = x_vals_gamma
        all_gamma_surv_sampled_mean[:,i] = gamma_surv_sampled_mean
        all_gamma_surv_sampled_std[:,i] = gamma_surv_sampled_std
        all_survival_data_gamma_interp[:,i] = survival_data_gamma_interp

        single_rate_err[i,j] = np.mean((all_survival_data_gamma_interp[:,i] - all_gamma_surv_sampled_mean[:,i])**2)

In [None]:
denoms_adj = np.copy(lin_denoms)
denoms_adj *= np.random.exponential(rate_inferred,len(denoms_adj))/rate_inferred
sampled_expon_rates = np.random.poisson(denoms_adj*rate_inferred)/lin_denoms
sampled_expon_rates_filt = sampled_expon_rates[sampled_expon_rates > 0]

expon_rate_inferred = fit_survival_func_single_rate(sampled_expon_rates_filt, lin_denoms, np.exp(np.mean(np.log(sampled_expon_rates_filt))), points_side = 20, plot_fit = False)

x_vals_expon = np.logspace(np.log10(sampled_expon_rates_filt.min()), np.log10(sampled_expon_rates_filt.max()), 500)
expon_surv_sampled_mean, expon_surv_sampled_std, _, _ = get_survival_func_single_rate(lin_denoms, expon_rate_inferred, x_vals_expon)

rates_data_sorted_expon, survival_data_expon = surv_func(sampled_expon_rates_filt)

interp_survival_expon = interp1d(rates_data_sorted_expon, survival_data_expon, bounds_error=False, fill_value="extrapolate")
survival_data_expon_interp = interp_survival_expon(x_vals_expon)

In [None]:
plt.figure(figsize=(4, 4))
plt.plot(rates_data_sorted_expon, survival_data_expon, color='blue', label='Simulated data with variable rates', linewidth=4)
plt.plot(x_vals_expon, expon_surv_sampled_mean, color='red', linewidth=1.5, label='Best-fit single rate')
plt.fill_between(x_vals_expon, expon_surv_sampled_mean - expon_surv_sampled_std, expon_surv_sampled_mean + expon_surv_sampled_std, color='red', alpha=0.3)
plt.xscale('log')
plt.ylim([0, 1])
plt.xticks([1, 10, 100, 1000], [r'$(100\%)^{-1}$', r'$(10\%)^{-1}$', r'$(1\%)^{-1}$', r'$(0.1\%)^{-1}$'])
plt.xlabel('Migration rate (1/percent divergence)')
plt.ylabel('Survival function')

#plt.savefig(r'Figures\Supplementary\SI_Inferred_Rate_Survival.pdf', format='pdf')

plt.tight_layout()
plt.show()


In [None]:
from scipy.interpolate import interp1d

tot_ct = np.sum(migration_count,axis=(1,2))
tot_denom = np.sum(migration_denoms,axis=1)
filt = tot_ct >= 1

err_bar_filt = np.logical_and(tot_ct[filt] < 2, 1/tau_ests1[:,0] > np.percentile(1/tau_ests1[:,0],90))

tot_denom = np.sum(migration_denoms,axis=1)
all_filt = np.ones(len(tot_ct), dtype=bool)
ct = 0
for i in range(len(tot_ct)):
    if filt[i]:
        if err_bar_filt[ct]: all_filt[i] = False
        ct += 1
mig_denoms_filt = tot_denom[all_filt]

rate_values_filt = 1/tau_ests1[np.logical_not(err_bar_filt),0]

rate_inferred_filt = fit_survival_func_single_rate(rate_values_filt, mig_denoms_filt, np.exp(np.mean(np.log(rate_values_filt))), points_side=20, plot_fit=False)

x_vals_filt = np.logspace(np.log10(rate_values_filt.min()), np.log10(rate_values_filt.max()), 500)
surv_sampled_mean_filt, surv_sampled_std_filt, _, _ = get_survival_func_single_rate(mig_denoms_filt, rate_inferred_filt, x_vals_filt)

rates_data_sorted_filt, survival_data_filt = surv_func(rate_values_filt)

In [None]:
masked = np.where(np.isfinite(np.sqrt(single_rate_err)), np.sqrt(single_rate_err), np.nan)

mean_err = np.nanmean(masked, axis=1)
std_err  = np.nanstd(masked, axis=1)

interp_filtered_data = interp1d(rates_data_sorted_filt, survival_data_filt, bounds_error=False, fill_value="extrapolate")
filtered_data_interp = interp_filtered_data(x_vals_filt)
RMSE_data_filt = np.sqrt(np.mean((filtered_data_interp - surv_sampled_mean_filt)**2))

plt.plot(std_range/rate_inferred, mean_err,color='black')
plt.fill_between(std_range/rate_inferred, mean_err - std_err, mean_err + std_err, color='gray', alpha=0.3)

plt.plot([0,2],[RMSE_data_filt, RMSE_data_filt],color='red',linestyle='--',linewidth=2)
plt.ylabel('RMSE of single-rate prediction')
plt.xlabel('CV(Migration rates)')
plt.ylim([0,0.08])
plt.xlim([0,2])
plt.yticks([0,0.04,0.08])
plt.xticks([0,0.5,1,1.5,2])

#plt.savefig(r'Figures\Supplementary\SI_Rate_Variation_Power_Analysis.pdf', format='pdf')

In [None]:
# FIGURE S9: Filtering out putative clonal bursts

plt.ylim([1,250])
plt.yscale('log')
plt.yticks([100/50,100/10, 100/3, 100/1], [r'$(50\%)^{-1}$',r'$(10\%)^{-1}$',r'$(3\%)^{-1}$', r'$(1\%)^{-1}$'])
filt_cutoff = np.percentile(1/tau_ests1[:,0],90)
plt.gca().add_patch(plt.Rectangle((0.72, filt_cutoff), 0.7, 238 - filt_cutoff, color='green', alpha=0.2))
rect = plt.Rectangle((0.72, filt_cutoff), 0.7, 238 - filt_cutoff, fill=False, edgecolor='darkgreen', linewidth=3, linestyle='--')
plt.gca().add_patch(rect)
plt.scatter(tot_ct[filt], 1/tau_ests1[:,0],color='black',alpha=0.2)
plt.xscale('log')
plt.xticks([1,2,5,10,15],['1','2','5','10','15'])
plt.xlim([0.7,19])

plt.xlabel('Number of migration events in lineage')
plt.ylabel('Inferred migration rate (1/percent divergence)')

plt.savefig(r'Figures\Supplementary\SI_Clonal_Burst_Filtering.pdf', format='pdf')

In [None]:
# FIGURE S10A: Spatial distribution of early migrations

fname = 'Data/Preprocessed_Data/pairwise_distances.pickle'
with open(fname, 'rb') as f:
    pairwise_distances = pickle.load(f)
    intra_dists = pairwise_distances['intra_dists']
    all_inter_dists = pairwise_distances['all_inter_dists']
    inter_dists = pairwise_distances['inter_dists']

typ_intra_dist = np.median(intra_dists)

all_inter_dists_early = np.zeros(len(df_IGH_IF)**2)
inter_dists_early = np.zeros(len(df_IGH_IF)**2)
k1 = 0
k2 = 0

grouped = df_IGH_IF.groupby('follicle')
for name1, group1 in grouped:
    for name2, group2 in grouped:
        if name1 >= name2: continue
        for i, row1 in group1.iterrows():
            div_from_root = int(row1['v_mismatch'])/len(row1['v_sequence_no_trunc'])
            if div_from_root >= 0.01: continue
            for j, row2 in group2.iterrows():
                div_from_root = int(row2['v_mismatch'])/len(row2['v_sequence_no_trunc'])
                if div_from_root >= 0.01: continue
                all_inter_dists_early[k1] = euclidean_distance(row1, row2)
                k1 += 1
                if row1['lineage_id'] == row2['lineage_id']:
                    inter_dists_early[k2] = euclidean_distance(row1, row2)
                    k2 += 1

all_inter_dists_early = all_inter_dists_early[:k1]
inter_dists_early = inter_dists_early[:k2]

In [None]:
import pickle

fname = 'Data\Preprocessed_Data/pairwise_distances.pickle'
with open(fname, 'rb') as f:
    pairwise_distances = pickle.load(f)
    intra_dists = pairwise_distances['intra_dists']
    all_inter_dists = pairwise_distances['all_inter_dists']
    inter_dists = pairwise_distances['inter_dists']

typ_intra_dist = np.median(intra_dists)

fname = 'Data\Preprocessed_Data/pairwise_distances_early_migrations.pickle'

with open(fname, 'rb') as f:
    pairwise_distances_cb = pickle.load(f)
    all_inter_dists_early = pairwise_distances_cb['all_inter_dists_early']
    inter_dists_early = pairwise_distances_cb['inter_dists_early']

bin_edges = np.linspace(0,0.1+max(np.max(inter_dists_early)/typ_intra_dist, np.max(all_inter_dists_early)/typ_intra_dist),20)

x_dist1, y_dist1, y_dist_err1 = pdf_histogram(inter_dists_early/typ_intra_dist, bin_edges)
x_dist2, y_dist2, y_dist_err2 = pdf_histogram(all_inter_dists_early/typ_intra_dist, bin_edges)

palette = sb.color_palette("colorblind", 4)
colors = [palette[0], palette[1]]

plt.figure(figsize=(6, 4))

plt.plot(x_dist1, y_dist1, color=colors[0], label='Between follicles, same clone',alpha=0.75,linewidth=3)
plt.plot(x_dist2, y_dist2, color=colors[1], label='Between follicles, all pairs',alpha=0.75,linewidth=3)


plt.xlabel('Euclidean distance between spatial barcodes')
plt.ylabel('Probability density')
plt.legend()
plt.ylim([0,0.25])
plt.yticks([0,0.1,0.2])
#plt.xlim([0,10])

plt.savefig(r'Figures\Supplementary\Early_Migrations_Spatial.pdf', format='pdf')

plt.show()


In [None]:
# FIGURE S10B: Diversification after migration, filtering out early migrations

subtree_divs = []
shuffled_divs = []

subtree_descendants = []
shuffled_descendants = []

subtree_depths = []
shuffled_depths = []

N_randomizations = 100

for ind in range(len(lineage_ids)):

    lin_id = int(float(lineage_ids[ind]))
    tree = Phylo.read("Data/Lineage_Seqs/Lineage_Trees/tree_lin_" + str(lin_id) + "_no_ef.newick", "newick")

    root_flag = False
    for clade in tree.find_clades():
        if clade.name and clade.name.startswith("ROOT"):
            tree.root_with_outgroup(clade)
            if clade.name == "ROOT__INFERRED": tree.collapse(clade)
            root_flag = True
            break
    if not root_flag: print("Error: No root found")

    label_tips(tree, df_IGH_IF, lineage_ids[ind], ignore_EF = True)
    fill_internal_nodes(tree, lineage_ids[ind], df_IGH_IF, track_dominant = True)

    subtree_div, subtree_desc, subtree_depth = get_subtree_data(tree, cutoff = 0, div_cutoff = 0.01)

    subtree_divs.extend(subtree_div)
    subtree_descendants.extend(subtree_desc)
    subtree_depths.extend(subtree_depth)
    
    for j in range(N_randomizations):
        tree_rand = randomize_migration_locations(tree)
        subtree_div, subtree_desc, subtree_depth = get_subtree_data(tree_rand, cutoff = 0, div_cutoff = 0.01)

        shuffled_divs.extend(subtree_div)
        shuffled_descendants.extend(subtree_desc)
        shuffled_depths.extend(subtree_depth)

subtree_divs = np.array(subtree_divs)
shuffled_divs = np.array(shuffled_divs)
subtree_descendants = np.array(subtree_descendants)
shuffled_descendants = np.array(shuffled_descendants)
subtree_depths = np.array(subtree_depths)
shuffled_depths = np.array(shuffled_depths)

In [None]:
x1, y1 = survival_func(subtree_divs[subtree_divs > 0])
x2, y2 = survival_func(shuffled_divs[shuffled_divs > 0])

zero_ct1 = np.sum(subtree_divs == 0)/len(subtree_divs)
zero_ct2 = np.sum(shuffled_divs == 0)/len(shuffled_divs)

plt.plot(x2[y2>0], y2[y2>0],color='red',label='Shuffled',linewidth=3,alpha=0.75)

plt.plot(x1[y1>0], y1[y1>0],color='black',label='Data',linewidth=3,alpha=0.75)
plt.xscale('log')
#plt.yscale('log')
plt.xlabel('Subtree phylogenetic divergence, x')
plt.xticks([0.001,0.01,0.1,1],['0.1%','1%','10%','100%'])
plt.ylabel('Survival function, P(div. > x | div. > 0)')
plt.ylim([0,1])
plt.legend()
plt.savefig(r'Figures\Supplementary\Subtree_Divergence_Filtered.pdf', format='pdf')

plt.show()

plt.figure()
fig, ax = plt.subplots(figsize=(4, 3))
ax.bar(['Data', 'Shuffled'], [zero_ct1, zero_ct2], color=['black', 'red'],alpha=0.75)
ax.set_ylabel('Fraction zero')
ax.set_ylim([0,1])
plt.savefig(r'Figures\Supplementary\Subtree_Divergence_Filtered_Inset.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S10C-D: Immigrant lineage statistics after filtering out early migrations

div_cutoff = 0.01

freq_list = []
first_entry_list = []

folls = df_IGH_IF['follicle'].unique()
N_UMI = np.zeros(len(folls))
N_lineage = np.zeros(len(folls))

max_f = np.zeros(len(folls))
max_f_migrant = np.zeros(len(folls))
max_f_id = np.empty(len(folls),dtype='str')

prop_UMI_migrant = np.zeros(len(folls))
prop_lineage_migrant = np.zeros(len(folls))

for i in range(len(folls)):
    df_foll = df_IGH_IF[df_IGH_IF['follicle'] == folls[i]]
    N_UMI[i] = len(df_foll)

    running_migrant_UMI = 0
    running_migrant_lineage = 0
    running_migrant_UMI_corr = 0
    running_migrant_lineage_corr = 0
    lineages = df_foll['lineage_id'].unique()
    N_lineage[i] = len(lineages)

    for lineage in lineages:
        df_lin = df_foll[df_foll['lineage_id'] == lineage]
        if len(df_lin)/N_UMI[i] > max_f[i]:
            max_f[i] = len(df_lin)/N_UMI[i]
            max_f_id[i] = lineage

        if df_lin['v_mutations'].nunique() <= 1: continue

        lin_id = int(float(lineage))
        tree = Phylo.read("Data/Lineage_Seqs/Lineage_Trees/tree_lin_" + str(lin_id) + "_no_ef.newick", "newick")

        root_flag = False
        for clade in tree.find_clades():
            if clade.name and clade.name.startswith("ROOT"):
                tree.root_with_outgroup(clade)
                if clade.name == "ROOT__INFERRED": tree.collapse(clade)
                root_flag = True
                break
        if not root_flag: print("Error: No root found")

        label_tips(tree, df_IGH, lineage, ignore_EF = True)
        fill_internal_nodes(tree,lineage, df_IGH, track_dominant = True)

        # Find all follicles in the tree with clades below div_cutoff depth
        follicles_near_root = set()
        for clade in tree.find_clades():
            if tree.depths().get(clade) < div_cutoff:
                follicles_near_root.add(clade.metadata['foll'])
                for child in clade.clades: follicles_near_root.add(child.metadata['foll'])

        if folls[i] not in follicles_near_root:
            if len(df_lin)/N_UMI[i] > max_f_migrant[i]: max_f_migrant[i] = len(df_lin)/N_UMI[i]

            prop_lineage_migrant[i] += 1/N_lineage[i]
            prop_UMI_migrant[i] += len(df_lin)/N_UMI[i]

            first_entry_list.append(find_first_entry(tree, folls[i]))
            freq_list.append(len(df_lin)/N_UMI[i])

from collections import Counter

max_lineages = []
for foll in df_IGH_IF['follicle'].unique():
    df_foll = df_IGH_IF[df_IGH_IF['follicle'] == foll]
    if not df_foll.empty:
        lineage_counts = df_foll['lineage_id'].value_counts()
        max_lineage = lineage_counts.idxmax()
        max_lineages.append(max_lineage)

counts = Counter(max_lineages)
for lineage, count in counts.items():
    if count > 1:
        print(f"{lineage}: {count}")
        

In [None]:
plt.scatter(prop_lineage_migrant, prop_UMI_migrant, s=N_UMI, color='black',alpha=0.5)
#plt.scatter(prop_lineage_migrant_corr_stats[1,:],prop_UMI_migrant_corr_stats[1,:],color=[0.7,0.7,0.7])
plt.xlabel('Prop. immigrant lineages')
plt.ylabel('Prop. immigrant UMIs')
plt.xlim([0,0.08])
plt.ylim([0,0.25])
plt.savefig(r'Figures\Supplementary\Prop_Immigrants_Filtered.pdf', format='pdf')

plt.show()

plt.scatter(max_f, max_f_migrant,color='black',s=N_UMI,alpha=0.5)
plt.plot([0,0.4],[0,0.4],color='gray',linestyle='--')
plt.xlabel('Max. lineage frequency')
plt.ylabel('Max. immigrant lineage frequency')
plt.ylim([0,0.2])
plt.xlim([0,0.4])
plt.xticks([0,0.1,0.2,0.3,0.4])
plt.yticks([0,0.1,0.2])
plt.savefig(r'Figures\Supplementary\Max_Freq_Filtered.pdf', format='pdf')
plt.show()

In [None]:
# FIGURE S11: Additional subtree statistics

subtree_divs = []
shuffled_divs = []

subtree_descendants = []
shuffled_descendants = []

subtree_depths = []
shuffled_depths = []

N_randomizations = 100

for ind in range(len(lineage_ids)):

    lin_id = int(float(lineage_ids[ind]))
    tree = Phylo.read("Data/Lineage_Seqs/Lineage_Trees/tree_lin_" + str(lin_id) + "_no_ef.newick", "newick")

    root_flag = False
    for clade in tree.find_clades():
        if clade.name and clade.name.startswith("ROOT"):
            tree.root_with_outgroup(clade)
            if clade.name == "ROOT__INFERRED": tree.collapse(clade)
            root_flag = True
            break
    if not root_flag: print("Error: No root found")

    label_tips(tree, df_IGH, lineage_ids[ind], ignore_EF = True)
    fill_internal_nodes(tree, lineage_ids[ind], df_IGH, track_dominant = True)

    subtree_div, subtree_desc, subtree_depth = get_subtree_data(tree, cutoff = 0, div_cutoff = 0)
    subtree_divs.extend(subtree_div)
    subtree_descendants.extend(subtree_desc)
    subtree_depths.extend(subtree_depth)
    
    for j in range(N_randomizations):
        tree_rand = randomize_migration_locations(tree)
        subtree_div, subtree_desc, subtree_depth = get_subtree_data(tree_rand, cutoff = 0)

        shuffled_divs.extend(subtree_div)
        shuffled_descendants.extend(subtree_desc)
        shuffled_depths.extend(subtree_depth)

subtree_divs = np.array(subtree_divs)
shuffled_divs = np.array(shuffled_divs)
subtree_descendants = np.array(subtree_descendants)
shuffled_descendants = np.array(shuffled_descendants)
subtree_depths = np.array(subtree_depths)
shuffled_depths = np.array(shuffled_depths)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

x1, y1 = survival_func(subtree_descendants)
x2, y2 = survival_func(shuffled_descendants)

axes[0].plot(x1[y1>0], y1[y1>0], color='red', label='Data')
axes[0].plot(x2[y2>0], y2[y2>0], color='black', label='Shuffled')
axes[0].set_xlim([0, 75])
axes[0].set_yscale('log')
axes[0].set_xlabel('Number of descendants after migration, x')
axes[0].set_ylabel('Survival function, P(X>x)')

axes[0].set_ylim([1/(1.1*len(subtree_descendants)), 1])
axes[0].legend()

x1, y1 = survival_func(subtree_depths)
x2, y2 = survival_func(shuffled_depths)

axes[1].plot(x1[y1>0], y1[y1>0], color='red', label='Data')
axes[1].plot(x2[y2>0], y2[y2>0], color='black', label='Shuffled')
axes[1].set_xscale('log')
axes[1].set_yscale('log')
axes[1].set_xlabel('Tree depth after migration, x')
axes[1].set_ylim([1/(1.1*len(subtree_descendants)), 1])
axes[1].legend()

plt.tight_layout()

plt.savefig(r'Figures\Supplementary\Subtree_Stats.pdf', format='pdf')

plt.show()

In [None]:
# FIGURE S12: Frequency of migrant lineages in root vs non-root follicles

freq_list_root = []
freq_list_nonroot = []
N_list_root = []
N_list_nonroot = []
lin_sizes = []

for lineage in df_IGH_IF['lineage_id'].unique():
    df_lin = df_IGH_IF[df_IGH_IF['lineage_id'] == lineage]
    if df_lin['v_mutations'].nunique() <= 1 or df_lin['follicle'].nunique() <= 1: continue
    lin_sizes.append(len(df_lin))

    lin_id = int(float(lineage))
    tree = Phylo.read("Data/Lineage_Seqs/Lineage_Trees/tree_lin_" + str(lin_id) + "_no_ef.newick", "newick")

    root_flag = False
    for clade in tree.find_clades():
        if clade.name and clade.name.startswith("ROOT"):
            tree.root_with_outgroup(clade)
            if clade.name == "ROOT__INFERRED": tree.collapse(clade)
            root_flag = True
            break
    if not root_flag: print("Error: No root found")

    label_tips(tree, df_IGH_IF, lineage, ignore_EF = True)
    fill_internal_nodes(tree,lineage, df_IGH_IF, track_dominant = True)

    root_foll = tree.root.metadata['foll']
    max_N = 0
    max_f = 0
    max_N_name = None
    max_f_name = None

    for foll in df_lin['follicle'].unique():
        N_foll = len(df_IGH_IF[df_IGH_IF['follicle'] == foll])
        if foll == root_foll:
            N_list_root.append(len(df_lin[df_lin['follicle'] == foll]))
            freq_list_root.append(len(df_lin[df_lin['follicle'] == foll])/N_foll)
        else:
            if len(df_lin[df_lin['follicle'] == foll]) > max_N:
                max_N = len(df_lin[df_lin['follicle'] == foll])
                max_N_name = foll
            if len(df_lin[df_lin['follicle'] == foll])/N_foll > max_f:
                max_f = len(df_lin[df_lin['follicle'] == foll])/N_foll
                max_f_name = foll

    N_list_nonroot.append(len(df_lin[df_lin['follicle'] == max_N_name]))
    freq_list_nonroot.append(len(df_lin[df_lin['follicle'] == max_f_name])/len(df_IGH_IF[df_IGH_IF['follicle'] == max_f_name]))

freq_list_root = np.array(freq_list_root)
freq_list_nonroot = np.array(freq_list_nonroot)
N_list_root = np.array(N_list_root)
N_list_nonroot = np.array(N_list_nonroot)
lin_sizes = np.array(lin_sizes)

In [None]:
print(np.sum(freq_list_nonroot > freq_list_root)/len(freq_list_root),np.sum(N_list_nonroot > N_list_root)/len(N_list_nonroot))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Left subplot: N_list_root vs N_list_nonroot
axes[0].scatter(N_list_root, N_list_nonroot, color='black', alpha=0.5, s=lin_sizes)
axes[0].set_xscale('log')
axes[0].set_yscale('log')
axes[0].plot([1, 250], [1, 250], color='red', linestyle='--')
axes[0].set_xlabel('Number UMIs in donor follicle')
axes[0].set_ylabel('Number UMIs in max. recipient follicle')
axes[0].set_xlim([0.9, 225])
axes[0].set_ylim([0.9, 225])
axes[0].set_aspect('equal', adjustable='box')

# Right subplot: freq_list_root vs freq_list_nonroot
axes[1].scatter(freq_list_root, freq_list_nonroot, color='black', alpha=0.5, s=lin_sizes)
axes[1].set_xscale('log')
axes[1].set_yscale('log')
axes[1].plot([0, 1], [0, 1], color='red', linestyle='--')
axes[1].set_xlabel('Frequency donor follicle')
axes[1].set_ylabel('Frequency in max. recipient follicle')
axes[1].set_xlim([0.0005, 0.3])
axes[1].set_ylim([0.0005, 0.3])
axes[1].set_aspect('equal', adjustable='box')



plt.savefig(r'Figures\Supplementary/Lineage_Origin_Destination.pdf', format='pdf')

