## scRNAseq CD8 Tm high salt and low salt preprocessing and analysis

Author: Maha Alissa Alkhalaf

Figures: Figure 3 (C-E), Figure 4 (M, O), Extended Data 10 (A, C-H)

In [None]:
import scanpy as sc
import scrublet as scr
import numpy as np
import gseapy as gp
import pandas as pd
import seaborn as sns
import scipy as sci
import matplotlib.pyplot as plt
import scipy.stats as stats
import celltypist
from pathlib import Path
import os
import re
from statannot import add_stat_annotation

## Preprocessing

In [None]:
low_salt = sc.read_h5ad('../data/salt_data/low_salt.h5ad')
high_salt = sc.read_h5ad('../data/salt_data/high_salt.h5ad')

In [None]:
low_salt.obs['Condition'] = 'low salt'
high_salt.obs['Condition'] = 'high salt'

In [None]:
scrub = scr.Scrublet(high_salt.raw.X)
high_salt.obs['doublet_scores'], high_salt.obs['predicted_doublets'] = scrub.scrub_doublets()

In [None]:
scrub = scr.Scrublet(low_salt.raw.X)
low_salt.obs['doublet_scores'], low_salt.obs['predicted_doublets'] = scrub.scrub_doublets()

In [None]:
low_salt = low_salt[low_salt.obs['predicted_doublets'] == False, :]
high_salt = high_salt[high_salt.obs['predicted_doublets'] == False, :]

In [None]:
adata = sc.concat([low_salt, high_salt], label = 'dataset')
adata

In [None]:
sc.pp.filter_cells(adata, min_genes = 200)

In [None]:
sc.pp.filter_genes(adata, min_cells = 3)
adata

In [None]:
sc.pp.normalize_total(adata, target_sum = 1e4)
sc.pp.log1p(adata)

In [None]:
adata.raw = adata

In [None]:
adata.X = np.nan_to_num(adata.X, nan = 0)

In [None]:
adata.write('../data/salt_data/high_and_low_salt.h5ad')

## Module Score

In [None]:
adata = sc.read_h5ad('../data/salt_data/high_and_low_salt.h5ad')

In [None]:
def print_boxplot_stats(data, label):
    minimum = np.min(data)
    maximum = np.max(data)
    q1 = np.percentile(data, 25)
    median = np.median(data)
    q3 = np.percentile(data, 75)
    iqr = q3 - q1
    lower_whisker = np.max([minimum, q1 - 1.5 * iqr])
    upper_whisker = np.min([maximum, q3 + 1.5 * iqr])

    return {'min': minimum, 
           'max': maximum, 
           'q1': q1,
           'median': median,
           'q3': q3,
           'lower_whisker': lower_whisker,
           'upper_whisker': upper_whisker}
    
def gene_expression(adata, gene, conditions, condition, figure, df):
    plt.figure(figsize = (2, 2.5), dpi = 300)

    gene_index = np.where(adata.var.index == gene)[0][0]
    high_salt = np.array(adata.X[adata.obs[condition] == conditions[0], adata.var_names == gene]).flatten()
    low_salt = np.array(adata.X[adata.obs[condition] == conditions[1], adata.var_names == gene]).flatten()
    
    dict_high = print_boxplot_stats(high_salt, "high")
    dict_low = print_boxplot_stats(low_salt, "low")
    
    dict_high['figure'] = figure
    dict_low['figure'] = figure
    
    dict_high['n'] = high_salt.shape[0]
    dict_low['n'] = low_salt.shape[0]
    
    dict_high['type'] = 'High NaCl'
    dict_low['type'] = 'Low NaCl'
    
    dict_high['statistical_test'] = 'one-sided'
    dict_low['statistical_test'] = 'one-sided'

    df = pd.concat([df, pd.DataFrame([dict_high]), pd.DataFrame([dict_low])], ignore_index=True)

    alternatives = ['greater', 'less']
    p_values = []
    for alternative in alternatives:
        _ , p = stats.ranksums(high_salt, low_salt, alternative = alternative)
        p_values.append(p)
    p_values = [f'{i[0]} : {i[1]:.4e}' for i in list(zip(alternatives, p_values))]

    colors = ['#ffa37b', '#A7C7E7']
    customPalette = sns.set_palette(sns.color_palette(colors))

    ax = sns.violinplot(data = [high_salt, low_salt], saturation = 0.9, width = 0.9, palette = customPalette, linewidth = 0.3, kws = {'linecolor' : 'black'})
    for i, c in enumerate(ax.collections):
        ax.collections[i].set_edgecolor('black')

    sns.boxplot(data = [high_salt, low_salt], width = 0.4,
                boxprops = {'zorder': 2, 'edgecolor' : 'black'},
                capprops = {'color' : 'black'},
                whiskerprops = {'color' : 'black'},
                medianprops = {'color' : 'black'},
                showfliers = False,
                linewidth = 0.3,
                ax = ax)

    sns.stripplot(data = [high_salt, low_salt], color = 'black', ax = ax, size = 0.4)

    ax.set_ylabel(f'Log-scaled expression value', fontsize = 4)

    ax.set_yticklabels(ax.get_yticks(), size = 4);
    ax.set_xticklabels(ax.get_xticklabels(), size = 4);

    labels = [item.get_text() for item in ax.get_yticklabels()]

    ax.set_xticklabels([conditions[0], conditions[1]])
    ax.set_yticklabels([str(round(float(label), 2)) for label in labels])

    ax.set_title(f'Gene: {gene}\nWilcoxon rank sum, p-values:\n {", ".join(p_values)}', fontsize = 4)
    sns.despine()
    plt.savefig(f'../figures/violin_plot_expression_values_of_{gene}_in_{conditions}.pdf', dpi = 300, bbox_inches = 'tight')
    plt.show()
    plt.clf()
    return df

## CD8 tm scRNA-seq salt data

In [None]:
df = pd.read_csv('MAA_figures_info.csv')
df

In [None]:
df = gene_expression(adata, 'SLC7A5', ['high salt', 'low salt'], 'Condition', '4m', df) 

In [None]:
df = gene_expression(adata, 'FABP5', ['high salt', 'low salt'], 'Condition', '4o', df) 

In [None]:
df  = gene_expression(adata, 'ICOS', ['high salt', 'low salt'], 'Condition', 'Supplementary Figure 10a (1)', df) 

In [None]:
df = gene_expression(adata, 'ITGAE', ['high salt', 'low salt'], 'Condition', 'Supplementary Figure 10a (2)', df) 

In [None]:
df = gene_expression(adata, 'PDCD1', ['high salt', 'low salt'], 'Condition', 'Supplementary Figure 10a (3)', df) 

## Module score for gene sets

In [None]:
df = pd.read_csv('MAA_figures_info.csv')
df

In [None]:
def gen_text(row):
    print(f"{row['figure']}: One-tailed wilcoxon rank sum test. {row['type']}: min={row['min']}, q1={row['q1']}, median={row['median']}, q3={row['q3']}, max={row['max']}, lower whisker={row['lower_whisker']}, upper whisker={row['upper_whisker']}")

In [None]:
marker_genes_tissue_residency = {}
column_names = []

for i in [15, 13, 11, 19, 21, 1, 9]:
    trm = pd.read_excel('../data/gene_sets/TRM_Signatures.xlsx', i)
    column_name = trm.columns[0]
    column_names.append(column_name)
    marker_genes_tissue_residency[column_name] = list(trm[column_name])[1:]

In [None]:
def module_score(adata, data_set, data_set_name, figure, df):

    sc.tl.score_genes(adata, data_set)

    low_salt = np.array(adata.obs[adata.obs.Condition == 'low salt']['score'])
    high_salt = np.array(adata.obs[adata.obs.Condition == 'high salt']['score'])

    plt.figure(figsize = (2, 2.5), dpi = 300)
    alternative = 'greater'

    dict_high = print_boxplot_stats(high_salt, "high")
    dict_low = print_boxplot_stats(low_salt, "low")
    
    dict_high['figure'] = figure
    dict_low['figure'] = figure
    
    dict_high['n'] = high_salt.shape[0]
    dict_low['n'] = low_salt.shape[0]
    
    dict_high['type'] = 'High NaCl'
    dict_low['type'] = 'Low NaCl'
    
    dict_high['statistical_test'] = 'one-sided'
    dict_low['statistical_test'] = 'one-sided'

    df = pd.concat([df, pd.DataFrame([dict_high]), pd.DataFrame([dict_low])], ignore_index=True)

    alternatives = ['two-sided', 'greater', 'less']
    
    p_values = []
    for alternative in alternatives:
        _ , p = stats.ranksums(high_salt, low_salt, alternative = alternative)
        p_values.append(p)
    p_values = [f'{i[0]} : {i[1]:.4e}' for i in list(zip(alternatives, p_values))]
    print(f'p-values:\n {", ".join(p_values)}')

    colors = ['#ffa37b', '#A7C7E7']
    customPalette = sns.set_palette(sns.color_palette(colors))

    ax = sns.violinplot(data = [high_salt, low_salt], saturation = 0.9, width = 0.9, palette = customPalette, linewidth = 0.3, kws = {'linecolor' : 'black'})
    for i, c in enumerate(ax.collections):
        ax.collections[i].set_edgecolor('black')

    sns.boxplot(data = [high_salt, low_salt], width = 0.4,
                boxprops = {'zorder': 2, 'edgecolor' : 'black'},
                capprops = {'color' : 'black'},
                whiskerprops = {'color' : 'black'},
                medianprops = {'color' : 'black'},
                showfliers = False,
                linewidth = 0.3,
                ax = ax)

    sns.stripplot(data = [high_salt, low_salt], color = 'black', ax = ax, size = 0.4)

    ax.set_ylabel(f'Module score', fontsize = 4)

    ax.set_yticklabels(ax.get_yticks(), size = 4);
    ax.set_xticklabels(ax.get_xticklabels(), size = 4);

    labels = [item.get_text() for item in ax.get_yticklabels()]

    ax.set_xticklabels(['High salt', 'Low salt'])
    ax.set_yticklabels([str(round(float(label), 2)) for label in labels])

    ax.set_title(f'{data_set_name}\nWilcoxon rank sum, p-values:\n {", ".join(p_values)}', fontsize=4)
    sns.despine()
    plt.savefig(f'../figures/violin_plot_module_score_of_{data_set_name}_genes_in_high_vs_low_salt.pdf', dpi = 300, bbox_inches = 'tight')
    plt.show()
    plt.clf()
    return df

In [None]:
cyto1 = pd.read_csv('../data/gene_sets/cyto_list1.csv')
cyto1

In [None]:
df = module_score(adata, cyto1['genes'], 'Cyto_list1 GO:0001916', '3e', df)

In [None]:
tissure_residency_gustavo = ['XIST', 'UBC', 'LGALS3', 'MT-CO2', 'VIM', 'ANKRD28', 'RGS1', 'RGCC', 'HSPA1B', 'MT-ND4', 'HSP90Ab1', 'PPP1R15A']

In [None]:
df  = module_score(adata, tissure_residency_gustavo, 'Tissue residency markers in host vs. donor CD8+ T cells', 'Extended data 10c', df)

In [None]:
figures_names = ['Supplementary Figure 10d (1)', 'Supplementary Figure 10d (2)', 'Supplementary Figure 10d (3)', 'Supplementary Figure 10e (1)', 'Supplementary Figure 10e (2)', 'Supplementary Figure 10f', 'Supplementary Figure 10g']
for column_name, f in zip(column_names, figures_names):
    data_set = marker_genes_tissue_residency[column_name]
    df = module_score(adata, data_set, column_name, f, df)

## GO Pathways

In [None]:
path = Path('../data/gene_sets/')

pathways = {}
for root, dirs, files in os.walk(path):
    for file in files:
        if not 'GO-0005125' in str(file): continue
        print(file)
        with open(f'{path}/{file}') as f:
            lines = f.readlines()
            genes = [line.replace('\n', '') for line in lines]
        pathway = file.replace('_', ', ').split('.')[0]
        pathways[pathway] = genes

In [None]:
for pathway in pathways:
    data_set = pathways[pathway]
    print(f'{pathway}, Number of genes: {len(data_set)}')
    df = module_score(adata, data_set, pathway, '3d', df)

In [None]:
exhaustion_list1 = pd.read_csv('../data/gene_sets/exhaustion_list1.csv')
exhaustion_list2 = pd.read_csv('../data/gene_sets/exhaustion_list2.csv')

In [None]:
df = module_score(adata, list(exhaustion_list1['genes']), 'exhaustion_list1', 'Supplementary Figure 10h (1)', df)
df = module_score(adata, list(exhaustion_list2['genes']), 'exhaustion_list2', 'Supplementary Figure 10h (2)', df)

In [None]:
eff_list2 = pd.read_csv('../data/gene_sets/effector_list2.csv')

In [None]:
df = module_score(adata, list(eff_list2['genes']), f'Effector list2 {list(eff_list2["genes"])}', '3c', df)

In [None]:
df.to_csv('MAA_figures_info.csv', index=False)