<a href="https://colab.research.google.com/github/gmestrallet/TCRscRNAseqAnalysis/blob/main/TCRscRNASeqAnalysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to do TCRseq coupled with scRNAseq analysis using scanpy and scirpy with data stored in google drive. More information available https://scirpy.scverse.org/en/latest/tutorials/tutorial_3k_tcr.html

In [None]:
#Mount Google Drive to access your files, if they are stored there.
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#Set the path where you want to store the files (use your own directory).
import os

In [None]:
#Replace 'RNAseq_folder' with the path to the folder in your Google Drive or use '/content/' for local storage.
rna_seq_path = '/content/drive/My Drive/RNAseq_folder'
os.chdir(rna_seq_path)

In [None]:
#Create directories for your data and figures
os.makedirs('data', exist_ok=True)  # Creates 'data' directory if it doesn't exist
os.chdir('data')

In [None]:
#Create 'write' directory inside 'data'
os.makedirs('write', exist_ok=True)

In [None]:
#Install necessary libraries and import
!pip install scanpy  # Make sure scanpy is installed
!pip install igraph  # Make sure igraph is installed
!pip install leidenalg  # Make sure leidenalg is installed
!pip install muon # Install the muon package
!pip install scirpy # Install the scirpy package
!pip install cycler # Install the cycler package
!pip install bbknn # Install the bbknn package
!pip install parasail # Install the parasail package
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import tarfile
import warnings
from glob import glob
import anndata
import muon as mu
import scirpy as ir
from cycler import cycler
from matplotlib import cm as mpl_cm

In [None]:
sc.set_figure_params(figsize=(4, 4))
sc.settings.verbosity = 2  # verbosity: errors (0), warnings (1), info (2), hints (3)

In [None]:
adatas_tcr = {}
adatas_gex = {}

In [None]:
# Load the TCR data
adata_tcr_1 = ir.io.read_10x_vdj(
    "/content/drive/My Drive/RNAseq_folder/n1_filtered_contig_annotations.csv"
)
adata_tcr_1.shape
adatas_tcr['1'] = adata_tcr_1

In [None]:
# Load the TCR data. Duplicate this step for each TCR sample
adata_tcr_2 = ir.io.read_10x_vdj(
    "/content/drive/My Drive/RNAseq_folder/n2_filtered_contig_annotations.csv"
)
adata_tcr_2.shape
adatas_tcr['2'] = adata_tcr_2

In [None]:
adata_tcr = anndata.concat(adatas_tcr)#, index_unique="_")

In [None]:
adata_tcr

In [None]:
#Load scRNAseq data.
adata_1 = sc.read_10x_h5('/content/drive/My Drive/RNAseq_folder/1_sample_filtered_feature_bc_matrix.h5')
adata_1.obs['your_condition'] = '1'
adata_1.var_names_make_unique()
adata_1
adatas_gex['1'] = adata_1

In [None]:
# Load the scRNASeq data. Duplicate this step for each scRNASeq sample
adata_3 = sc.read_10x_h5('/content/drive/My Drive/RNAseq_folder/3_sample_filtered_feature_bc_matrix.h5')
adata_3.obs['your_condition'] = '3'
adata_3.var_names_make_unique()
adata_3
adatas_gex['3'] = adata_3

In [None]:
adata_gex = anndata.concat(adatas_gex, index_unique="_")

In [None]:
adata_gex

In [None]:
cell_barcodes_adata_gex = adata_gex.obs_names

# Print the first few cell barcodes
print(cell_barcodes_adata_gex[:5])

In [None]:
cell_barcodes_adata_tcr = adata_tcr.obs_names

# Print the first few cell barcodes
print(cell_barcodes_adata_tcr[:5])

In [None]:
## Assuming you have already loaded your data into the `adata` and `adata_tcr` variables

# Create a dictionary to map the cell barcodes in adata_tcr to the format in adata
barcode_mapping = {}

# Iterate over the cell barcodes in adata_tcr and map them to the format in adata
for cell_barcode_tcr in adata_tcr.obs_names:
    # Extract the unique identifier (e.g., 'AAACCTGCAGACGTAG-1')
    unique_identifier_tcr = cell_barcode_tcr

    # Extract the corresponding cell barcode from adata that matches the unique identifier
    # Assuming that 'cell_barcodes_adata' contains the cell barcodes in adata
    matching_cell_barcode_adata_gex = [cell_barcode_adata_gex for cell_barcode_adata_gex in cell_barcodes_adata_gex if unique_identifier_tcr in cell_barcode_adata_gex]

    if len(matching_cell_barcode_adata_gex) == 1:
        # If there's a single matching cell barcode in adata, use it
        cell_barcode_adata_gex = matching_cell_barcode_adata_gex[0]
        # Store the mapping
        barcode_mapping[cell_barcode_tcr] = cell_barcode_adata_gex

# Now, you can use the mapping to rename the cell barcodes in adata_tcr
adata_tcr.obs_names = [barcode_mapping.get(cell_barcode, cell_barcode) for cell_barcode in adata_tcr.obs_names]

# Extract the common barcodes between adata and adata_tcr
common_barcodes = set(adata_gex.obs_names).intersection(adata_tcr.obs_names)

# Filter adata_tcr to include only the common barcodes
adata_tcr = adata_tcr[adata_tcr.obs_names.isin(common_barcodes)]
adata_gex = adata_gex[adata_gex.obs_names.isin(common_barcodes)]

# Check if the cell barcodes in adata_tcr now match the format in adata
print(adata_tcr.obs_names[:20])

In [None]:
mdata = mu.MuData({"gex": adata_gex, "airr": adata_tcr})

In [None]:
mdata.obs["your_condition"] = adata_gex.obs["your_condition"]

In [None]:
mdata

In [None]:
sc.pp.log1p(mdata["gex"])
sc.pp.pca(mdata["gex"], svd_solver="arpack")
sc.pp.neighbors(mdata["gex"])
sc.tl.umap(mdata["gex"])
ir.pp.index_chains(mdata)
ir.tl.chain_qc(mdata)

In [None]:
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 4), gridspec_kw={"wspace": 0.5})
mu.pl.embedding(mdata, basis="gex:umap", color="CD3E", ax=ax0, show=False) # Replace by Cd3e if murine data.
mu.pl.embedding(mdata, basis="gex:umap", color="your_condition", ax=ax2, show=False)
mu.pl.embedding(mdata, basis="gex:umap", color="airr:receptor_type", ax=ax1)

In [None]:
%%time
sc.external.pp.bbknn(mdata["gex"], batch_key="your_condition")

In [None]:
sc.tl.umap(mdata["gex"])

In [None]:
sc.pl.umap(mdata["gex"], color=["your_condition"])

In [None]:
sc.pp.filter_cells(mdata["gex"], min_genes=200)
sc.pp.filter_genes(mdata["gex"], min_cells=3)

In [None]:
mdata["gex"].var['mt'] = mdata["gex"].var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(mdata["gex"], qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

In [None]:
sc.pl.violin(mdata["gex"], ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'],
             jitter=0.4, multi_panel=True)

In [None]:
sc.pp.filter_genes(mdata["gex"], min_cells=10)
sc.pp.filter_cells(mdata["gex"], min_genes=100)

In [None]:
sc.pp.normalize_per_cell(mdata["gex"])
sc.pp.log1p(mdata["gex"])
sc.pp.highly_variable_genes(mdata["gex"], flavor="cell_ranger", n_top_genes=5000)
sc.tl.pca(mdata["gex"])
sc.pp.neighbors(mdata["gex"])

In [None]:
sc.pl.scatter(mdata["gex"], x='total_counts', y='pct_counts_mt')
sc.pl.scatter(mdata["gex"], x='total_counts', y='n_genes_by_counts')

In [None]:
sc.pp.normalize_total(mdata["gex"], target_sum=10000)

In [None]:
sc.pp.log1p(mdata["gex"])

In [None]:
sc.pp.highly_variable_genes(mdata["gex"], min_mean=0.0125, max_mean=3, min_disp=0.5)

In [None]:
sc.pl.highly_variable_genes(mdata["gex"])

In [None]:
mdata["gex"].raw = mdata["gex"]

In [None]:
sc.pp.scale(mdata["gex"], max_value=10)

In [None]:
#sc.pp.pca(mdata["gex"])
%%time
sc.external.pp.bbknn(mdata["gex"], batch_key="your_condition")
#sc.pp.neighbors(mdata["gex"])
sc.tl.umap(mdata["gex"])
sc.pl.umap(mdata["gex"], color='your_condition')

In [None]:
sc.tl.rank_genes_groups(mdata["gex"], groupby="your_condition")
sc.pl.rank_genes_groups(mdata["gex"])

In [None]:
sc.tl.pca(mdata["gex"], svd_solver='arpack')

In [None]:
sc.pl.pca(mdata["gex"], color='CD3E') # Replace by Cd3e if murine data.

In [None]:
sc.pl.pca_variance_ratio(mdata["gex"], log=True)

In [None]:
sc.pp.neighbors(mdata["gex"], n_neighbors=10, n_pcs=40)

In [None]:
sc.tl.umap(mdata["gex"])

In [None]:
sc.pl.umap(mdata["gex"], color=['CD3E','CD4','CD8A'])

In [None]:
sc.tl.leiden(mdata["gex"])

In [None]:
sc.pl.umap(mdata["gex"], color=['your_condition','leiden'])

In [None]:
sc.tl.rank_genes_groups(mdata["gex"], 'leiden', method='wilcoxon')
sc.pl.rank_genes_groups(mdata["gex"], n_genes=25, sharey=False)

In [None]:
marker_genes = ['CD3E','CD4','CD8A']

In [None]:
pd.DataFrame(mdata["gex"].uns['rank_genes_groups']['names']).head(5)

In [None]:
result = mdata["gex"].uns['rank_genes_groups']
groups = result['names'].dtype.names
pd.DataFrame(
    {group + '_' + key[:1]: result[key][group]
    for group in groups for key in ['names', 'pvals']}).head(5)

In [None]:
sc.pl.violin(mdata["gex"], marker_genes, groupby='leiden')

In [None]:
new_T_cluster_names = ['Your_cluster_1','Your_cluster_2','Your_cluster_3','Your_cluster_4','Your_cluster_5','Your_cluster_6','Your_cluster_7','Your_cluster_8','Your_cluster_9','Your_cluster_10','Your_cluster_11','Your_cluster_12','Your_cluster_13'] #Modify the number of clusters and labels as needed.
mdata["gex"].rename_categories('leiden', new_T_cluster_names)

In [None]:
sc.pl.umap(mdata["gex"], color='leiden', legend_loc='on data', title='', frameon=False, legend_fontsize=8, legend_fontoutline=2, save='.png')

In [None]:
sc.pl.umap(mdata["gex"], color=['leiden'])

In [None]:
sc.pl.rank_genes_groups_dotplot(mdata["gex"], n_genes=4, values_to_plot='logfoldchanges', min_logfoldchange=3, vmax=7, vmin=-7, cmap='bwr')

In [None]:
sc.pl.dotplot(mdata["gex"], marker_genes, groupby='leiden', title='');

In [None]:
from matplotlib.pyplot import rc_context

In [None]:
with rc_context({'figure.figsize': (9, 1.5)}):
    sc.pl.rank_genes_groups_violin(mdata["gex"], n_genes=20, jitter=False, save='.png')

In [None]:
sc.tl.embedding_density(mdata["gex"], groupby='your_condition')

In [None]:
sc.pl.embedding_density(mdata["gex"], groupby='your_condition')

In [None]:
cluster_counts = pd.Series(mdata["gex"].obs['leiden']).value_counts()
cluster_percentages = cluster_counts / cluster_counts.sum() * 100
cluster_percentages

In [None]:
cumulative_percentages = cluster_percentages.cumsum()

# Create a bar plot of the percentages
fig, ax = plt.subplots()
ax.bar(cluster_percentages.index, cluster_percentages.values, label='Percentages')
ax.set_xlabel('Cluster')
plt.xticks(rotation=90)
ax.set_ylabel('Percentage')
ax.set_title('Cluster Percentages')
ax.legend()

# Create a line plot of the cumulative percentages
ax2 = ax.twinx()
ax2.plot(cluster_percentages.index, cumulative_percentages.values, color='red', marker='o', label='Cumulative Percentages')
ax2.set_ylabel('Cumulative Percentage')
ax2.legend(loc='upper right')

# Show the plot
plt.show()

In [None]:
group_column = "your_condition"
cluster_column = "leiden"
group_counts = mdata["gex"].obs.groupby([group_column, cluster_column]).size()
group_percentages = group_counts / group_counts.groupby(level=0).transform(sum) * 100
group_percentages = group_percentages.reset_index(name="Percentage")

# Print the resulting table
print(group_percentages)

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
group_percentages.pivot(index=cluster_column, columns=group_column, values="Percentage").plot(kind="bar", ax=ax)

# Set plot properties
ax.set_xlabel("Cluster")
ax.set_ylabel("Percentage")
ax.set_title("Percentage of Cells in Each Cluster for Different Groups")
ax.legend(title=group_column)

# Display the plot
plt.show()

In [None]:
# Calculate the percentage of cells in each cluster for different groups
group_column = "your_condition"
cluster_column = "leiden"
group_counts = mdata["gex"].obs.groupby([group_column, cluster_column]).size()
group_percentages = group_counts / group_counts.groupby(level=0).transform(sum) * 100
group_percentages = group_percentages.reset_index(name="Percentage")

# Create a bar plot
fig, ax = plt.subplots(figsize=(10, 6))
group_percentages.pivot(index=group_column, columns=cluster_column, values="Percentage").plot(kind="bar", stacked=True, ax=ax)

# Set plot properties
ax.set_xlabel("Group")
ax.set_ylabel("Percentage of Clusters")
ax.set_title("Percentage of Cells in Each Cluster for Different Groups")
ax.legend(title=group_column, bbox_to_anchor=(1.05, 1), loc="upper left")

# Display the plot
plt.show()

In [None]:
# Calculate the number of cells in each cluster for each group
cluster_counts = mdata["gex"].obs.groupby(['your_condition', 'leiden']).size().unstack().fillna(0)

# Set the figure size
plt.figure(figsize=(10, 6))

# Plot the number of cells in each cluster for each group using a bar plot
sns.barplot(data=cluster_counts, palette='viridis')

# Set the plot labels and title
plt.xlabel('Cluster')
plt.ylabel('Number of Cells')
plt.title('Number of Cells in Each Cluster for Each Group')

# Rotate the x-axis labels if needed
plt.xticks(rotation=90)

# Show the plot
plt.show()

In [None]:
cluster_counts = mdata["gex"].obs.groupby(['your_condition', 'leiden']).size().unstack().fillna(0)

# Set the figure size
plt.figure(figsize=(10, 6))

# Plot the cell counts as a heatmap
sns.heatmap(cluster_counts, annot=True, fmt='g', cmap='YlGnBu')

# Set the plot labels and title
plt.xlabel('Cluster')
plt.ylabel('Group')
plt.title('Number of Cells\nComparison Between Groups for Each Cluster')

# Show the plot
plt.show()

In [None]:
!mkdir dataTCR
!cd dataTCR
!mkdir write

In [None]:
resultsmdataTCR_file = 'write/mdataTCRs.h5ad'  # the file that will store the analysis results

In [None]:
mdata["gex"].obsm["X_umap_TCR"] = mdata["gex"].obsm["X_umap"]

In [None]:
mu.pl.embedding(
    mdata,
    basis="gex:umap",
    color=["gex:your_condition"],
    ncols=3,
    wspace=0.7,
)
mu.pl.embedding(
    mdata,
    basis="gex:umap",
    color=["CD8A", "CD4"],
    ncols=3,
    wspace=0.7,
)

In [None]:
ir.pp.index_chains(mdata)

In [None]:
ir.tl.chain_qc(mdata)

In [None]:
_ = ir.pl.group_abundance(
    mdata, groupby="airr:receptor_subtype", target_col="gex:your_condition"
)

In [None]:
_ = ir.pl.group_abundance(mdata, groupby="airr:chain_pairing", target_col="gex:your_condition")

In [None]:
print(
    "Fraction of cells with more than one pair of TCRs: {:.2f}".format(
        np.sum(
            mdata.obs["airr:chain_pairing"].isin(
                ["extra VJ", "extra VDJ", "two full chains", "multichain"]
            )
        )
        / mdata["airr"].n_obs
    )
)

In [None]:
mu.pl.embedding(
    mdata, basis="gex:umap", color="airr:chain_pairing", groups="multichain"
)

In [None]:
mu.pp.filter_obs(mdata, "airr:chain_pairing", lambda x: x != "multichain")

In [None]:
mu.pp.filter_obs(
    mdata, "airr:chain_pairing", lambda x: ~np.isin(x, ["orphan VDJ", "orphan VJ"])
)

In [None]:
mdata

In [None]:
ax = ir.pl.group_abundance(mdata, groupby="airr:chain_pairing", target_col="gex:your_condition")

In [None]:
# using default parameters, `ir_dist` will compute nucleotide sequence identity
ir.pp.ir_dist(mdata)
ir.tl.define_clonotypes(mdata, receptor_arms="all", dual_ir="primary_only")

In [None]:
ir.tl.clonotype_network(mdata, min_cells=2)

In [None]:
mdata.obs.groupby("gex:your_condition", dropna=False).size()

In [None]:
_ = ir.pl.clonotype_network(
    mdata, color="gex:your_condition", base_size=20, label_fontsize=9, panel_size=(7, 7)
)

In [None]:
ir.pp.ir_dist(
    mdata,
    metric="alignment",
    sequence="aa",
    cutoff=15,
)

In [None]:
ir.tl.define_clonotype_clusters(
    mdata, sequence="aa", metric="alignment", receptor_arms="all", dual_ir="any"
)

In [None]:
ir.tl.clonotype_network(mdata, min_cells=3, sequence="aa", metric="alignment")

In [None]:
_ = ir.pl.clonotype_network(
    mdata, color="gex:your_condition", label_fontsize=9, panel_size=(7, 7), base_size=20
)

In [None]:
with ir.get.airr_context(mdata, "junction_aa", ["VJ_1", "VDJ_1", "VJ_2", "VDJ_2"]):
    cdr3_ct_169 = (
        # TODO astype(str) is required due to a bug in pandas ignoring `dropna=False`. It seems fixed in pandas 2.x
        mdata.obs.loc[lambda x: x["airr:cc_aa_alignment"] == "169"]
        .astype(str)
        .groupby(
            [
                "VJ_1_junction_aa",
                "VDJ_1_junction_aa",
                "VJ_2_junction_aa",
                "VDJ_2_junction_aa",
                "airr:receptor_subtype",
            ],
            observed=True,
            dropna=False,
        )
        .size()
        .reset_index(name="n_cells")
    )
cdr3_ct_169

In [None]:
ir.tl.define_clonotype_clusters(
    mdata,
    sequence="aa",
    metric="alignment",
    receptor_arms="all",
    dual_ir="any",
    same_v_gene=True,
    key_added="cc_aa_alignment_same_v",
)

In [None]:
# find clonotypes with more than one `clonotype_same_v`
ct_different_v = mdata.obs.groupby("airr:cc_aa_alignment").apply(
    lambda x: x["airr:cc_aa_alignment_same_v"].nunique() > 1
)
ct_different_v = ct_different_v[ct_different_v].index.values.tolist()
ct_different_v

In [None]:
with ir.get.airr_context(mdata, "v_call", ["VJ_1", "VDJ_1"]):
    ct_different_v_df = (
        mdata.obs.loc[
            lambda x: x["airr:cc_aa_alignment"].isin(ct_different_v),
            [
                "airr:cc_aa_alignment",
                "airr:cc_aa_alignment_same_v",
                "VJ_1_v_call",
                "VDJ_1_v_call",
            ],
        ]
        .sort_values("airr:cc_aa_alignment")
        .drop_duplicates()
        .reset_index(drop=True)
    )
ct_different_v_df

In [None]:
ir.tl.clonal_expansion(mdata)

In [None]:
mu.pl.embedding(
    mdata, basis="gex:umap", color=["airr:clonal_expansion", "airr:clone_id_size"]
)

In [None]:
_ = ir.pl.clonal_expansion(
    mdata, target_col="clone_id", groupby="gex:leiden", clip_at=4, normalize=False
)

In [None]:
ir.pl.clonal_expansion(mdata, target_col="clone_id", groupby="gex:leiden")

In [None]:
_ = ir.pl.alpha_diversity(
    mdata, metric="normalized_shannon_entropy", groupby="gex:leiden"
)

In [None]:
_ = ir.pl.group_abundance(
    mdata, groupby="airr:clone_id", target_col="gex:leiden", max_cols=10
)

In [None]:
_ = ir.pl.group_abundance(
    mdata,
    groupby="airr:clone_id",
    target_col="gex:leiden",
    max_cols=10,
    normalize="gex:your_condition",
)

In [None]:
_ = ir.pl.group_abundance(
    normalize="gex:your_condition",
    mdata, groupby="airr:clone_id", target_col="gex:your_condition", max_cols=15, figsize=(5, 3)
)

In [None]:
_ = ir.pl.group_abundance(
    mdata,
    groupby="airr:clone_id",
    target_col="gex:your_condition",
    max_cols=15,
    figsize=(5, 3),
)

In [None]:
ir.tl.clonotype_convergence(mdata, key_coarse="cc_aa_alignment", key_fine="clone_id")

In [None]:
mu.pl.embedding(mdata, "gex:umap", color="airr:is_convergent")

In [None]:
with ir.get.airr_context(mdata, "v_call"):
    ir.pl.group_abundance(
        mdata,
        groupby="VJ_1_v_call",
        target_col="gex:leiden",
        normalize=True,
        max_cols=10,
    )

In [None]:
_ = ir.pl.vdj_usage(
    mdata,
    full_combination=False,
    max_segments=None,
    max_ribbons=30,
    fig_kws={"figsize": (8, 5)},
)

In [None]:
ir.pl.vdj_usage(
    mdata[mdata.obs["airr:clone_id"].isin(["183", "186", "187", "185", "9", "184", "27", "23", "8", "10"]), :], #Replace with you clones of interest.
    max_ribbons=None,
    max_segments=100,
)

In [None]:
ir.pl.spectratype(mdata, color="gex:leiden", viztype="bar", fig_kws={"dpi": 120})

In [None]:
ir.pl.spectratype(
    mdata,
    color="gex:leiden",
    viztype="curve",
    curve_layout="shifted",
    fig_kws={"dpi": 120},
    kde_kws={"kde_norm": False},
)

In [None]:
with ir.get.airr_context(mdata, "v_call"):
    ir.pl.spectratype(
        mdata[
            mdata.obs["VDJ_1_v_call"].isin(
                ["TRBV20-1", "TRBV7-2", "TRBV28", "TRBV5-1", "TRBV7-9"]
            ),
            :,
        ],
        chain="VDJ_1",
        color="VDJ_1_v_call",
        normalize="gex:patient",
        fig_kws={"dpi": 120},
    )

In [None]:
df, dst, lk = ir.tl.repertoire_overlap(mdata, "gex:your_condition", inplace=False)
df.head()

In [None]:
ir.pl.repertoire_overlap(
    mdata, "gex:your_condition", pair_to_plot=["1", "3"], fig_kws={"dpi": 120} #Replace by your condition names.
)

In [None]:
ir.tl.clonotype_modularity(mdata, target_col="airr:cc_aa_alignment")

In [None]:
mu.pl.embedding(mdata, basis="gex:umap", color="airr:clonotype_modularity")

In [None]:
_ = ir.pl.clonotype_network(
    mdata,
    color="clonotype_modularity",
    label_fontsize=9,
    panel_size=(6, 6),
    base_size=20,
)

In [None]:
ir.pl.clonotype_modularity(mdata, base_size=20)

In [None]:
clonotypes_top_modularity = list(
    mdata.obs.set_index("airr:cc_aa_alignment")["airr:clonotype_modularity"]
    .sort_values(ascending=False)
    .index.unique()
    .values[:20]
)

In [None]:
test_ad = mu.pl.embedding(
    mdata,
    basis="gex:umap",
    color="airr:cc_aa_alignment",
    groups=clonotypes_top_modularity,
    palette=cycler(color=mpl_cm.Dark2_r.colors),
)

In [None]:
# Since sc.tl.rank_genes_group does not support MuData, we need to temporarily add
# the AIRR columns to the gene expression AnnData object
with ir.get.obs_context(
    mdata["gex"], {"cc_aa_alignment": mdata.obs["airr:cc_aa_alignment"]}
) as tmp_ad:
    sc.tl.rank_genes_groups(
        tmp_ad,
        "cc_aa_alignment",
        groups=clonotypes_top_modularity,
        reference="rest",
        method="wilcoxon",
    )
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    for ct, ax in zip(clonotypes_top_modularity, axs):
        sc.pl.rank_genes_groups_violin(
            tmp_ad, groups=[ct], n_genes=15, ax=ax, show=False, strip=False
        )

In [None]:
freq, stat = ir.tl.clonotype_imbalance(
    mdata,
    replicate_col="gex:patient",
    groupby="gex:leiden",
    case_label="CD8_T_GZMB",
    control_label="CD4_T_GZMB",
    inplace=False,
)
top_differential_clonotypes = stat["clone_id"].tolist()[:10]

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={"wspace": 0.6})
mu.pl.embedding(mdata, basis="gex:umap", color="gex:leiden", ax=ax1, show=False)
mu.pl.embedding(
    mdata,
    basis="gex:umap",
    color="airr:clone_id",
    groups=top_differential_clonotypes,
    ax=ax2,
    # increase size of highlighted dots
    size=[
        80 if c in top_differential_clonotypes else 30
        for c in mdata.obs["airr:clone_id"][mdata.mod["gex"].obs_names]
    ],
    palette=cycler(color=mpl_cm.Dark2_r.colors),
)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={"wspace": 0.6})
mu.pl.embedding(mdata, basis="gex:umap", color="gex:leiden", ax=ax1, show=False)
mu.pl.embedding(
    mdata,
    basis="gex:umap",
    color="airr:clone_id",
    groups=['9','1','19','77'], #Replace with your clones of interest.
    ax=ax2,
    # increase size of highlighted dots
    size=[
        80 for c in mdata.obs["airr:clone_id"][mdata.mod["gex"].obs_names]
    ],
    palette=cycler(color=mpl_cm.Dark2_r.colors),
)

In [None]:
# ir.tl.repertoire_overlap(mdata, "gex:cluster")
_ = ir.pl.repertoire_overlap(
    mdata, "gex:leiden", pair_to_plot=["Your_Cluster_1", "Your_Cluster_2"], fig_kws={"dpi": 120} #Replace with your clusters of interest.
)

In [None]:
with ir.get.obs_context(
    mdata["gex"], {"clone_id": mdata.obs["airr:clone_id"]}
) as tmp_ad:
    sc.tl.rank_genes_groups(
        tmp_ad, "clone_id", groups=["93"], method="wilcoxon"
    )
    sc.pl.rank_genes_groups_violin(tmp_ad, groups="93", n_genes=15)

In [None]:
vdjdb = ir.datasets.vdjdb()

In [None]:
ir.pp.ir_dist(mdata, vdjdb, metric="identity", sequence="aa")

In [None]:
ir.tl.ir_query(
    mdata,
    vdjdb,
    metric="identity",
    sequence="aa",
    receptor_arms="any",
    dual_ir="any",
)

In [None]:
ir.tl.ir_query_annotate_df(
    mdata,
    vdjdb,
    metric="identity",
    sequence="aa",
    include_ref_cols=["antigen.species", "antigen.gene"],
).tail()

In [None]:
ir.tl.ir_query_annotate(
    mdata,
    vdjdb,
    metric="identity",
    sequence="aa",
    include_ref_cols=["antigen.species"],
    strategy="most-frequent",
)

In [None]:
mu.pl.embedding(mdata, "gex:umap", color="airr:antigen.species")