# Attempt on integrating scRNA data with SCVI tools
## Initial imports and definition of helper functions

In [None]:
import scanpy as sc
import scvi
import glob
import os
from functools import reduce
import anndata as ad
import matplotlib.pyplot as plt

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42


def integrate_data_scvi(
    adata, 
    batch_key, 
    categorical_covariate_keys = None,
    continuous_covariate_keys = None,
    use_highly_variable_genes = True,
    n_top_genes = 4000,
    use_gpu = True,
    max_epochs = None,
    train_size = 0.9
    
):
    adata.layers['counts'] = adata.X.copy()
    adata.raw = adata
    
    if use_highly_variable_genes:
        print('computing highly variable genes')
        sc.pp.highly_variable_genes(
            adata,
            n_top_genes = n_top_genes,
            layer = 'counts',
            subset = True,
            flavor = 'seurat_v3',
        )
        
    scvi.model.SCVI.setup_anndata(
        adata,
        layer = 'counts',
        batch_key = batch_key,
        categorical_covariate_keys = categorical_covariate_keys,
        continuous_covariate_keys = continuous_covariate_keys
    )
    # non default parameters from scVI tutorial and scIB github
    # see https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html
    # and https://github.com/theislab/scib/blob/main/scib/integration.py
    model = scvi.model.SCVI(
        adata,
        n_layers = 2,
        n_latent = 30,
        gene_likelihood = 'nb'
    )
    model.train(
        use_gpu = use_gpu,
        max_epochs = max_epochs,
        train_size = train_size
    )
    adata.obsm['X_scvi'] = model.get_latent_representation()
    
    print('compute umap from scvi embedding')
    sc.pp.neighbors(
        adata,
        use_rep = 'X_scvi'
    )
    sc.tl.umap(
        adata
    )
    
    return {'data': adata, 'model': model}


def initialize_from_raw(file):
    adata = sc.read_h5ad(file)
    obs = adata.obs.copy()
    var = adata.raw.var.copy()
    adata = ad.AnnData(
        X = adata.raw.X,
        obs = obs,
        var = var,
        dtype = np.int64
    )
    return adata

def plot_leiden_clusterings(
    data_dict, 
    resolutions, 
    data_key = None, 
    legend_loc = 'on_data', 
    panelheight = 5, 
    panelwidth = 6, 
    subsample = 1, 
    size = None
):
    fig, axs = plt.subplots(
        len(data_dict), 
        len(resolutions)
    )
    
    for i,(k, d) in enumerate(data_dict.items()):
        if not data_key:
            tmp = d.copy()
        
        else:
            tmp = d[data_key].copy()
        
        idx = np.random.choice(
            tmp.obs.index,
            size = int(tmp.shape[0] * subsample),
            replace = False
        )
        
        one_dimensional = len(data_dict) == 1
        for ax, resolution in zip(axs if one_dimensional else axs[i, :], resolutions):
            sc.tl.leiden(
                tmp, 
                key_added = f'leiden_scvi_{resolution}',
                resolution = resolution
            )
            sc.pl.umap(
                tmp[idx],
                color = f'leiden_scvi_{resolution}',
                frameon = False,
                show = False,
                ax = ax,
                size = size,
                legend_loc = legend_loc
            )
    
    fig.set_figwidth(panelwidth * len(resolutions))
    fig.set_figheight(panelheight * len(data_dict))
    fig.tight_layout()
    
    return fig, axs

def plot_integration_results(data_dict, color_keys, params_list = None, data_key = None, panelheight = 5, panelwidth = 6, subsample = 1):
    fig, axs = plt.subplots(len(color_keys), len(data_dict))
    for i, (k, d) in enumerate(data_dict.items()):
        if data_key:
            data = d[data_key]
            
        else:
            data = d
            
        if not params_list:
            params_list = [{} for i in range(len(color_keys))]
            
        idx = np.random.choice(
            data.obs.index,
            size = int(data.shape[0] * subsample),
            replace = False
        )
        one_dimensional = len(data_dict) == 1
        for ax, color_key, kwargs in zip(axs if one_dimensional else axs[:, i], color_keys, params_list):
            sc.pl.umap(
                data,
                color = color_key,
                show = False,
                frameon = False,
                ax = ax,
                **kwargs
            )
        
        top_ax = axs[0] if one_dimensional else axs[0, i]
        top_ax.set_title(k)
        
    fig.set_figwidth(panelwidth * len(data_dict))
    fig.set_figheight(panelheight * len(color_keys))
    fig.tight_layout()
    
    return fig, axs

def plot_clustering_and_expression(data_dict, cluster_key, expression_keys, params_list = None, data_key = None, figwidth = 20):
    fig, axs = plt.subplots(len(expression_keys) + 1, len(data_dict))
    color_keys = [cluster_key] + expression_keys
    for i, (k, d) in enumerate(data_dict.items()):
        if data_key:
            data = d[data_key]
        
        else:
            data = d
        
        if not params_list:
            params_list = [{} for i in range(len(color_keys))]
        
        one_dimensional = len(data_dict) == 1
        for ax, color_key, kwargs in zip(axs if one_dimensional else axs[:, i], color_keys, params_list):
            sc.pl.umap(
                data,
                color = color_key,
                frameon = False,
                show = False,
                ax = ax,
                **kwargs
            )
        
        top_ax = axs[0, i] if len(axs.shape) > 1 else axs[0]
        top_ax.set_title(k)
    
    fig.set_figwidth(figwidth)
    fig.set_figheight(len(color_keys) * figwidth / (len(color_keys) + 1))
    fig.tight_layout()
    
    return fig, axs


def majority_vote(adata, prediction_col, clustering_resolution, majority_col_name = None):
    # partly taken from celltypist
    key_added = f'leiden_scvi_{clustering_resolution}'
    sc.tl.leiden(
        adata, 
        key_added = key_added,
        resolution = clustering_resolution
    )
    clustering = adata.obs.pop(key_added)
    votes = pd.crosstab(adata.obs[prediction_col], clustering)
    majority = votes.idxmax(axis=0)
    majority = majority[clustering].reset_index()
    majority.index = adata.obs.index
    
    majority_col_name = majority_col_name if majority_col_name else 'majority_voting'
    colnames = ['clustering', majority_col_name]
    majority.columns = colnames
    majority[majority_col_name] = majority[majority_col_name].astype('category')
    
    for col in colnames:
        if col in adata.obs.columns:
            adata.obs.pop(col)
    
    adata.obs = adata.obs.join(majority)
    

def assign_celltypes(adata, marker_matrix, layer, clustering_resolution = 10, normalized = False):
    if layer:
        adata.X = adata.layers[layer]
        
    bdata = adata[:, marker_matrix.index].copy()
    
    if not normalized:
        lib_size = adata.X.sum(1)
        bdata.obs['size_factor'] = np.min(lib_size) / lib_size
    
    else:
        bdata.obs['size_factor'] = 1
        
    scvi.external.CellAssign.setup_anndata(
        bdata,
        size_factor_key = 'size_factor'
    )
    model = scvi.external.CellAssign(
        bdata,
        marker_matrix
    )
    model.train()
    predictions = model.predict()
    adata.obs['predicted_labels'] = predictions.idxmax(axis=1).values
    majority_vote(
        adata,
        'predicted_labels',
        clustering_resolution
    )

## Data integration using SCVI
This step aims to harmonize the data and remove batch effectsreplace=

In [None]:
adata = sc.read_h5ad(
    '../data/skin_inflammatory_disease.qcfiltered.h5ad'
)
adata

In [None]:
# visualize raw data
tmp = adata.copy()
tmp.layers['counts'] = tmp.X.copy()
sc.pp.normalize_total(
    tmp, 
    target_sum = 1e4
)
sc.pp.log1p(tmp)
sc.pp.highly_variable_genes(
    tmp,
    n_top_genes = 4000,
    layer = "counts",
    flavor = "seurat_v3",
)
sc.pp.pca(
    tmp, 
    n_comps = 40, 
    svd_solver = 'arpack',
    use_highly_variable = True,
)
sc.pp.neighbors(
    tmp,
    use_rep = 'X_pca'
)
sc.tl.umap(tmp)

sc.pl.umap(
    tmp,
    color = ['patient_id', 'status', 'CD3D', 'FOXP3'],
    frameon = False,
    size = 5,
    vmax = 0.5
)

del tmp

In [None]:
integration_params = {
    'full': {
        'kwargs': dict(),
        'key': 'sample_id'
    }
}

results = {}
for key, params in integration_params.items():
    integration_key = params['key']
    print(key)
    cells_per_group = adata.obs.groupby(integration_key).count()
    enough = adata.obs[integration_key].apply(
        lambda x: cells_per_group.loc[x, :][0] >= 5
    )
    results[key] = integrate_data_scvi(
        adata[enough[enough].index].copy(),
        'sample_id',
        train_size = 1,
        **params['kwargs']
    )

    results[key]['data'].write(
        f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
    )

    results[key]['model'].save(
        f'../data/skin_{key}.integration.scvi.model',
        overwrite = True
    )

In [None]:
# restore results if kernel breaks or gets shut down
results = {}
for key in ['full']:
    data = sc.read_h5ad(
        f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
    )
    results[key] = {
        'data': data,
        'model': scvi.model.SCVI.load(
            f'../data/skin_{key}.integration.scvi.model', 
            adata = data
        )
    }

In [None]:
fig, axs = plot_integration_results(
    results,
    ['patient_id', 'status', 'FOXP3', 'CD3D'],
    [
        dict(size = 1, vmax = None),
        dict(size = 1, vmax = None),
        dict(size = 10, vmax = 1),
        dict(size = 10, vmax = 10)
    ],
    data_key = 'data',
    panelwidth = 7
)

## Identification of T-cell subsets
This section aims at extracting T-cells from full dataset based on CD4 and FOXP3 expression (FOXP3 is taken into account here because the ultimate goal is to extract regulatory T-cells)

In [None]:
fig, axs = plot_leiden_clusterings(
    results,
    [0.1, 0.2],
    data_key = 'data',
    legend_loc = 'right margin'
)

In [None]:
resolution = 0.1
for k, d in results.items():
    sc.tl.leiden(
        d['data'], 
        key_added = f'leiden_scvi_{resolution}',
        resolution = resolution
    )

In [None]:
fig, axs = plot_clustering_and_expression(
    results,
    f'leiden_scvi_{resolution}',
    ['CD3D', 'FOXP3'],
    [
        dict(size = 5, vmax = None),
        dict(size = 10, vmax = 10),
        dict(size = 10, vmax = 1)
    ],
    data_key = 'data'
)

In [None]:
tcell_clusters = {
    'full': [c for c in '1'.split(',')],
}

There definitely seems to be a bit of a trade off between convenience and recovery of cell identity when integrating tissue and pbmc at once. However, there is not that much difference in FOXP3 pos Tcells when integrating pbmc and tissue separately compared to integrating all at once (6,630 in full (clusters 0, 1, 4, 7, 8 and 13) and 6,613 in pbmc (clusters 0, 3, 6, 8 and 10) and tissue combinded (cluster 0, 1, 3, 4, 8)). In general this annotation could also be done with CellAssign or Celltypist to get the full spectrum of cell types but here we are currently only interested in Tcells. For code to do the full cell type annotation with either of the tools look below.

In [None]:
clustering = f'leiden_scvi_{resolution}'
for k, d in results.items():
    d['data'].obs['coarse_cell_types'] = d['data'].obs[clustering].apply(
        lambda x: 'Tcell' if x in tcell_clusters[k] else 'other'
    )
    print(k, d['data'].obs.groupby('coarse_cell_types').count().iloc[:, 0])
    
fig, axs = plot_clustering_and_expression(
    results,
    'coarse_cell_types',
    ['CD3D'],
    [
        dict(size = 1, vmax = None),
        dict(size = 10, vmax = 10)
    ],
    data_key = 'data'
)

In [None]:
for k, d in results.items():
    d['data'].write(
        f'../data/skin_inflammatory_disease.{k}.integrated.clustered.h5ad'
    )

## Integration of T-cell subsets for identification of regulatory T-cells
Here we extract the identified T-cell subset from the original dataset and integrate it anew to avoid incorporating artifacts introduced into the embedding by other celltypes we are not interested in.

In [None]:
# load data
adata_dict = {}
format_string = '../data/skin_inflammatory_disease.{k}.integrated.clustered.h5ad'
for key in ['full']:
    adata = initialize_from_raw(format_string.format(k = key))
    adata_dict[key] = adata[adata.obs.coarse_cell_types == 'Tcell'].copy()
    del adata
    
adata_dict

In [None]:
# integrate tcells
integration_params = {
    'tcells_full': {
        'kwargs': dict(),
        'key': 'sample_id'
    }
}

results = {}
for key, adata in adata_dict.items():
    key = f'tcells_{key}'
    params = integration_params[key]
    integration_key = params['key']

    print(key)
    cells_per_group = adata.obs.groupby(integration_key).count()
    enough = adata.obs[integration_key].apply(
        lambda x: cells_per_group.loc[x, :][0] >= 5
    )
    results[key] = integrate_data_scvi(
        adata[enough[enough].index].copy(),
        integration_key,
        train_size = 1,
        use_highly_variable_genes = hvg,
        **params['kwargs']
    )

    results[key]['data'].write(
        f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
    )

    results[key]['model'].save(
        f'../data/skin_{key}.integration.scvi.model',
        overwrite = True
    )

In [None]:
# restore results if kernel breaks or gets shut down
results = {}
for key in ['tcells_full']:
    data = sc.read_h5ad(
        f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
    )
    results[key + hvg] = {
        'data': data,
        'model': scvi.model.SCVI.load(
            f'../data/skin_{key + hvg}.integration.scvi.model', 
            adata = data
        )
    }

In [None]:
fig, axs = plot_integration_results(
    {k: d for k, d in results.items() if not 'hvg' in k},
    ['status', 'FOXP3', 'CD3D'],
    [
        dict(size = 1, vmax = None),
        dict(size = 10, vmax = 1),
        dict(size = 10, vmax = 10)
    ],
    data_key = 'data'
)

In [None]:
fig, axs = plot_integration_results(
    {k: d for k, d in results.items() if not 'hvg' in k},
    ['status', 'FOXP3', 'CD8A', 'CD8B', 'CD4'],
    [
        dict(size = 1, vmax = None),
        dict(size = 10, vmax = 1),
        dict(size = 10, vmax = 5),
        dict(size = 10, vmax = 5),
        dict(size = 10, vmax = 2)
    ],
    data_key = 'data'
)

## Cell type annotation of T-cell subset
This step is done manually as both CellAssign as well as CellTypist performed poorly on the dataset.

In [None]:
# load data
results = {}
for key in ['tcells_full']:
    data = sc.read_h5ad(
        f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
    )
    model = scvi.model.SCVI.load(
        f'../data/skin_{key + hvg}.integration.scvi.model', 
        adata = data
    )
    
    data.layers['scvi_normalized'] = model.get_normalized_expression(
        library_size = 1e4
    )
    
    results[key + hvg] = {
        'data': data,
        'model': model
    }

{k: d['data'] for k, d in results.items()}

In [None]:
res = 0.5
for k, data in results.items():
    adata = data['data']
    key_added = f'leiden_scvi_{res}'
    sc.tl.leiden(
        adata, 
        key_added = key_added,
        resolution = res
    )
    fig, axs = plt.subplots(1, 2)
    for ax, color_key in zip(axs, [key_added, 'FOXP3']):
        sc.pl.umap(
            adata,
            color = color_key,
            show = False,
            ax = ax,
            size = 50,
            vmax = 2,
            legend_loc = 'on data',
            legend_fontsize = 20,
            frameon = False
        )   
        cax = fig.axes[-1]
        cax.tick_params(axis = 'y', labelsize = 15)
        ax.set_title(color_key, fontsize = 20)
    
    fig.set_figwidth(15)
    fig.set_figheight(7.5)
    fig.tight_layout()

In [None]:
adata = results['tcells_full']['data']
adata.write_h5ad(
    '../data/skin_inflammatory_disease.tcells_full.clustered.h5ad'
)

In [None]:
def rank_genes_group_to_data_frame(adata, rank_keys, groupby, n_genes):
    groups = adata.obs[groupby].cat.categories
    rank_results = []
    for k in rank_keys:
        rank_data = adata.uns['rank_genes_groups'][k]
        rank_results.append(
            pd.DataFrame({g: rank_data[g][:n_genes] for g in groups}).melt(value_name = k, var_name = 'group')
        )
    
    df = pd.concat(
        [rank_results[0]] + [rank_res.iloc[:, 1] for rank_res in rank_results[1:]],
        axis = 1
    )
    return df


rank_keys = ['names', 'scores', 'logfoldchanges', 'pvals', 'pvals_adj']
groupby = 'leiden_scvi_0.5'
n_genes = 50
for k, d in results.items():
    print(k)
    adata = d['data']
    adata.layers['cpm'] = sc.pp.normalize_total(
        adata,
        target_sum = 1e4,
        inplace = False
    )['X']
    adata.layers['logcpm'] = sc.pp.log1p(
        adata.layers['cpm'],
        copy = True
    )
    sc.tl.rank_genes_groups(
        adata, 
        groupby, 
        use_raw = False,
        layer = 'logcpm',
        method = 'wilcoxon'
    )
    rank_results = rank_genes_group_to_data_frame(
        adata, 
        rank_keys, 
        groupby, 
        n_genes = n_genes
    )
    rank_results.to_csv(
        f'figures/rank_genes_groups_{groupby}_{k}_logcpm.tsv',
        sep = '\t',
        index = False
    )
    
    sc.pl.rank_genes_groups(
        adata, 
        n_genes = n_genes, 
        sharey = False,
        save = f'_{k}_logcpm'
    )

    sc.tl.rank_genes_groups(
        adata, 
        groupby, 
        use_raw = False,
        layer = 'scvi_normalized',
        method = 'wilcoxon'
    )
    rank_results = rank_genes_group_to_data_frame(
        adata, 
        rank_keys, 
        groupby, 
        n_genes = n_genes
    )
    rank_results.to_csv(
        f'figures/rank_genes_groups_{groupby}_{k}_scvi.tsv',
        sep = '\t',
        index = False
    )
    sc.pl.rank_genes_groups(
        adata, 
        n_genes = n_genes, 
        sharey = False,
        save = f'_{k}_scvi'
    )

In [None]:
annotated_clusters = pd.read_csv(
    '../annotation_celltypes.txt',
    sep = '\t',
    dtype = str
)
annotated_clusters = annotated_clusters.loc[annotated_clusters.data == 'skin', :]
celltype_annotation = annotated_clusters.loc[
    annotated_clusters.type == tissue, 
    ['leiden_cluster', 'cell_type']
].drop_duplicates()

celltype_annotation = {
    r.leiden_cluster: r.cell_type for i, r in celltype_annotation.iterrows()
}

adata = results['tcells_full']['data']
adata.obs['cell_subtype'] = adata.obs['leiden_scvi_0.5'].apply(
    lambda x: celltype_annotation[x]
)

In [None]:
adata = results[f'tcells_full']['data']
fig, axs = plt.subplots(1, 2)
for ax, color_key in zip(axs, ['cell_subtype', 'FOXP3']):
    sc.pl.umap(
        adata,
        color = color_key,
        show = False,
        ax = ax,
        size = 50,
        vmax = 2,
        legend_loc = 'on data',
        legend_fontsize = 20,
        frameon = False
    )   
    cax = fig.axes[-1]
    cax.tick_params(axis = 'y', labelsize = 15)
    ax.set_title(color_key, fontsize = 20)

fig.set_figwidth(15)
fig.set_figheight(7.5)
fig.tight_layout()

In [None]:
adata.write_h5ad(
    f'../data/skin_inflammatory_disease.tcells_full.annotated.h5ad'
)

In [None]:
results = {}
for tissue in ['tissue', 'pbmc']:
    adata = sc.read_h5ad(
        f'../data/skin_inflammatory_disease.tcells_{tissue}.annotated.h5ad'
    )
    adata.X = adata.layers['counts']
    sc.pp.normalize_total(
        adata,
        target_sum = 10e4
    )
    sc.pp.log1p(adata)
    adata.layers['logcpm'] = adata.X
    results[tissue] = adata
    
results

In [None]:
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
mpl.rcParams['pdf.fonttype'] = 42

cmap = LinearSegmentedColormap.from_list(
    'petrols',
    ['#edf8b1', '#7fcdbb', '#2c7fb8'][::-1],
    255
)

for k, adata in results.items():
    sc.tl.umap(
        adata,
        min_dist = 0.2,
        spread = 1.25
    )

    for color_key, kwargs, figsize in zip(
        ['Status', 'sample_id', 'cell_subtype', 'FOXP3', 'IKZF2', 'IL2RA'],
        [
            {'size': 20}, 
            {'size': 20}, 
            {'size': 20}, 
            {'vmax': 3, 'color_map': cmap, 'size': 20}, 
            {'vmax': 3, 'color_map': cmap, 'size': 20}, 
            {'vmax': 3, 'color_map': cmap, 'size': 20}
        ],
        [(5, 5), (5, 6), (5, 7.6), (5, 5.5), (5, 5.5), (5, 5.5)]
    ):
        fig, ax = plt.subplots()
        sc.pl.umap(
            adata,
            color = color_key,
            show = False,
            ax = ax,
            **kwargs
        )
        if color_key == 'sample_id':
            ax.get_legend().remove()

        figheight, figwidth = figsize
        fig.set_figwidth(figwidth)
        fig.set_figheight(figheight)
        fig.tight_layout()

## Cell abundance analysis with Milo

In [None]:
# load data
results = {}
format_string = '../data/skin_inflammatory_disease.{k}.integrated.h5ad'
for key in ['tcells_full']:
    results[key] = {}
    results[key]['data'] = sc.read_h5ad(
        format_string.format(k = key)
    )

In [None]:
import os
os.environ['R_HOME'] = '/users/daniel.malzl/.conda/envs/scpython/lib/R'
os.environ['TZ'] = 'Europe/Vienna' # needs to be set in some cases to avoid tzlocal error

In [None]:
import milopy
import milopy.core as milo
for k, d in results.items():
    adata = d['data']
    adata.obs['condition'] = 'healthy'
    adata.obs.loc[adata.obs.status != 'normal', 'condition'] = 'disease'
    # needs to be done in order to enforce disease vs healthy comparison
    # otherwise will take alphabetical order and do healthy vs disease
    adata.obs.loc[:, 'condition'] = pd.Categorical(
        adata.obs.condition,
        categories = ['healthy', 'disease']
    )
    milo.make_nhoods(adata)
    milo.count_nhoods(adata, sample_col = 'sample_id')
    milo.DA_nhoods(adata, design="~condition")

In [None]:
for i, (k, d) in enumerate(results.items()):
    fig, axs = plt.subplots(1, 3)
    adata = d['data']
    sc.pl.umap(
        adata,
        color = 'FOXP3',
        frameon = False,
        show = False,
        size = 50,
        ax = axs[0],
        vmax = 5
    )
    
    sc.pl.umap(
        adata,
        color = 'SAT1',
        frameon = False,
        show = False,
        size = 50,
        ax = axs[1],
        vmax = 25
    )
    
        
    milopy.utils.build_nhood_graph(adata)
    milopy.plot.plot_nhood_graph(
        adata, 
        alpha=0.1, 
        min_size=5, 
        ax = axs[2], 
        show = False
    )
    for ax in axs:
        ax.set_title(ax.get_title(), fontsize = 20)
        
    for ax in fig.axes[-2:]:
        ax.tick_params(
            labelsize = 15
        )
    fig.set_figheight(5)
    fig.set_figwidth(15)
    fig.tight_layout()

In [None]:
adata = results['tcells_full']['data']
adata

In [None]:
res = 0.5
key_added = f'leiden_scvi_{res}'
sc.tl.leiden(
    adata, 
    key_added = key_added,
    resolution = res
)
fig, axs = plt.subplots(1, 2)
for ax, color_key in zip(axs, [key_added, 'FOXP3']):
    sc.pl.umap(
        adata,
        color = color_key,
        show = False,
        ax = ax,
        size = 50,
        vmax = 2,
        legend_loc = 'on data',
        legend_fontsize = 20,
        frameon = False
    )   
    cax = fig.axes[-1]
    cax.tick_params(axis = 'y', labelsize = 15)
    ax.set_title(color_key, fontsize = 20)

fig.set_figwidth(15)
fig.set_figheight(7.5)
fig.tight_layout()

In [None]:
def annotate_adata_on_gene_hi_lo(adata, gene):
    expression_values = adata[:, gene].X.toarray().flatten()
    c = np.median(expression_values[expression_values > 0])
    return np.select(
        [expression_values < c, expression_values >= c],
        [f'{gene}_lo', f'{gene}_hi']
    )

bdata = adata[adata.obs['leiden_scvi_0.5'] == '2'].copy()
bdata.obs['sat1_status'] = annotate_adata_on_gene_hi_lo(
    bdata,
    'SAT1'
)
sc.pl.umap(
    bdata,
    color = ['SAT1', 'sat1_status'],
    vmax = 10,
    size = 10
)
majority_vote(
    bdata,
    'sat1_status',
    5,
    'sat1_status_majority_vote'
)
sc.pl.umap(
    bdata,
    color = ['SAT1', 'sat1_status', 'sat1_status_majority_vote'],
    vmax = 10,
    size = 10
)
bdata.obs.groupby('sat1_status_majority_vote').count()
adata.obs['sat1_status'] = 'SAT1_lo'
adata.obs.loc[bdata.obs.index, 'sat1_status'] = bdata.obs.sat1_status_majority_vote
milopy.utils.annotate_nhoods(adata, anno_col = 'sat1_status')

for col in ['nhood_annotation', 'nhood_annotation_frac']:
    bdata.uns['nhood_adata'].obs.loc[:, col] = adata.uns['nhood_adata'].obs.loc[:, col]

nhood_adata = bdata.uns['nhood_adata']
bdata.uns['nhood_adata'] = nhood_adata[nhood_adata.obs.index_cell.isin(bdata.obs.index)].copy()
bdata.uns['nhood_adata'].obs

# Subclustering Tregs

In [None]:
# load data
adata_dict = {}
format_string = '../data/skin_inflammatory_disease.tcells_full.annotated.h5ad'
adata = initialize_from_raw(format_string)
adata_dict['tregs'] = adata[adata.obs['cell_subtype'] == 'regulatory T'].copy()
del adata
    
adata_dict

In [None]:
fig, axs = plt.subplots(len(adata_dict), 3)
for i, (k, adata) in enumerate(adata_dict.items()):
    tmp = adata.copy()
    tmp.layers['counts'] = tmp.X.copy()
    sc.pp.normalize_total(
        tmp, 
        target_sum = 1e4
    )
    sc.pp.log1p(tmp)
    sc.pp.highly_variable_genes(
        tmp,
        n_top_genes = 4000,
        layer = "counts",
        flavor = "seurat_v3",
    )
    sc.pp.pca(
        tmp, 
        n_comps = 40, 
        svd_solver = 'arpack',
        use_highly_variable = True,
    )
    sc.pp.neighbors(
        tmp,
        use_rep = 'X_pca'
    )
    sc.tl.umap(tmp)
    
    for ax, color_key in zip(axs, ['status', 'tissue', 'SAT1']):
        sc.pl.umap(
            tmp,
            color = color_key,
            size = 10,
            vmax = 5,
            ax = ax,
            show = False
        )
        
    axs[0].set_ylabel(k)
    
    del tmp
    
fig.set_figwidth(20)
fig.set_figheight(len(adata_dict) * 5)
fig.tight_layout()

In [None]:
# integrate tcells
integration_params = {
    'tregs': {
        'kwargs': dict(),
        'key': 'sample_id
    }
}

results = {}
    for key, adata in adata_dict.items():
        key = key.replace('tcells', 'tregs')
        params = integration_params[key]
        integration_key = params['key']
            
        print(key)
        cells_per_group = adata.obs.groupby(integration_key).count()
        enough = adata.obs[integration_key].apply(
            lambda x: cells_per_group.loc[x, :][0] >= 5
        )
        adata = adata[enough[enough].index].copy()
        results[key] = integrate_data_scvi(
            adata.copy(),
            integration_key,
            train_size = 1,
            use_highly_variable_genes = hvg,
            **params['kwargs']
        )
        
        results[key]['data'].write(
            f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
        )

        results[key]['model'].save(
            f'../data/skin_{key}.integration.scvi.model',
            overwrite = True
        )

In [None]:
# restore results if kernel breaks or gets shut down
results = {}
for key in ['tregs']:
    data = sc.read_h5ad(
        f'../data/skin_inflammatory_disease.{key}.integrated.h5ad'
    )
    results[key] = {
        'data': data,
        'model': scvi.model.SCVI.load(
            f'../data/skin_{key}.integration.scvi.model', 
            adata = data
        )
    }

In [None]:
for k, d in results.items():
    d['data'].layers['scvi_normalized'] = d['model'].get_normalized_expression(library_size = 1e4)

In [None]:
fig, axs = plot_integration_results(
    {k: d for k, d in results.items() if not 'hvg' in k},
    ['tissue', 'status', 'SAT1'],
    [
        dict(size = 50, vmax = None),
        dict(size = 50, vmax = None),
        dict(size = 50, vmax = 20, layer = 'scvi_normalized')
    ],
    data_key = 'data'
)

In [None]:
import seaborn as sns
fig, axs = plt.subplots(1, 3)
for ax, (k, d) in zip(axs, results.items()):
    adata = d['data']
    # adata = d['data'].copy()
    # sc.pp.normalize_total(adata, target_sum = 1e4)
    # sc.pp.log1p(adata)
    # sc.pp.scale(adata)
    # expr_vals = adata[:, 'SAT1'].X.toarray()
    # expr_vals = np.log(d['data'][:, 'SAT1'].layers['scvi_normalized'].toarray())
    expr_vals = adata[:, 'SAT1'].layers['scvi_normalized'].toarray()
    # mean, std = expr_vals.mean(), expr_vals.std()
    df = pd.DataFrame(
        expr_vals,
        index = adata.obs.index,
        columns = ['SAT1']
    )
    sns.histplot(
        x = 'SAT1',
        data = df,
        ax = ax
    )
    
    # alternatively use percentiles
    percentiles = df['SAT1'].quantile([0.5, 0.75, 0.9, 0.95]).to_dict()
    for i, (p, v) in enumerate(percentiles.items()):
        p = int(p * 100)
    # for p, v in zip(['mean', '1xStd'], [mean, mean + std]):
        ax.axvline(v, ls = '--')
        ax.text(
            v,
            ax.get_ylim()[1] - ax.get_ylim()[1] * 0.05 * (i + 1),
            '{}% ({})'.format(p, (df['SAT1'] >= v).sum()),
            va = 'top'
        )

    
fig.set_figwidth(20)
fig.set_figheight(6)
fig.tight_layout()

In [None]:
for k, d in results.items(): 
    adata = d['data']
    # expr_vals = np.log(adata[:, 'SAT1'].layers['scvi_normalized'].toarray())
    expr_vals = adata[:, 'SAT1'].layers['scvi_normalized'].toarray()
    df = pd.DataFrame(
        expr_vals,
        index = adata.obs.index,
        columns = ['SAT1']
    )
    
    # mean, std = expr_vals.mean(), expr_vals.std()
    
    x = expr_vals
    # c = mean + std
    c = df['SAT1'].quantile(0.5)
    adata.obs['sat1_status'] = np.select(
        [x < c, x >= c],
        ['lo', 'hi']
    )
    
fig, axs = plot_integration_results(
    {k: d for k, d in results.items() if not 'hvg' in k},
    ['tissue', 'status', 'sat1_status', 'SAT1'],
    [
        dict(size = 50, vmax = None),
        dict(size = 50, vmax = None),
        dict(size = 50, vmax = None),
        dict(size = 50, vmax = 20, layer = 'scvi_normalized')
    ],
    data_key = 'data'
)

In [None]:
fig, axs = plt.subplots(3, 3)
for k, d in results.items():
    majority_vote(
        d['data'],
        'sat1_status',
        5,
        'sat1_status_majority_vote'
    )
    
for i, (k, d) in enumerate(results.items()):
    col_axs = axs[:, i]
    sc.pl.umap(
        d['data'],
        color = 'sat1_status_majority_vote',
        frameon = False,
        show = False,
        ax = col_axs[0],
        size = 50
    )
    col_axs[0].set_title(k)
    
    sc.pl.umap(
        d['data'],
        color = 'SAT1',
        layer = 'scvi_normalized',
        frameon = False,
        show = False,
        ax = col_axs[1],
        vmax = 20,
        size = 50
    )
    
    sc.pl.violin(
        d['data'],
        groupby = 'sat1_status_majority_vote',
        layer = 'scvi_normalized',
        keys = ['SAT1'],
        log = True,
        use_raw = False,
        ax = col_axs[2],
        show = False
    )
    

fig.set_figwidth(15)
fig.set_figheight(15)
fig.tight_layout()
fig.savefig('../plots/skin_inflammatory_disease_tregs_sat1hi_annotated.png')

In [None]:
for k, d in results.items():
    print(d['data'].obs.groupby('sat1_status_majority_vote').count().iloc[:, 0])

In [None]:
for k, d in results.items():
    d['data'].write(
        f'../data/skin_inflammatory_disease.{k}.integrated.annotated.h5ad'
    )

In [None]:
# load data
results = {}
format_string = '../data/skin_inflammatory_disease.{k}.integrated.annotated.h5ad'
for key in ['tregs']:
    results[key] = {}
    results[key]['data'] = sc.read_h5ad(
        format_string.format(k = key)
    )

In [None]:
import os
os.environ['R_HOME'] = '/users/daniel.malzl/.conda/envs/scpython/lib/R'
os.environ['TZ'] = 'Europe/Vienna' # needs to be set in some cases to avoid tzlocal error

In [None]:
import milopy
import milopy.core as milo
for k, d in results.items():
    adata = d['data']
    adata.obs['condition'] = 'healthy'
    adata.obs.loc[adata.obs.status != 'normal', 'condition'] = 'disease'
    # needs to be done in order to enforce disease vs healthy comparison
    # otherwise will take alphabetical order and do healthy vs disease
    adata.obs.loc[:, 'condition'] = pd.Categorical(
        adata.obs.condition,
        categories = ['healthy', 'disease']
    )
    milo.make_nhoods(adata)
    milo.count_nhoods(adata, sample_col = 'sample_id')
    milo.DA_nhoods(adata, design="~ condition")
    milopy.utils.annotate_nhoods(adata, anno_col = 'sat1_status_majority_vote')
    # adata.uns['nhood_adata'].obs['sat1_status'] = adata.uns['nhood_adata'].obs[['nhood_annotation', 'nhood_annotation_frac']].apply(
    #     lambda x: x[0] if x[1] > 0.7 else 'mixed',
    #     axis = 1
    # )

In [None]:
for i, (k, d) in enumerate(results.items()):
    fig, axs = plt.subplots(1, 2)
    adata = d['data']
    sc.pl.umap(
        adata,
        color = 'SAT1',
        frameon = False,
        show = False,
        size = 50,
        ax = axs[0],
        vmax = 15
    )
    
    milopy.utils.build_nhood_graph(adata)
    milopy.plot.plot_nhood_graph(
        adata, 
        alpha=0.5, 
        min_size=5, 
        ax = axs[1], 
        show = False
    )
    fig.set_figheight(4)
    fig.set_figwidth(10)
    fig.tight_layout()
    fig.savefig('../plots/tregs_DA.umap.pdf')

In [None]:
adata.uns['nhood_adata'].obs

In [None]:
import seaborn as sns
def annotate_enrichment(x):
    if x['SpatialFDR'] < 0.1:
        return 'enriched' if x['logFC'] > 0 else 'depleted'
    
    else:
        return 'not_significant'

def plot_nhood_violin(adata):
    fig, ax = plt.subplots()
    df = adata.uns['nhood_adata'].obs.copy()
    #df['nhood_size'] = np.array(adata.uns['nhood_adata'].X.sum(1)).flatten()
    df['enriched'] = df[['logFC', 'SpatialFDR']].apply(
        annotate_enrichment, axis = 1
    )
    sns.violinplot(
        y = 'nhood_annotation',
        x = 'logFC',
        data = df,
        color = '#f2f2f2'
    )
    sns.stripplot(
        y = 'nhood_annotation',
        x = 'logFC',
        data = df,
        hue = 'enriched',
        palette = {
            'enriched': '#f44336',
            'not_significant': '#bcbcbc',
            'depleted': '#6fa8dc'
        },
        ax = ax
    )
    ax.set_ylabel('SAT1 status', fontsize = 20)
    ax.set_xlabel('logFC', fontsize = 20)
    ax.set_yticklabels(
        ['SAT1 low', 'SAT1 high']
    )
    ax.tick_params(
        labelsize = 15
    )
    plt.legend(
        loc='upper left', 
        bbox_to_anchor=(1, 1), 
        frameon=False,
        fontsize = 15
    )
    fig.set_figwidth(10)
    fig.set_figheight(5)
    fig.tight_layout()
    return fig, ax

for k, d in results.items():
    fig, ax = plot_nhood_violin(d['data'])
    fig.savefig(f'../plots/{k}.diffabundance.violin.pdf')

In [None]:
fig, axs = plt.subplots(3, 1)
markers = ['SAT1', 'ODC1', 'SMS', 'SRM', 'ARG2']
for ax, (k, d) in zip(axs, results.items()):
    adata = d['data'].copy()
    adata.obs['Status2'] = adata.obs.Status.apply(
        lambda x: 'healthy' if x == 'Normal' else 'disease'
    )
    sc.pl.dotplot(
        adata,
        layer = 'scvi_normalized',
        groupby = 'Status2',
        var_names = markers,
        ax = ax,
        expression_cutoff = 0.5,
        show = False
    )
    ax.set_title(k, fontsize = 17)
    for subax in fig.axes[-4:]:
        subax.tick_params(labelsize = 15)
        subax.set_title(subax.get_title().replace('\n', ' '), fontsize = 17)

fig.set_figwidth(10)
fig.set_figheight(10)
fig.tight_layout()
fig.savefig('../plots/skin_inflammatory_disease_tregs_polyamine_hvd_dotplot.pdf')

In [None]:
markers = ['S1PR4', 'CXCR4', 'CXCR6', 'FABP4', 'LGALS3', 'ITGAE', 'RUNX3', 'PRDM1', 'CD69', 'ICOS', 'CD28']
fig, axs = plt.subplots(3, 1)
for ax, (k, d) in zip(axs, results.items()):
    adata = d['data'].copy()
    adata.obs['Status2'] = adata.obs.Status.apply(
        lambda x: 'healthy' if x == 'Normal' else 'disease'
    )
    sc.pl.dotplot(
        adata,
        layer = 'scvi_normalized',
        groupby = 'Status2',
        var_names = markers,
        ax = ax,
        expression_cutoff = 1,
        show = False
    )
    ax.set_title(k, fontsize = 17)
    for subax in fig.axes[-4:]:
        subax.tick_params(labelsize = 15)
        subax.set_title(subax.get_title().replace('\n', ' '), fontsize = 17)

fig.set_figwidth(10)
fig.set_figheight(10)
fig.tight_layout()
fig.savefig('../plots/skin_inflammatory_disease_tregs_trm_status_dotplot.pdf')

In [None]:
fig, axs = plt.subplots(3, 1)
for ax, (k, d) in zip(axs, results.items()):
    adata = d['data'].copy()
    sc.pl.dotplot(
        adata[adata.obs.Status != 'Normal'],
        layer = 'scvi_normalized',
        groupby = 'sat1_status_majority_vote',
        var_names = markers,
        ax = ax,
        expression_cutoff = 1,
        show = False
    )
    ax.set_title(k, fontsize = 17)
    for subax in fig.axes[-4:]:
        subax.tick_params(labelsize = 15)
        subax.set_title(subax.get_title().replace('\n', ' '), fontsize = 17)

fig.set_figwidth(10)
fig.set_figheight(10)
fig.tight_layout()
fig.savefig('../plots/skin_inflammatory_disease_tregs_disease_trm_sat1_dotplot.pdf')

# A simple way to pick apart CD4 and CD8 Tcells

In [None]:
# naive_CD4 and naive_CD8 Tcells are very similar except for expression of CD4 CD8 receptor
# CellAssign does not seem to be able to pick them apart based on the given marker
# we therefore do this manually based on the difference of the two genes

diff = np.diff(bdata[bdata.obs.CellAssign == 'naive_CD8', ['CD8A', 'CD4']].X).flatten()
diff.sort()
cutidx = np.where(diff > 0)[0][0]
fig, ax = plt.subplots()
ax.plot(range(len(diff)), diff)
ax.axvline(cutidx)

In [None]:
def reassign_cells(x):
    if x['CellAssign'] != 'naive_CD8':
        return x['CellAssign']
    
    elif x['diff'] > 0:
        return 'naive_CD4'
    
    else:
        return x['CellAssign']
    
bdata.obs['diff'] = np.diff(bdata[:, ['CD8A', 'CD4']].X).flatten()
bdata.obs.loc[:, ['CellAssign']] = bdata.obs[['CellAssign', 'diff']].apply(
    reassign_cells,
    axis = 1
)