# Prepare raw data

In [None]:
import pandas as pd
import scanpy as sc
import numpy as np
from scipy.sparse import csr_matrix
import anndata as ad

# convenience functions for single cell analysis
from sctools import plot, integrate, io, score

In [None]:
sample_status_txt = '''A3,inflamed
A2,non-inflamed
A1,healthy
B3,inflamed
B2,non-inflamed
B1,healthy
C3,inflamed
C2,non-inflamed
C1,healthy'''

sample_status = {
    s.split(',')[0]: s.split(',')[1]
    for s
    in sample_status_txt.split('\n')
}
sample_status

In [None]:
data = pd.read_csv(
    snakemake.input.matrix_tsv,
    sep = '\t',
    compression = 'gzip',
    index_col = 0
)
data

In [None]:

X = csr_matrix(data.T.values)
X.eliminate_zeros()
obs = pd.DataFrame(
    index = data.columns
)
obs['sampleid'] = [i.split('-')[-1] for i in obs.index]
obs['status'] = obs.sampleid.apply(
    lambda x: sample_status[x]
)
var = pd.DataFrame(
    index = data.index
)
adata = ad.AnnData(
    X = X,
    obs = obs,
    var = var
)
adata

# Filter genes

In [None]:

nexpressed_threshold = 10
ngenes_passed = (adata.X.sum(axis = 0) >= nexpressed_threshold).sum()
ngenes = adata.var.shape[0]
print(
    f'{ngenes_passed} of {ngenes} are retained requiring their expression in {nexpressed_threshold} or more cells'
)

In [None]:

filtered = adata[:, adata.X.sum(axis = 0) >= nexpressed_threshold].copy()
filtered

# Plot raw data UMAP

In [None]:
def compute_raw_umap(adata):
    tmp = adata.copy()
    # don't need this here because it is already normalized and logged
    # sc.pp.normalize_total(
    #     tmp, 
    #     target_sum = 1e4
    # )
    # sc.pp.log1p(tmp)
    sc.pp.pca(
        tmp, 
        n_comps = 40, 
        svd_solver = 'arpack'
    )
    sc.pp.neighbors(
        tmp,
        use_rep = 'X_pca'
    )
    sc.tl.umap(tmp)
    return tmp

In [None]:
raw = compute_raw_umap(filtered)

In [None]:

plot.misc.generate_and_save_for_figure(
    raw,
    {
        'sampleid': (None, None), 
        'status': (snakemake.params.condition_palette, None),
    },
    snakemake.output.umap_raw,
    snakemake.output.umap_raw_legend,
    size = 40,
    edgecolor = 'k',
    linewidths = 0.3
)

# Integrate with scVI

In [None]:

filtered.X = filtered.X.expm1()
integrated = integrate.integrate_data_scvi(
    filtered.copy(),
    'sampleid',
    train_size = 1
)

# integrated['data'].write(
#     '../data/data.integrated.h5ad'
# )

# integrated['model'].save(
#     '../data/data.integration.scvi.model',
#     overwrite = True
# )

In [None]:
??plot.integrate.plot_integration_results

In [None]:

fig, axs = plot.integrate.plot_integration_results(
    {'data': integrated},
    ['status', 'sampleid'],
    [
        dict(size = 10, vmax = None),
        dict(size = 10, vmax = None)
    ],
    data_key='data'
)

# Plot integrated UMAPs

In [None]:
plot.misc.generate_and_save_for_figure(
    integrated["data"],
    {
        'sampleid': (None, None), 
        'status': (snakemake.params.condition_palette, None),
    },
    snakemake.output.umap_integrated,
    snakemake.output.umap_integrated_legend,
    size = 40,
    edgecolor = 'k',
    linewidths = 0.3,
)

# Cell type annotation with celltypist

In [None]:
# this is the recommended resolution for majority vote in celltypist
# we need to do this manually on integrated because integrated is only hvg
resolution = 10
sc.tl.leiden(
    integrated["data"],
    key_added = f'leiden_scvi_{resolution}',
    resolution = resolution
)

In [None]:

# reinitialize from raw to get all genes
bdata = io.initialize_from_raw(integrated["data"])

# need log1p because saved data is expm1
sc.pp.log1p(bdata)

In [None]:
import celltypist as ct
model = ct.models.Model.load(model=snakemake.params.celltypist_model)
model

In [None]:
predictions = ct.annotate(
    bdata,
    model=snakemake.params.celltypist_model,
    majority_voting = True,  # following the celltypist tutorial
    over_clustering = f'leiden_scvi_{resolution}'
)

In [None]:
cdata = predictions.to_adata()
integrated["data"].obs['cell_type_coarse'] = cdata.obs.majority_voting
integrated["data"].obs['cell_type_fine'] = cdata.obs.predicted_labels  # the label pre-majority voting

In [None]:
sc.pl.umap(
    integrated["data"],
    color = ['cell_type_coarse', 'cell_type_fine', 'LGR5'],
    legend_loc = 'on data',
    frameon = False,
    size = 20,
    edgecolor = 'white',
    lw = 0.2,
    vmax = 1
)

In [None]:
(integrated["data"].obs.cell_type_fine == 'Stem cells').sum()

# Plot annotation UMAPs

In [None]:
adata = integrated["data"]
celltype_color_palettes = {
    k: v for k, v 
    in zip(
        adata.obs.cell_type_coarse.cat.categories, 
        adata.uns['cell_type_coarse_colors']
    )
}
celltype_color_palettes['Stem cells'] = '#EF5A7E'

In [None]:

selected_cell_types = [
    y if y == 'Stem cells' else x 
    for x, y 
    in zip(
        adata.obs['cell_type_coarse'].values.to_list(), 
        adata.obs['cell_type_fine'].values.to_list()
    )
]

adata.obs['cell_type_select'] = selected_cell_types

In [None]:
plot.misc.generate_and_save_for_figure(
    adata,
    {
        'cell_type_coarse': (celltype_color_palettes, None),
        'cell_type_select': (celltype_color_palettes, None),
    },
    snakemake.output.umap_annotated,
    snakemake.output.umap_annotated_legend,
    size = 40,
    edgecolor = 'k',
    linewidths = 0.3
)

# Differential expression analysis

In [None]:
adata = io.initialize_from_raw(
    integrated["data"]
)
sc.pp.log1p(adata)
sc.tl.rank_genes_groups(
    adata,
    groupby = 'cell_type_coarse',
    groups = ['TA'],
    reference = 'rest',
    method = 'wilcoxon',
    corr_method = 'benjamini-hochberg'
)

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np

de_ta = pd.DataFrame.from_dict(
    dict(
        genes = [name[0] for name in adata.uns['rank_genes_groups']['names']],
        padj = [p[0] for p in adata.uns['rank_genes_groups']['pvals_adj']],
        lfc = [lfc[0] for lfc in adata.uns['rank_genes_groups']['logfoldchanges']]
    )
)
de_ta['-log10padj'] = -np.log10(de_ta.padj)
de_ta.index = de_ta.genes
ax = sns.scatterplot(
    data = de_ta,
    x = 'lfc',
    y = '-log10padj'
)

In [None]:
sc.tl.rank_genes_groups(
    adata,
    groupby = 'cell_type_fine',
    groups = ['Stem cells'],
    reference = 'rest',
    method = 'wilcoxon',
    corr_method = 'benjamini-hochberg'
)

In [None]:
de_sc = pd.DataFrame.from_dict(
    dict(
        genes = [name[0] for name in adata.uns['rank_genes_groups']['names']],
        padj = [p[0] for p in adata.uns['rank_genes_groups']['pvals_adj']],
        lfc = [lfc[0] for lfc in adata.uns['rank_genes_groups']['logfoldchanges']]
    )
)
de_sc['-log10padj'] = -np.log10(de_sc.padj)
de_sc.index = de_sc.genes
ax = sns.scatterplot(
    data = de_sc,
    x = 'lfc',
    y = '-log10padj'
)

# Compute stemness score as described by Tirosh et al. 2016

In [None]:
cell_index = (adata.obs.cell_type_select == 'Stem cells') & (adata.obs.status != 'healthy')
stem_cells = adata[cell_index, :]
stem_cells

In [None]:
genes = set(stem_cells.var.index) & set(snakemake.params.stemness_markers)

In [None]:
import pandas as pd


score_df = pd.DataFrame(
    {
        'stemness_score': score.gene_module_score(stem_cells, list(genes)),
        'status': ['inflamed' if x == 'inflamed' else 'non-inflamed' for x in stem_cells.obs.status]
    },
    index = stem_cells.obs.index
)
score_df

In [None]:
import scipy.stats as stats
tres = stats.ttest_ind(
    score_df.loc[score_df.status == 'inflamed', 'stemness_score'],
    score_df.loc[score_df.status == 'non-inflamed', 'stemness_score'],
    alternative = 'less'
)
tres

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
sns.violinplot(
    data = score_df,
    y = 'status',
    x = 'stemness_score',
    ax = ax,
)
ax.text(
    0.4,
    0.5,
    'p = {:.4e}'.format(tres.pvalue),
    ha = 'right'
)
fig.set_figheight(3)
fig.set_figwidth(6)
fig.tight_layout()
fig.savefig(snakemake.output.stemness_score_plot)

# Rank plots

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


def map_group(x, lfc_cut, padj_cut):
    signif =  (abs(x.lfc) > lfc_cut) & (x.padj < padj_cut)
    
    if signif:
        return 'sig'
    
    return 'n.s.'


def label_point(poi, ax):
    x = poi['rank']
    y = poi['lfc']
    ax.plot(
        [x, 600],
        [y] * 2,
        c = 'k',
        lw = 0.5
    )
    gene_name = poi['genes']
    padj = poi['padj']
    
    ax.text(
        600, y,
        f'{gene_name} padj = {padj:.3e}',
        va = 'center',
        ha = 'left'
    )
    

palette = {
    'sig': (178/256, 24/256, 43/256),
    'n.s.': 'lightgrey'
}

zorders = {
    'sig': 2,
    'n.s.': 1
}

fig, axs = plt.subplots(2)

for ax, df, title, label in zip(
    axs,
    [de_ta, de_sc],
    ['TA cells', 'stem cells'],
    ['TA', 'SC']
):
    df = df.loc[df.lfc > 0, :]
    df['rank'] = df.lfc.rank(ascending = False)
    df['-log10padj'] = -np.log10(df.padj)
    df['group'] = df.apply(
        map_group,
        axis = 1,
        lfc_cut = 1,
        padj_cut = 1e-4
    )
    for g, gdf in df.groupby('group'):
        ax.scatter(
            x = gdf['rank'],
            y = gdf['lfc'],
            color = palette[g],
            zorder = zorders[g],
            label = g,
            edgecolors = 'white',
            linewidths = 0.5
        )
        
        ax.axhline(
            1,
            ls = '--',
            c = 'grey',
            lw = 1
        )
    
    label_point(df.loc['LGR5', :], ax)
    
    ax.legend()
    ax.set_ylabel('log2FC')
    ax.set_xlabel('log2FC rank')
    ax.set_title(title)
            
fig.set_figwidth(3)
fig.set_figheight(6)
fig.tight_layout()
fig.savefig(snakemake.output.rank_plot)

# Cellwhisperer embedding UMAPs

In [None]:
import scanpy as sc
cwdata = sc.read_h5ad(snakemake.input.cellwhisperer_dataset)
cwdata

In [None]:
cwdata.obsm['X_umap'] = cwdata.obsm['X_cellwhisperer_umap']

In [None]:
plot.misc.generate_and_save_for_figure(
    cwdata,
    {
        'sample_id': (None, None), 
        'condition': (snakemake.params.condition_palette, None),
    },
    snakemake.output.umap_cellwhisperer,
    snakemake.output.umap_cellwhisperer_legend,
    size = 40,
    edgecolor = 'k',
    linewidths = 0.3
)