In [None]:
import pandas as pd
data = pd.read_csv(
    '../data/GeneGroup_Quant_224Samples.csv',
    index_col=0
)
data = data.T
data

In [None]:
obs = pd.read_excel(
    '../data/Sample Info_D0-98.xlsx',
    index_col=0,
    skiprows=1
)
obs.loc[:, 'Age [weeks]'] = pd.Categorical(obs['Age [weeks]'])
# we don't do this since it does not make sense biologically
# obs.loc[
#     (obs['Timepoint Isolation'] == 'D7') &
#     (obs['Condition'] == 'Naive'),
#     'Timepoint Isolation'
# ] = 'D0'
obs.rename(
    columns = {
        'Condition': 'condition',
        'Sex': 'sex',
        'Timepoint Isolation': 'time',
        'Age [weeks]': 'age'
    },
    inplace = True
)
obs

In [None]:
# one of the samples was wrongfully renamed in the data processing
# and thus does not match between data and metadata frame
# we correct this manually here
obs.rename(
    index = {'Sham_M_4w_D98_267': 'Sham_M_4w_D98_268'},
    inplace = True
)

# Check normality assumption for data

In [None]:
import anndata as ad
import scanpy as sc
import seaborn as sns

from mefistotools import preprocess


filtered_df = preprocess.filter_high_nan_features_by_group(
    data,
    obs,
    ['condition', 'time'],
    0.25
)
adata = ad.AnnData(
    X = filtered_df,
    var = pd.DataFrame(index = filtered_df.columns),
    obs = obs.reindex(filtered_df.index)
)
# sc.pp.log1p(adata)
d = adata.X.flatten()
sns.histplot(
    x = adata.X.flatten(),
    bins = 50,
    binrange = [0, 1e6]
    
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


fig, axs = plt.subplots(10, 10)

non_normal_sample = np.random.choice(
    np.arange(0, len(adata.var)), 
    size = 100, 
    replace = False
)
for ax, col_idx in zip(axs.reshape(100), non_normal_sample):
    sns.histplot(
        adata.X[:, col_idx],
        ax = ax
    )

fig.set_figwidth(20)
fig.set_figheight(20)
fig.tight_layout()

In [None]:
sc.pp.log1p(adata)
d = adata.X.flatten()
sns.histplot(
    x = adata.X.flatten(),
    bins = 50
)

In [None]:
import numpy as np

from scipy.stats import shapiro
from statsmodels.stats.multitest import fdrcorrection

pvalues = []
for col in adata.X.T:
    nan_idx = np.isnan(col)
    pvalues.append(
        shapiro(col[~nan_idx]).pvalue
    )

reject, padj = fdrcorrection(pvalues)
sum(reject)

In [None]:
fig, axs = plt.subplots(10, 10)
sc.pp.log1p(adata)
for ax, col_idx in zip(axs.reshape(100), non_normal_sample):
    sns.histplot(
        adata.X[:, col_idx],
        ax = ax
    )

fig.set_figwidth(20)
fig.set_figheight(20)
fig.tight_layout()

# Filtering for time and condition varying data

In [None]:
from mefistotools import preprocess
import warnings
import numpy as np
        

subset = obs[
    (
        (obs.condition != 'Naive') &
        (obs.time != 'D98')
    )
]
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    anova_results = preprocess.featurewise_two_way_anova(
        np.log1p(data.loc[subset.index, :]),
        obs,
        ['condition', 'time'],
        allowed_nan_fraction = 0.25
    )

anova_results

In [None]:
variable_features_index = anova_results[
    (
        (anova_results['padj_C(condition)'] < 0.1) |
        (anova_results['padj_C(time)'] < 0.1) |
        (anova_results['padj_C(condition):C(time)'] < 0.1)
    )
].index

# Export for variance partitioning

In [None]:
for suffix, index in zip(
    ['full', 'anova'],
    [
        np.ones(shape = len(data.columns), dtype = bool), 
        variable_features_index,
        
    ]
):
    df = preprocess.filter_high_nan_features_by_group(
        data.loc[:, index],
        obs,
        ['condition', 'time'],
        0
    )
    df.to_csv(f'../data/data_{suffix}.csv')
    obs.reindex(df.index).to_csv('../data/metadata.csv')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

var_part_full = pd.read_csv('../data/var_part_full.csv', index_col = 0)
var_part_full['data'] = 'full'
var_part_anova = pd.read_csv('../data/var_part_anova.csv', index_col = 0)
var_part_anova['data'] = 'anova_filtered'
df = pd.concat(
    [var_part_full, var_part_anova]
)
df = df.melt(
    id_vars = ['data'],
    value_vars = ['condition', 'sex', 'time', 'age', 'Residuals'],
    value_name = 'variance_explained',
    var_name = 'covariate'
)

fig, ax = plt.subplots()
sns.violinplot(
    data = df,
    x = 'covariate',
    y = 'variance_explained',
    hue = 'data',
    ax = ax,
    cut = 0
)
fig.set_figwidth(6)
fig.set_figheight(4)
fig.tight_layout()
fig.savefig('varpart.pdf')

In [None]:
mean_var_expl = df.groupby(['data', 'covariate']).mean()
anova_mean_var_expl = mean_var_expl.loc[('anova_filtered',), :]
anova_mean_var_expl.loc[anova_mean_var_expl.index != 'Residuals', :].sum()

# Generate datasets

In [None]:
filtered_df = preprocess.filter_high_nan_features_by_group(
    data,
    obs,
    ['condition', 'time'],
    0.25
)
nan_filtered_features_index = filtered_df.columns

In [None]:
import anndata as ad

only_early = (obs.time != 'D98')
only_non_naive = (obs.condition != 'Naive')
only_early_index = obs[only_early].index
only_non_naive_and_early_index = obs[only_early & only_non_naive].index


# anndata converts float64 to float32 which results in loss of some precision
# don't think it changes something but I leave this here in case somebody wonders
# why the values are slightly different between the original DataFrame and the AnnData
datasets = {
    'anova_filtered_with_naive': ad.AnnData(
        X = data.loc[
            only_early_index, 
            variable_features_index
        ],
        obs = obs.loc[only_early_index, :].reindex(only_early_index),
        var = pd.DataFrame(index = variable_features_index)
    ),
    'anova_filtered_without_naive': ad.AnnData(
        X = data.loc[
            only_non_naive_and_early_index, 
            variable_features_index
        ],
        obs = obs.loc[only_non_naive_and_early_index, :].reindex(
            only_non_naive_and_early_index
        ),
        var = pd.DataFrame(index = variable_features_index)
    ),
    'nan_filtered_with_naive': ad.AnnData(
        X = data.loc[
            only_early_index, 
            nan_filtered_features_index
        ],
        obs = obs.loc[only_early_index, :].reindex(only_early_index),
        var = pd.DataFrame(index = nan_filtered_features_index)
    ),
    'nan_filtered_without_naive': ad.AnnData(
        X = data.loc[
            only_non_naive_and_early_index, 
            nan_filtered_features_index
        ],
        obs = obs.loc[only_non_naive_and_early_index, :].reindex(
            only_non_naive_and_early_index
        ),
        var = pd.DataFrame(index = nan_filtered_features_index)
    )
}

In [None]:
for adata in datasets.values():
    adata.layers['raw'] = adata.X.copy()
    
datasets

# Export for crosscheck with R version
Skipped because R uses python implementation underneath. Leave it here anyway for documentation reasons

In [None]:
# adata = datasets['anova_filtered_with_naive']
# data_df = adata.to_df()
# data_df = data_df.merge(
#     adata.obs[['condition', 'time']].rename(
#         columns = {'condition': 'group'}
#     ),
#     left_index = True,
#     right_index = True,
#     how = 'inner'
# )

# time_to_int = {t: i for i, t in enumerate(['D7', 'D10', 'D14'])}
# data_df.loc[:, 'time'] = data_df.time.apply(
#     lambda x: time_to_int[x]
# )
# data_df.reset_index(
#     names = ['sample'],
#     inplace = True
# )
# data_df_long = data_df.melt(
#     id_vars = ['sample', 'group', 'time'],
#     value_vars = data_df.columns,
#     value_name = 'value',
#     var_name = 'feature'
# )
# data_df_long.to_csv(
#     '../data/data_for_mefisto_r.csv',
#     index = False
# )

# Train MEFISTO in different configurations

In [None]:
import muon as mu
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib as mpl

from mefistotools import plot

import os


mpl.rcParams['pdf.fonttype'] = 42


time_ordering = ['D7', 'D10', 'D14']
for n_factors in [20, 50]:
    for k, adata in datasets.items():
        print(k, n_factors)
        plot_dir = f'../plots/{k}_{n_factors}'
        if not os.path.exists(plot_dir):
            os.mkdir(plot_dir)
            
        adata.X = adata.layers['raw'].copy()
        figs = plot.plot_reduced_dimensions(
            adata,
            n_cols = 5,
            n_rows = 3
        )

        for fig, dim_red_type in zip(figs, ['pca', 'umap']):
            fig.savefig(
                f'{plot_dir}/{k}_{n_factors}.{dim_red_type}.pdf'
            )
            plt.close(fig)

        # is needed since mefisto can only interpret numbers for smooth covariate
        factors = {
            k: i for i, k in enumerate(time_ordering)
        }
        adata.obs['timefactor'] = adata.obs['time'].apply(
            lambda x: factors[x]
        )
        
        adata.X = adata.layers['raw'].copy()
        sc.pp.log1p(adata)
        adata.write(f'../data/{k}_{n_factors}.h5ad')

        mu.tl.mofa(
            adata, 
            n_factors = n_factors, # number of factors to fit
            groups_label = 'condition', # column of adata.obs to use for data grouping
            center_groups = False,
            n_iterations = 2000,
            smooth_covariate = 'timefactor', # column to use as time variable
            smooth_kwargs = {"n_grid": 50, "start_opt": 50, "opt_freq": 50}, # additional arguments for MEFISTO
            outfile = f'../models/{k}_{n_factors}.h5ad',
            seed = 2023,
            convergence_mode = 'fast'
        )

In [None]:
import muon as mu
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib as mpl

from mefistotools import plot

import os


mpl.rcParams['pdf.fonttype'] = 42


time_ordering = ['D7', 'D10', 'D14']
for n_factors in [20, 50]:
    for k, adata in datasets.items():
        print(k, n_factors)
        plot_dir = f'../plots/{k}_{n_factors}_nogroup'
        if not os.path.exists(plot_dir):
            os.mkdir(plot_dir)
            
        adata.X = adata.layers['raw'].copy()
        figs = plot.plot_reduced_dimensions(
            adata,
            n_cols = 5,
            n_rows = 3
        )

        for fig, dim_red_type in zip(figs, ['pca', 'umap']):
            fig.savefig(
                f'{plot_dir}/{k}_{n_factors}.{dim_red_type}.pdf'
            )
            plt.close(fig)

        # is needed since mefisto can only interpret numbers for smooth covariate
        factors = {
            k: i for i, k in enumerate(time_ordering)
        }
        adata.obs['timefactor'] = adata.obs['time'].apply(
            lambda x: factors[x]
        )
        
        adata.X = adata.layers['raw'].copy()
        sc.pp.log1p(adata)
        adata.write(f'../data/{k}_{n_factors}_nogroup.h5ad')

        mu.tl.mofa(
            adata, 
            n_factors = n_factors, # number of factors to fit
            # groups_label = 'condition', # column of adata.obs to use for data grouping
            center_groups = False,
            n_iterations = 2000,
            smooth_covariate = 'timefactor', # column to use as time variable
            smooth_kwargs = {"n_grid": 50, "start_opt": 50, "opt_freq": 50}, # additional arguments for MEFISTO
            outfile = f'../models/{k}_{n_factors}_nogroup.h5ad',
            seed = 2023,
            convergence_mode = 'fast'
        )

# Evaluate models

In [None]:
import os

from mefistotools import plot


time_ordering = ['D7', 'D10', 'D14', 'D98']
factors = {
    k: i for i, k in enumerate(time_ordering)
}
time_column = 'time'
obs['timefactor'] = obs[time_column].apply(
    lambda x: factors[x]
)

dataset_keys = [
    'anova_filtered_with_naive', 
    'anova_filtered_without_naive', 
    # 'nan_filtered_with_naive', 
    # 'nan_filtered_without_naive'
]

# may consume a lot of memory for large numbers of factors
for (c, r), n_factors in zip([(4, 5), (5, 10)], [20, 50]): 
    for k in dataset_keys:
        plot_dir = f'../plots/{k}_{n_factors}/'

        print(k, n_factors)
        model_file = f'../models/{k}_{n_factors}.h5ad'
        plot_prefix = os.path.basename(model_file)[:-5]
        plot.plot_model_evaluations(
            model_file,
            obs,
            plot_dir + plot_prefix,
            n_factors = n_factors,
            n_rows_factors = r,
            n_cols_factors = c,
            groups = False,
            group_column = 'condition'
        )

In [None]:
for group_column in ['sex', 'age', 'condition']:
    for (c, r), n_factors in zip([(4, 5), (5, 10)], [20, 50]): 
        for k in dataset_keys:
            plot_dir = f'../plots/{k}_{n_factors}/'
            print(k, n_factors, group_column)
            model_file = f'../models/{k}_{n_factors}.h5ad'
            plot_prefix = os.path.basename(model_file)[:-5] + f'_{group_column}'
            m = io.read_model(model_file, obs)
            factors = list(range(n_factors)) if isinstance(n_factors, int) else n_factors
            fig = plot.plot_factors(
                m,
                'timefactor',
                factors,
                group_column,
                r,
                c,
                0.5
            )
            fig.savefig(plot_dir + plot_prefix + '_factors.pdf')

In [None]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
from mefistotools import plot, io

time_ordering = ['D7', 'D10', 'D14', 'D98']
factors = {
    k: i for i, k in enumerate(time_ordering)
}
time_column = 'time'
obs['timefactor'] = obs[time_column].apply(
    lambda x: factors[x]
)

model = io.read_model('../models/anova_filtered_with_naive_50.h5ad', obs)
factors = list(range(50))
factors_and_metadata = model.fetch_values(
    [*factors] + model.metadata.columns[:-1].to_list()
)
retain_columns = factors_and_metadata.columns.str.startswith('Factor') | factors_and_metadata.columns.isin(['condition', 'sex', 'age', 'time'])
factors_and_metadata = factors_and_metadata.loc[:, retain_columns]

fig, axs = plt.subplots(2, 5)

factor_to_color = {
    9: 'sex',
    14: 'condition',
    36: 'age',
    37: 'condition'
}
for i, (f1, f2) in enumerate([(14, 37), (9, 36)]):
    plot.plot_annotated_factor_combination(
        factors_and_metadata.loc[:, ~factors_and_metadata.columns.isin(['time'])], 
        f'Factor{f1}',
        f'Factor{f2}',
        axs[i, :3]
    )
    
    for j, factor in zip([3, 4], [f1, f2]): 
        color = factor_to_color[factor]
        sns.scatterplot(
            data = factors_and_metadata,
            x = 'time',
            y = f'Factor{factor}',
            hue = color,
            ax = axs[i, j],
            palette = 'husl'
        )

        sns.lineplot(
            data = factors_and_metadata,
            x = 'time',
            y = f'Factor{factor}', 
            hue = color,
            ax = axs[i, j],
            estimator = 'mean',
            palette = 'husl'
        )
        
        axs[i, j].set_title(factor)
    
fig.set_figwidth(20)
fig.set_figheight(8)
fig.tight_layout()
fig.savefig('factor_combination.pdf')

In [None]:
# load model
from mefistotools import io

k = 'anova_filtered_with_naive'
n_factors = 50

model = io.read_model(
    f'../models/{k}_{n_factors}.h5ad',
    obs
)
variance_explained = model.get_variance_explained()

In [None]:
variance_explained_pivot = variance_explained.pivot(
    columns = ['Group'],
    index = ['Factor'],
    values = 'R2'
)
variance_explained_pivot.loc[
    [f'Factor{i}' for i in [9, 14, 36, 37]],
    :
].to_csv('../data/variance_explained.tsv', sep = '\t')

In [None]:
import mofax
sharedness = mofax.plot_sharedness(model, return_data = True)
smoothness = mofax.plot_smoothness(model, return_data = True)

In [None]:
smoothness

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams['pdf.fonttype'] = 42


fig, axs = plt.subplots(1, 2)
for df, label, ax in zip([sharedness, smoothness], ['Sharedness', 'Smoothness'], axs):
    df.index = df.factor
    sns.barplot(
        data = df.loc[[f'Factor{i}' for i in [9, 14, 36, 37]], :],
        x = 'shared' if label == 'Sharedness' else 'smooth',
        y = 'factor',
        color = 'grey',
        ax = ax
    )
    ax.set(xlabel = label, ylabel = "Factor") 
    sns.despine(offset=10, trim=True, ax=ax)
    
fig.set_figwidth(10)
fig.set_figheight(5)
fig.tight_layout()
fig.savefig('../plots/shared_and_smooth.pdf')

In [None]:
from mefistotools import io, plot


dataset_keys = [
    'anova_filtered_with_naive', 
    'anova_filtered_without_naive', 
    'nan_filtered_with_naive', 
    'nan_filtered_without_naive'
]

models = {}
for n_factors in [20, 50]: 
    for k in dataset_keys:
        for suffix in ['', '_nogroup']:
            print(n_factors, k, suffix)
            
            model_file = f'../models/{k}_{n_factors}{suffix}.h5ad'
            plot_dir = f'../plots/{k}_{n_factors}{suffix}/'
            model = io.read_model(model_file, obs)
            models[f'{k}_{n_factors}{suffix}'] = model
            fig = plot.plot_factor_values(model)
            fig.savefig(plot_dir + 'factor_values.pdf')

# GSEA

In [None]:
import gseapy as gp

bm = gp.Biomart()
h2m = bm.query(
    dataset='hsapiens_gene_ensembl',
    attributes=[
        'ensembl_gene_id',
        'external_gene_name',
        'mmusculus_homolog_ensembl_gene',
        'mmusculus_homolog_associated_gene_name']
)
pain_gene_sets = io.read_gene_sets(
    '../resources/pain_gene_sets.csv',
    'gs',
    'symbol',
    h2m
)

In [None]:
models

In [None]:
from scipy.stats import zscore
import os

def gp_prerank(expression, gene_sets, fdr = 0.05, identifier = ''):
    try:
        results = gp.prerank(
            rnk = expression,
            gene_sets=gene_sets,
            outdir = None,
            min_size = 5,
            max_size = 2000
        )
        
    except Exception as e:
        print(
            'prerank {identifier} {exception}'.format(
                identifier = identifier, 
                exception = str(e)
            )
        )
        return pd.DataFrame()
    
    results = pd.DataFrame() \
        .from_dict(results.results, orient = 'index') \
        .reset_index(names = 'Term')

    results = results.loc[results.fdr < fdr, :]
    return results.drop(columns = ['RES'])


def gp_enrich(gene_list, gene_sets, fdr = 0.05, identifier = '', background = None):
    try:
        results = gp.enrich(
            gene_list = gene_list,
            gene_sets = gene_sets,
            background = background,
            outdir = None
        )
        
    except Exception as e:
        print(
            'enrich {identifier} {exception}'.format(
                identifier = identifier, 
                exception = str(e)
            )
        )
        return pd.DataFrame()
    
    results = results.results
    return results.loc[results['Adjusted P-value'] < fdr, :]



gene_sets_to_retrieve = [
    'GO_Biological_Process_2023', 
    'MSigDB_Hallmark_2020'
]

# this is to download the gmt files from enrichr
gene_sets = []
for gene_set_name in gene_sets_to_retrieve:
    gene_sets.append(
        gp.get_library(
            name = gene_set_name,
            organism = 'mouse'
        )
    )
    
gene_sets.append(pain_gene_sets)
    

results = {}
plot_factors = {
    'anova_filtered_with_naive_20': set(
        [f'Factor{i}' for i in [4, 5, 10, 12, 15, 16, 19, 20]]
    ), 
    'anova_filtered_with_naive_20_nogroup': set(
        [f'Factor{i}' for i in [4, 16, 18]]
    ), 
    'anova_filtered_without_naive_20': set(
        [f'Factor{i}' for i in [16, 17, 18, 19]]
    ), 
    'anova_filtered_with_naive_50': set(
        [f'Factor{i}' for i in [9, 14, 36, 37]]
    ), 
    'anova_filtered_with_naive_50_nogroup': set(
        [f'Factor{i}' for i in [7, 11, 36, 37]]
    ), 
    'anova_filtered_without_naive_50': set(
        [f'Factor{i}' for i in [23, 24, 25, 26, 27, 29]]
    )
}

fdr = 1
for k, model in models.items():
    if k != 'anova_filtered_with_naive_50':
        continue
        
    print(k)
    factors_to_plot = plot_factors[k] if k in plot_factors else set()
    factor_weights = plot.expand_gene_names(
        model.get_weights(df = True)
    )
    factor_weights_zscore = factor_weights.apply(zscore, axis = 1)
    model_results_enrich, model_results_prerank = [], []
    for factor in factor_weights_zscore.columns:
        print(factor)
        ranks = factor_weights_zscore[factor].sort_values()
        for association in ['associated', 'antiassociated']:
            n = int(len(ranks) * 0.5)
            features = ranks.nlargest(n) if association == 'associated' else ranks.nsmallest(n)
            
            result_prerank = gp_prerank(
                features,
                gene_sets_to_retrieve + [pain_gene_sets],
                identifier = k,
                fdr = fdr
            )
            result_prerank['factor'] = factor
            result_prerank['association'] = association
            model_results_prerank.append(result_prerank)
            
            n = int(len(ranks) * 0.1)
            features = ranks.nlargest(n) if association == 'associated' else ranks.nsmallest(n)
            result_enrich = gp_enrich(
                features.index.str.upper().to_list(),
                gene_sets,
                background = len(ranks),
                identifier = k,
                fdr = fdr
            )
            
            if factor in factors_to_plot:
                try:
                    ax = gp.dotplot(
                        result_enrich,
                        top_term = 10,
                        cutoff = fdr
                        
                    )
                    fig = ax.get_figure()
                    fig.savefig(f'../plots/{k}/enrichment_{factor}_{association}.pdf')
                    
                except ValueError as e:
                    print(result_enrich)
                    pass
                    
            
            result_enrich['factor'] = factor
            result_enrich['association'] = association
            model_results_enrich.append(result_enrich)
    
    results[k] = {
        'enrich': pd.concat(model_results_enrich),
        'prerank': pd.concat(model_results_prerank)
    }

In [None]:
for k, result_dict in results.items():
    for result_name, result_frame in result_dict.items():
        sort_col = 'fdr' if result_name == 'prerank' else 'Adjusted P-value'
        result_frame.sort_values(['factor', 'association', sort_col]).to_csv(
            f'../enrichments/{k}_{result_name}.tsv',
            index = False,
            sep = '\t'
        )