In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import os

import igraph
import seaborn as sns
import scipy.stats

import pyrepseq as prs
import pyrepseq.plotting as pp
from metaclonotypist import *

plt.style.use('seaborn-v0_8-paper')

# Change here to define chain and MHC class

In [None]:
chain = 'beta' # needs to be in ['alpha','beta']
mhc_class = 'II' # need to be in ['both','I','II']
dir_out = f"metaclonotypist_{chain}_mhc{mhc_class}"

# Load TCR data

In [None]:
chain_letter = chain[0].upper()

df = pd.read_csv(f'data/combined_subsampled_5000_10000_{chain}.csv.gz')

# Drop rows with missing values
df = df.dropna(subset=[f'TR{chain_letter}V', f'CDR3{chain_letter}'])
df = df[df[f'CDR3{chain_letter}'].apply(len)>5]
df.head()

# Load HLA data

In [None]:
hlas = pd.read_csv('data/hladata.csv', index_col=0)
hlas = flatten_hlas(hlas)
hlas.head()

# Define parameters

In [None]:
testmethod = 'fisher'
mincount = 2
max_edits = 2
max_tcrdist = 10 if chain == 'alpha' else 15
clustering = 'leiden'
clustering_kwargs = dict(resolution=0.1,
                         objective_function='CPM',
                         n_iterations=4)
min_donors = 4

newpath = f'output/{dir_out}' 
if not os.path.exists(newpath):
    os.makedirs(newpath)

# Run analyses

In [None]:
# filter clones < mincount
df = df[df['clonal_count']>=mincount]
# only keep samples found in both datasets
print(set(df['UIN'])-set(hlas.index))
df = df[df['UIN'].isin(hlas.index)]
hlas = hlas.loc[list(set(df['UIN']))]
len(df['UIN'].unique()), len(df)

In [None]:
# filter hlas < min_donors
hlas = hlas[hlas.columns[hlas.sum(axis=0)>=min_donors]]
# filter MHC class
if mhc_class == 'both':
    pass
elif mhc_class == 'I':
    hlas = hlas[hlas.columns[~hlas.columns.str.startswith('D')]]
elif mhc_class == 'II':
    hlas = hlas[hlas.columns[hlas.columns.str.startswith('D')]]
else:
    raise NotImplementedError("mhc_class needs to be in ['both', 'I', 'II']")
len(hlas.columns)

In [None]:
clusters = metaclonotypist(df, chain=chain,
                           max_tcrdist=max_tcrdist, max_edits=max_edits,
                           clustering=clustering, clustering_kwargs=clustering_kwargs)
clusters['Sample.ID'] = df.loc[clusters.index]['UIN']
clusters[f'CDR3{chain_letter}'] = df.loc[clusters.index][f'CDR3{chain_letter}']
len(clusters['cluster'].unique())

In [None]:
# filter clusters < min_donors
ndonors = clusters.groupby('cluster').apply(lambda cluster: len(cluster['Sample.ID'].unique()))
clusters = clusters[clusters['cluster'].isin(ndonors[(ndonors >= min_donors)].index)]
len(clusters['cluster'].unique())

In [None]:
cluster_association = hla_association(clusters, hlas,
                                      method=testmethod)

In [None]:
cluster_association.to_csv(f'output/{dir_out}/clusterassociation_{chain}_mhc{mhc_class}.csv')

In [None]:
nmetaclones = len(cluster_association[cluster_association['significant']]['cluster'].unique())
cluster_association['significant'].sum(), nmetaclones

In [None]:
cluster_association_noinf = cluster_association.replace(np.inf, 400, inplace=False)

# HLA shuffling

In [None]:
# shuffle hlas
hlas_shuffled = hlas.copy()
hlas_shuffled.index = np.random.permutation(hlas_shuffled.index)

In [None]:
cluster_association_shuffled = hla_association(clusters, hlas_shuffled, method=testmethod)

In [None]:
cluster_association_shuffled.to_csv(f'output/{dir_out}/clusterassociation_shuffled_{chain}_mhc{mhc_class}.csv')

In [None]:
cluster_association_shuffled['significant'].sum()

In [None]:
cluster_association_shuffled_noinf = cluster_association_shuffled.replace(np.inf, 400, inplace=False)

# Visualization and saving of results

In [None]:
fig, axes = plt.subplots(figsize=(4.5, 2.5), ncols=2, sharex=True, sharey=True)
sns.scatterplot(ax=axes[0], data=cluster_association_noinf,
                x='odds_ratio',
                y=-np.log10(cluster_association['pvalue']),
                hue='significant',
                s=5)
axes[0].text(0.1, 0.5, f'$n={nmetaclones}$', transform=axes[0].transAxes)
sns.scatterplot(ax=axes[1], data=cluster_association_shuffled_noinf,
                x='odds_ratio',
                y=-np.log10(cluster_association_shuffled['pvalue']),
                hue='significant',
                s=5)
axes[0].set_title('Data')
axes[1].set_title('Shuffled HLA')
for ax in axes:
    ax.set_xscale('log')
    ax.set_ylabel('p value')
    ax.set_xlabel('odds ratio')
    ax.legend(loc='upper left', title='significant')
fig.tight_layout()
fig.savefig(f'output/{dir_out}/volcano_plot_{chain}_mhc{mhc_class}.pdf')

In [None]:
hla_metaclones = cluster_association[cluster_association['significant']]
hla_metaclones.head(20)

In [None]:
fig, ax = plt.subplots()
hla_counts = hla_metaclones['hla'].value_counts()
ax.bar(hla_counts.index, hla_counts)
plt.xticks(rotation=90);
plt.ylabel('# HLA-associated metaclonotypes')
fig.tight_layout()
fig.savefig(f'output/{dir_out}/hla_association_{chain}_mhc{mhc_class}.png')

In [None]:
sig_clusters = clusters[(clusters['cluster'].isin(hla_metaclones['cluster']))].reset_index()
print(len(sig_clusters))
sig_clusters = sig_clusters.merge(hla_metaclones, on='cluster')
hla_match = [hlas.loc[row['Sample.ID']][row['hla']] for ind, row in sig_clusters.iterrows()]
sig_clusters = sig_clusters.iloc[hla_match]
len(sig_clusters)

In [None]:
for cluster in hla_metaclones['cluster'].unique():
    tcrs = df.loc[sig_clusters[(sig_clusters['cluster']==cluster)]['index']]
    tcrs[f'CDR3{chain_letter}'].apply 
    pp.seqlogos_vj(tcrs, cdr3_column=f'CDR3{chain_letter}',
                   v_column=f'TR{chain_letter}Vshort',
                   j_column=f'TR{chain_letter}J')
    plt.gcf().savefig(f'output/{dir_out}/{chain}_mhc{mhc_class}_{cluster}_seqlogo.pdf', dpi=300)

In [None]:
hla_metaclones_unique = hla_metaclones.sort_values('pvalue'
                            ).drop_duplicates(subset='cluster', keep='first'
                            ).reset_index(drop=True)

In [None]:
len(hla_metaclones_unique)
hla_metaclones_unique.head(20)

In [None]:
for i, (cluster, hla) in hla_metaclones_unique[['cluster', 'hla']].iterrows():
    seqs = df.loc[clusters[(clusters['cluster']==cluster)].index]

    neighbors = prs.nearest_neighbor_tcrdist(seqs,
                                             max_edits=max_edits,
                                             max_tcrdist=max_tcrdist,
                                             chain=chain)
    edges = np.array(neighbors)[:, :2]
    g = igraph.Graph(edges, n=len(seqs))
    g.simplify()

    g.vs['ID'] = list(seqs['UIN'])
    
    sample_ids = clusters[clusters['cluster'] == cluster]['Sample.ID']
    g.vs['HLA'] = hlas.loc[sample_ids, hla].apply(lambda x: 'o' if x else '')
    
    g.es['weight'] = 1.0*np.exp(-np.array(neighbors[:, 2])/max_tcrdist)
    c0 = mpl.colors.to_rgba('C0')
    c1 = mpl.colors.to_rgba('C1')
    g.vs['color'] = hlas.loc[sample_ids, hla].apply(lambda x: c1 if x else c0)
    
    edge_idx = np.array([(e.source, e.target) for e in g.es])
    same_sample = (np.array(seqs.iloc[edge_idx[:, 0]]['UIN'])
                   == np.array(seqs.iloc[edge_idx[:, 1]]['UIN']))
    c2 = mpl.colors.to_rgba('C3')
    c3 = mpl.colors.to_rgba('.3')
    g.es['color'] = [c2 if s else c3 for s in same_sample]
    
    width, height = 2.0, 1.0
    scale = 10.0
    layout = g.layout('kk',
                      minx=np.zeros(len(seqs)),
                      maxx=scale*width*np.ones(len(seqs)),
                      miny=np.zeros(len(seqs)),
                      maxy=scale*height*np.ones(len(seqs)))
    fig, ax = plt.subplots(figsize=(width, height))
    igraph.plot(g, target=ax,
                layout=layout,
                vertex_frame_width=0,
                vertex_size=2,
                edge_width=g.es['weight'])
    fig.tight_layout(pad=0.0)
    fig.savefig(f'output/{dir_out}/{chain}_mhc{mhc_class}_{cluster}_graph.pdf', dpi=300)

In [None]:
for i, (cluster, hla) in hla_metaclones_unique[['cluster', 'hla']].iterrows():
    seqs = df.loc[clusters[(clusters['cluster']==cluster)].index]

    neighbors = prs.nearest_neighbor_tcrdist(seqs,
                                             max_edits=max_edits,
                                             max_tcrdist=max_tcrdist,
                                             chain=chain)
    edges = np.array(neighbors)[:, :2]
    g = igraph.Graph(edges, n=len(seqs))
    g.simplify()

    unique_ids = seqs['UIN'].unique()
    id_to_color = dict(zip(unique_ids, plt.colormaps['gist_rainbow'](np.linspace(0, 1, len(unique_ids)))))
    id_to_color = {id_: tuple(color) for id_, color in id_to_color.items()}
    g.vs['color'] = list(seqs['UIN'].map(id_to_color))
    
    width, height = 2.0, 1.0
    scale = 10.0
    layout = g.layout('kk',
                      minx=np.zeros(len(seqs)),
                      maxx=scale*width*np.ones(len(seqs)),
                      miny=np.zeros(len(seqs)),
                      maxy=scale*height*np.ones(len(seqs)))
    fig, ax = plt.subplots(figsize=(width, height))
    igraph.plot(g, target=ax,
                layout=layout,
                vertex_frame_width=0,
                edge_width=10.0/len(seqs),
                vertex_size=2, 
               vertex_color=g.vs['color']) 
    fig.tight_layout(pad=0.0)
    fig.savefig(f'output/{dir_out}/{chain}_mhc{mhc_class}_{cluster}_graph2.pdf', dpi=300)

In [None]:
hla_metaclones_unique['Vs'] = hla_metaclones_unique['cluster'].apply(lambda x:
                                    '|'.join(df.loc[clusters[(clusters['cluster']==x)].index]
                                             [f'TR{chain_letter}Vshort'].unique()))
hla_metaclones_unique.head(81)

In [None]:
for x in hla_metaclones_unique['cluster']:
    seqs = df.loc[clusters[(clusters['cluster']==x)].index][f'CDR3{chain_letter}']
if seqs.empty:
    print(f"No sequences found for cluster {x}.")
else:
    try:
        prs.seqs_to_consensus(seqs)
    except EnvironmentError:
        print(seqs)
        aligned = prs.align_seqs([str(s) for s in seqs])
        print(aligned)

In [None]:
def generate_consensus(cluster):
    seqs = df.loc[clusters[(clusters['cluster'] == cluster)].index][f'CDR3{chain_letter}']
    if seqs.empty:
        print(f"No sequences found for cluster {cluster}.")
        return None 
    try:
        return prs.seqs_to_consensus(seqs)
    except Exception as e:
        print(f"Error processing cluster {cluster}: {e}")
        return None

hla_metaclones_unique['consensus'] = hla_metaclones_unique['cluster'].apply(generate_consensus)

In [None]:
def get_cluster_regex(cluster):
    seqs = df.loc[clusters[clusters['cluster'] == cluster].index][f'CDR3{chain_letter}']
    if seqs.empty:
        print(f"No sequences found for cluster {cluster}.")
        return None
    try:
        return prs.seqs_to_regex(seqs)
    except Exception as e:
        print(f"Error processing cluster {cluster}: {e}")
        return None

hla_metaclones_unique['regex'] = hla_metaclones_unique['cluster'].apply(get_cluster_regex)

In [None]:
hla_metaclones_unique['CDR3s'] = hla_metaclones_unique['cluster'].apply(lambda x:
                                    '|'.join(df.loc[
                                    clusters[(clusters['cluster']==x)].index][f'CDR3{chain_letter}']))

In [None]:
hla_metaclones_unique.to_csv(f'output/{dir_out}/hlametaclonotypes_{chain}_mhc{mhc_class}.csv')
# used as Supplementary Tables S4-5 and S7-8

# Analyze metaclonotype coverage

In [None]:
data = [['nassociations', len(hla_metaclones)],
        ['nmetaclones', len(hla_metaclones['cluster'].unique())],
        ['nshuffled', cluster_association_shuffled['significant'].sum()],
        ['clustered_fraction', len(clusters)/len(df)],
        ['sig_clonotype_fraction', len(sig_clusters)/len(df)],
        ['sig_read_fraction', df.loc[sig_clusters['index']]['clonal_count'].sum()/df['clonal_count'].sum()],
        ['id_fraction', len(sig_clusters['Sample.ID'].unique())/len(df['UIN'].unique())]
       ]
index, values = list(zip(*data))
s = pd.Series(index=index, data=values, name='results')
s

In [None]:
s.to_csv(f'output/{dir_out}/metaclonotype_coverage_{chain}_mhc{mhc_class}.csv', index=True)