In [9]:
import os
import pickle
from collections import defaultdict
import random
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import pandas as pd
from rapidfuzz.distance import Levenshtein as L


### Load file names

In [6]:
# Set HLA of interest - change dynamically
hla = 'A*02'
hla_name = hla.replace('*','')

In [7]:
metadata = pd.read_csv("/Users/ishaharris/Projects/TCR/TCR-Isha/data/Repertoires/Cohort01_whole_metadata.tsv", sep="\t") 
metadata_annotat = metadata[metadata['sample_tags'].str.contains(f'HLA-{hla}',case=False, regex = False)]
metadata_annotat = metadata_annotat[metadata_annotat['sample_tags'].str.contains(r'\bcytomegalovirus|CMV\b', case=False, na=False)]
metadata_annotat = metadata_annotat.reset_index(drop=True)

file_names = [name + '.slim.tsv' for name in metadata_annotat['sample_name'].tolist()]

### Load highconf

In [None]:
# Load high conf data

highconf_dir = '/Users/ishaharris/Projects/TCR/TCR-Isha/data/vdjdb/'
highconf_file_name = 'vdjdb_one_epitope.tsv'
highconf = pd.read_csv(highconf_dir + highconf_file_name, sep='\t')

aa_colname = 'CDR3'

highconf_seqs = highconf.loc[:,aa_colname].tolist()

len(highconf_seqs)

['CASSYSTGTPGIYTF',
 'CASTPAGGAPGELFF',
 'CASSLAPGATNEKLFF',
 'CASSFSGGAPGELFF',
 'CASSYFGGNTEAFF',
 'CASSLAPGATSEKLFF',
 'CASSYQTGTIYGYTF',
 'CASSPQTGAIYGYTF',
 'CASSALGGGGTGELFF',
 'CASSPVQGAFYNSPLHF',
 'CSARDFDRTGELFF',
 'CSVDETGGGETQYF',
 'CASSLAPGATNEKLFF',
 'CASSYQTGAIYGYTF',
 'CASSLAPGTTNEKLFF',
 'CASSSLYGDTGELFF',
 'CASSLVSGSPTGELFF',
 'CASSYATGIGNYGYTF',
 'CASSSVTEAFF',
 'CASSFGTGHTGELFF',
 'CANSYATGIGNYGYTF',
 'CASSYFGSQSEQYF',
 'CASNPSGGYTGELFF',
 'CASSLTTGTGSYGYTF',
 'CASSLQTGVRSYEQYF',
 'CASSFWREPTYEQYF',
 'CASSQADRAVYGYTF',
 'CASSLLIQGGENTEAFF',
 'CASSVQGYTEAFF',
 'CASSPGSINYGYTF',
 'CASSSAYYGYTF',
 'CASSESMIQHF',
 'CASSPRQGGKQPQHF',
 'CASHYGGSSYEQYF',
 'CAWSVSDIMNTEAFF',
 'CASSTPWGGTSGATDTQYF',
 'CASSFFSKKYNNEQFF',
 'CASSLEGFTEAFF',
 'CASSLSPTSGLSYEQYF',
 'CASSLGGGGFYEQYF',
 'CASSWNEQFF',
 'CASSLVGGRYGYTF',
 'CATSRDLVAETQYF',
 'CSARDPLDYVRTDTQYF',
 'CASSWLMGTTYNEQFF',
 'CSVDRYGGDTYEQYF',
 'CSVGGTLDTQYF',
 'CASSEVGATNYGYTF',
 'CASSYFPLADTQYF',
 'CASSSLGGAGTGELFF',
 'CASSR

### Function
- ALso collects a list of clonotypes corresponding to the CMV burden, but only for the top 100 frequencies

In [12]:
def compute_freqs_and_contributors(
    highconf_seqs: list[str],
    file_names: list[str],
    max_distance: int,
    input_dir: str,
    output_file: str,
    contributors_output: str,
    sep: str = '\t',
    seq_col: str = 'cdr3_b_aa',
    freq_col: str = 'productive_frequency',
    contrib_top_n: int = 100,
    n_workers: int = None,
    sample_n: int = None,
    random_seed: int = None,
) -> tuple[pd.DataFrame, dict]:
    """
    Parallel computation of distance-based frequency sums, recording only the top `contrib_top_n`
    contributor sequences (by summed frequency) per patient.
    """
    # Pre-bucket high-confidence sequences by length
    hc_buckets = defaultdict(list)
    for seq in highconf_seqs:
        hc_buckets[len(seq)].append(seq)

    def min_distance_to_highconf(seq: str) -> int:
        best = max_distance + 1
        L_seq = len(seq)
        for tgt_len in range(L_seq - max_distance, L_seq + max_distance + 1):
            for hc in hc_buckets.get(tgt_len, []):
                d = L.distance(seq, hc, score_cutoff=best)
                if d is not None and d < best:
                    best = d
                    if best == 0:
                        return 0
        return best

    def process_patient(fname: str):
        path = os.path.join(input_dir, fname)
        rep = (
            pd.read_csv(path, sep=sep, usecols=[seq_col, freq_col])
              .dropna()
        )
        seqs  = rep[seq_col].to_numpy()
        freqs = rep[freq_col].to_numpy()

        # 1) compute min distances for all sequences
        min_ds = np.array([min_distance_to_highconf(s) for s in seqs])

        # 2) sum frequencies by distance for final DataFrame
        mask = min_ds <= max_distance
        freq_sums = np.bincount(
            min_ds[mask],
            weights=freqs[mask],
            minlength=max_distance + 1
        )

        # 3) sum frequencies per unique sequence across all distances <= max_distance
        contrib_sums = defaultdict(float)
        for d, s, f in zip(min_ds, seqs, freqs):
            if d <= max_distance:
                contrib_sums[s] += f

        # 4) select top contrib_top_n sequences by summed frequency
        top_contrib = dict(
            sorted(contrib_sums.items(), key=lambda kv: kv[1], reverse=True)[:contrib_top_n]
        )

        # 5) prepare output row for freq_df
        row = {'patient_id': fname}
        for d in range(max_distance + 1):
            row[f'dist_{d}'] = float(freq_sums[d])

        return fname, row, top_contrib

    # --- load or initialize main outputs ---
    if os.path.exists(output_file):
        freq_df   = pd.read_csv(output_file)
        completed = set(freq_df['patient_id'])
    else:
        cols = ['patient_id'] + [f'dist_{d}' for d in range(max_distance + 1)]
        freq_df   = pd.DataFrame(columns=cols)
        completed = set()

    if os.path.exists(contributors_output):
        with open(contributors_output, 'rb') as f:
            contributors = pickle.load(f)
    else:
        contributors = {}

    # --- dispatch parallel jobs ---
    pending = [fn for fn in file_names if fn not in completed]

    if sample_n is not None and sample_n < len(pending):
        if random_seed is not None:
            random.seed(random_seed)
        pending = random.sample(pending, sample_n)
    total = len(pending)
    if total == 0:
        return freq_df, contributors

    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = {executor.submit(process_patient, fn): fn for fn in pending}
        for i, fut in enumerate(as_completed(futures), 1):
            fname = futures[fut]
            try:
                _, row, top_contrib = fut.result()
                print(f"[{i}/{total}] Processed {fname}")

                # append to freq_df and contributors dict
                freq_df = pd.concat([freq_df, pd.DataFrame([row])], ignore_index=True)
                contributors[fname] = top_contrib

                # save incremental results
                freq_df.to_csv(output_file, index=False)
                with open(contributors_output, 'wb') as f:
                    pickle.dump(contributors, f)
            except Exception as e:
                print(f"[{i}/{total}] ERROR {fname}: {e}")

    return freq_df, contributors


### Run the function

In [None]:
compute_freqs_and_contributors(
    
)