# Step 2: Normalize your data using sctransform

Use this notebook to normalize your data using the [sctransform](https://satijalab.org/seurat/articles/sctransform_vignette) package. 

Please note that the sctransform package is written in R. Therefore, in order to run this notebook you will first need to [install R]() on your system, as well as [sctransform version 0.4.1](https://github.com/satijalab/sctransform) and it's corresponding dependencies (including the optional dependency [glmGamPoi](https://github.com/const-ae/glmGamPoi)). 

Also note that your raw data input should be a csv file arranged using [tidy format](https://tidyr.tidyverse.org/articles/tidy-data.html). At a minimum your input csv should have five columns:
1. A column that corresponds to the first mode of your tensor. In metatranscriptomic data this column might indicate gene ID.
    - This first mode should generally be the longest in your tensor, and the one that corresponds to the variable you want clustered (e.g. genes in the case of metatranscriptomics data). The sparsity penalty (`lambda`) will be applied to this mode.
1. A column that corresponds to the second mode of your tensor. In metatranscriptomic data this column might indicate taxon ID.
1. A column that corresponds to the third mode of your tensor. This column should indicate sample ID.
    - **IMPORTANT: Sample IDs should be identical for different replicates of the same sample condition (see example below).**
1. A column that indicates the replicate ID of the sample.
1. A column that corresponds to the data variable. For raw metatranscriptomic data this column might contain read counts.

Here's a snippet of how an example csv might be arranged:

| gene_id | taxon_id   | sample_id | replicate | residual |
|---------|------------|-----------|-----------|----------|
| K03839  | P. marinus | sample1   | A         | 3.02     |
| K03839  | P. marinus | sample1   | B         | 3.31     |
| K03839  | P. marinus | sample1   | C         | 3.18     |
| K03839  | P. marinus | sample2   | A         | -1.24    |
| ...     | ...        | ...       | ...       | ...      |
| K03320  | S. marinus | sample9   | C         | 0.05     |

In [None]:
# imports

# python packages
import math
import numpy as np
import os
import pandas as pd
import rpy2
import seaborn as sns

from matplotlib import pyplot as plt

# rpy2 imports
from rpy2 import robjects as ro
from rpy2.robjects.packages import importr
from rpy2.ipython.ggplot import image_png
from rpy2.robjects import pandas2ri

# load rpy2 extension for ipython
pandas2ri.activate()
%load_ext rpy2.ipython


In [None]:
# install & import sctransform and other r package dependencies

# check if sctransform is installed
if not ro.packages.isinstalled('sctransform'):
    # select CRAN mirror
    utils = importr('utils')
    utils.chooseCRANmirror(ind=1)
    # install sctransform
    utils.install_packages(ro.vectors.StrVector(['sctransform']))

# import sctransform and R Matrix package
sctransform = importr('sctransform')
rmatrix = importr('Matrix')

# check sctransform version (should be 0.4.1)
print(f'Installed sctransform version: {sctransform.__version__}')
if not sctransform.__version__ == '0.4.1':
    raise Exception('Please ensure that the installed sctransform is version 0.4.1')
    
# check that glmGamPoi depencency is installed
if not ro.packages.isinstalled('glmGamPoi'):
    raise Exception('Please install glmGamPoi: https://github.com/const-ae/glmGamPoi')


### Input data

In this step you will enter the variables necessary to:
1. Locate your input data
1. Configure your data tensor (i.e. which three variables correspond to the three different modes or axes)
1. Configure your normalization scheme

When it comes to normalization, there are several parameters you need to define:
- You will normalize by one of the variables that corresponds to a mode in your data tensor. You can think of this as the variable that defines the "within" groups. For example if I am normalizing by taxa, then the normalization procedure will be applied independently within each taxon in the dataset. Put another way, you can think of each taxon as a slab of your data tensor (i.e. a matrix), and the normalization will be applied independently to each slab.
- For each slab, you will define which of the remaining two modes best corresponds to samples, and which to genes. For example, in metatranscriptomic data annotated with KEGG orthologies, the sample ID would be the sample mode, and the KEGG ID would be the gene mode. Additionally, if you want to account for batch effects in your data, the batch variable should relate to your sample mode.
- You will define two thresholds that will apply to each slab (e.g. taxon):
    1. Sample threshold: each gene must be detected in at least this many samples. Genes that fall below the threshold will be removed. Default is 3.
    1. Detection threshold: each sample must contain non-zero values for at least this proportion of genes. Samples that fall below the threshold will be removed. Default is 0.01 (1%).
    - Note that data removed due to thresholding is preserved, saved, and displayed at the end of this notebook so it can be used for troubleshooting or to refine the thresholds and normalization work flow.
- If you want the normalization model to correct for batch effects, you will need to prepare a second csv file with columns corresponding to the sample mode in your data (including both sample ID and replicate), and with an additional column indicating the batch ID. For example, you might have a csv with the following headers: `['sample_id', 'replicate', 'batch_id']`. This file should include each unique combination of sample ID and replicate in your dataset, and the batch membership of each sample replicate should be indicated in the batch ID column.


Either store inputs

In [11]:
datapath = ''  # Enter the filepath of your input data file:
mode0 = ''  # Enter the column name that will correspond to the first mode of your tensor:
mode1 = ''  # Enter the column name that will correspond to the second mode of your tensor:
mode2 = ''  # Enter the column name that will correspond to the third mode of your tensor:
rep = ''  # ('Enter the column name that corresponds to replicate IDs:')
data = ''  # ('Enter the column name that corresponds to your data:')
outdir = ''  # ('Enter the filepath of the output directory where you want files saved:')

# check output directory exists
if outdir:
    if not os.path.isdir(outdir):
        raise Exception(f'Unable to find the directory "{outdir}"')
    else:
        # make normalization directory within output directory
        outdir = f'{outdir}/normalization'
        if not os.path.exists(outdir):
            os.makedirs(outdir)

# check data file exists
if datapath and not os.path.isfile(datapath):
    raise Exception(f'Unable to find the file "{datapath}"')


Or input using prompt

In [12]:
# input data

# data file
datapath = datapath or input('Enter the filepath of your input data file:')
# check data file exists
if not os.path.isfile(datapath):
    raise Exception(f'Unable to find the file "{datapath}"')

# output directory
outdir = outdir or input('Enter the filepath of the output directory where you want files saved:')
# check output directory exists
if not os.path.isdir(outdir):
    raise Exception(f'Unable to find the directory "{outdir}"')
# make normalization directory within output directory
outdir = f'{outdir}/normalization'
if not os.path.exists(outdir):
    os.makedirs(outdir)

# column names
mode0 = mode0 or input('Enter the column name that will correspond to the first mode of your tensor:')
mode1 = mode1 or input('Enter the column name that will correspond to the second mode of your tensor:')
mode2 = mode2 or input('Enter the column name that will correspond to the third mode of your tensor:')
rep = rep or input('Enter the column name that corresponds to replicate IDs:')
data = data or input('Enter the column name that corresponds to your data:')

Inspect df

In [None]:
# read in csv
df = pd.read_csv(datapath)

# check column names match inputs
for column in [mode0, mode1, mode2, rep, data]:
    if column not in df.columns:
        raise Exception(f'Column name "{column}" not found in headers of file {datapath}')

# tidy up dataframe
df = df[[mode0, mode1, mode2, rep, data]]
df

Either store parameters

In [27]:
norm_mode_ = 2  # USE INTEGER NOT STRING (f'Which mode do you want to normalize by? 1-{mode0} 2-{mode1} 3-{mode2} (enter 1/2/3):'))
sample_mode_ = 3  # USE INTEGER NOT STRING (f'Which mode corresponds to your sample variable? 1-{mode0} 2-{mode1} 3-{mode2} (enter 1/2/3):'))
sample_thold = 3 # USE INTEGER NOT STRING 'Each {gene_mode} must be detected in what minimum number of unique {sample_mode}s? (Default is 3):'))
gene_thold = 0.01 # USE FLOAT NOT STRINGf'Each {sample_mode} must contain what minimum proportion of nonzero {gene_mode}s? (Default is 0.01):'))
save_data_ = 'y'  # 'Would you like to save the data output for each normalization? (Enter Y/N):'
save_plots_ = 'y'  # 'Would you like to save the diagnostic plots for each normalization? (Enter Y/N):
correction_ = 'n'  # 'Would you like to correct for batch effects? (Enter Y/N):').lower()
metadata_path = ''  # input('Enter the filepath of your batch metadata file:')
batch_id = ''  # input('Enter the column name that corresponds to the batch ID:')


Or input using prompt

In [None]:
# normalization set up 

# by mode
while norm_mode_ not in [1, 2, 3]:
    norm_mode_ = int(input(f'Which mode do you want to normalize by? 1-{mode0} 2-{mode1} 3-{mode2} (enter 1/2/3):'))

# sample mode
while sample_mode_ not in [1, 2, 3]:
    sample_mode_ = int(input(f'Which mode corresponds to your sample variable? 1-{mode0} 2-{mode1} 3-{mode2} (enter 1/2/3):'))
# make a unique identifier that combines sampleID and replicateID
norm_mode = [mode0, mode1, mode2][norm_mode_-1]
sample_mode = [mode0, mode1, mode2][sample_mode_-1]
df['sample_rep_id'] = df[sample_mode].astype(str) + df[rep].astype(str)
gene_mode = [m for m in [mode0, mode1, mode2] if m not in [norm_mode, sample_mode]][0]

# gene mode
# thresholds
sample_thold = sample_thold or int(input(f'Each {gene_mode} must be detected in what minimum number of unique {sample_mode}s? (Default is 3):'))
gene_thold = gene_thold or float(input(f'Each {sample_mode} must contain what minimum proportion of nonzero {gene_mode}s? (Default is 0.01):'))

# output options
while save_data_ not in ['y', 'n']:
    save_data_ = input('Would you like to save the data output for each normalization? (Enter Y/N):').lower()
while save_plots_ not in ['y', 'n']:
    save_plots_ = input('Would you like to save the diagnostic plots for each normalization? (Enter Y/N):').lower()
# batch effects
while correction_ not in ['y', 'n']:
    correction_ = input('Would you like to correct for batch effects? (Enter Y/N):').lower()

save_data = (save_data_ == 'y')
save_plots = (save_plots_ == 'y')
correction = (correction_ == 'y')

if correction:
    metadata_path = metadata_path or input('Enter the filepath of your batch metadata file:')
    batch_id = batch_id or input('Enter the column name that corresponds to the batch ID:')

    


Set up params

In [29]:


if correction:
    # check data file exists
    if not os.path.isfile(metadata_path):
        raise Exception(f'Unable to find the file "{metadata_path}"')
    # read in metadata file
    meta_df = pd.read_csv(metadata_path)
    
    # check for sample, replicate, and batch columns
    if sample_mode not in meta_df.columns:
        raise Exception(f'Column name "{sample_mode}" not found in headers of file {metadata_path}. Columns found: {meta_df.columns}')
    if rep not in meta_df.columns:
        raise Exception(f'Column name "{rep}" not found in headers of file {metadata_path}. Columns found: {meta_df.columns}')
    if not (meta_df[[sample_mode, rep]].drop_duplicates().reset_index(drop=True) == \
            df[[sample_mode, rep]].drop_duplicates().reset_index(drop=True)).all().all():
        raise Exception(f'Files {datapath} and {metadata_path} have different values for columns {sample_mode} and {rep}.')  
    if batch_id not in meta_df.columns:
        raise Exception(f'Column name "{batch_id}" not found in headers of file {metadata_path}. Columns found: {meta_df.columns}')

    # clean up metadata dataframe
    meta_df['sample_rep_id'] = meta_df[sample_mode].astype(str) + '_' + meta_df[rep].astype(str)
    meta_df = meta_df[['sample_rep_id', sample_mode, rep, batch_id]]
    display(meta_df)



### Run sctransform on each slab of data

In this step you'll run the sctransform model on your data. If you selected it in the first step, the data and/or plots resulting from each normalization run will be saved to the designated output directory.

In [30]:
# helper functions for running normalization model

# function to threshold sparsity of pandas dataframes
def sparsity_thold_df(df, thold, axis=0):
    """Apply a sparsity threshold to a pandas.DataFrame so that only columns or rows with 
    greater than or equal to the threshold amount of nonzero values are retained.
    
    Parameters
    ----------
    df : pandas.DataFrame
        Input dataframe object.
    thold : {int, float}
        Threshold. If parameter is an int, then retained vectors must contain at least that number of 
        nonzero values. If parameter is a float, then at least this proportion of the vector must be nonzero.
    axis : int
        Axis to which threshold is applied. 
            
    Returns
    -------
    df : pandas.DataFrame
        Thresholded dataframe.
    dropped_data_df : pandas.DataFrame
        Data removed from the input dataframe as a result of the threshold
    """
    if type(thold) is float:
        thold = math.ceil(df.shape[axis] * thold)
    elif type(thold) is not int:
        raise Exception('Parameter `thold` must be either an int or float type value.')
    mask = (df != 0).sum(axis).ge(thold)
    if axis == 0:
        output_df = df.loc[:, mask]
        dropped_df = df.loc[:, ~mask]
    elif axis == 1:
        output_df = df.loc[mask, :]
        dropped_df = df.loc[~mask, :]
    else:
        raise Exception('Invalid value for `axis` parameter.')
    return output_df, dropped_df


# function to calculate 0-sensitive geometric mean
def geometric_mean(vector, pseudocount=1):
    return np.exp(np.mean(np.log(vector + pseudocount))) - pseudocount


# function to convert pandas dataframe to r matrix
def pandas_dataframe_to_r_matrix(df, dtype=float):
    """
    Function to convert pandas DataFrame objects to R matrix objects.
    """
    if dtype is float:
        vector = ro.vectors.FloatVector(df.values.flatten().tolist())
    elif dtype is str:
        vector = ro.vectors.StrVector(df.values.flatten().tolist())
    elif dtype is int:
        vector = ro.vectors.FloatVector(df.values.flatten().tolist())
    else:
        raise ValueError('The dtype {} is not recognized'.format(dtype))
    matrix = rmatrix.Matrix(
        data=vector, 
        nrow=df.shape[0], 
        ncol=df.shape[1], 
        byrow=True, 
        dimnames=[df.index.to_list(), df.columns.to_list()], 
        sparse=True
    )
    return matrix
    

In [None]:
# run the model on each slab

# make normalization directory within output directory
outdir = f'{outdir}/normalization'
os.makedirs(outdir)

# initialize residuals dataframe
residuals_df = pd.DataFrame()

# keep track of filtered data
dropped_data_df = pd.DataFrame()

# iterate through slabs
slab_ids = df[norm_mode].unique()
for i, slab_id in enumerate(slab_ids):
    # separate out data
    slab_df = df[df[norm_mode].eq(slab_id)].pivot(index=gene_mode, columns=['sample_rep_id'], values=data).fillna(0)
    # apply sample threshold to filter out low-prevalence genes
    slab_df, drop_df = sparsity_thold_df(slab_df, sample_thold, axis=1)
    drop_df = drop_df.melt(value_name=data, ignore_index=False).reset_index()
    drop_df[norm_mode] = slab_id
    drop_df['drop_reason'] = f'{gene_mode} detected in fewer than {sample_thold} {sample_mode}s'
    dropped_data_df = pd.concat([dropped_data_df, drop_df])
    # apply gene threshold to filter out samples with low detection
    slab_df, drop_df = sparsity_thold_df(slab_df, gene_thold, axis=0)
    drop_df = drop_df.melt(value_name=data, ignore_index=False).reset_index()
    drop_df[norm_mode] = slab_id
    drop_df['drop_reason'] = f'{sample_mode} contains fewer than {math.ceil(slab_df.shape[0] * gene_thold)} nonzero {gene_mode}s'
    dropped_data_df = pd.concat([dropped_data_df, drop_df])
    # check for very small slabs
    if (slab_df.shape[0] < 10) or (slab_df.shape[1] < sample_thold):
        print(f'Skipping slab {i+1} of {len(slab_ids)}: {slab_id} ({slab_df.shape[1]} samples, {slab_df.shape[0]} genes)', flush=True)
        print('\tLimited nonzero data in this slab undermines the reliability of normalization with sctransform.', flush=True)
        drop_df = slab_df.melt(value_name=data, ignore_index=False).reset_index()
        drop_df[norm_mode] = slab_id
        drop_df['drop_reason'] = f'{norm_mode} encompassed fewer than 10 {gene_mode}s or fewer than {sample_thold} {sample_mode}s'
        dropped_data_df = pd.concat([dropped_data_df, drop_df])
        continue
    else:
        print(f'Normalizing slab {i+1} of {len(slab_ids)}: {slab_id} ({slab_df.shape[1]} samples, {slab_df.shape[0]} genes)', flush=True)

    # make r version of slab dataframe
    r_slab_df = pandas_dataframe_to_r_matrix(slab_df)
    # pull out batch information
    if correction:
        sample_attr_df = meta_df.set_index('sample_rep_id').loc[slab_df.columns, [sample_mode, rep, batch_id]]
    else: 
        sample_attr_df = df[['sample_rep_id', sample_mode, rep]].set_index('sample_rep_id').drop_duplicates()
    r_sample_attr_df = pandas2ri.py2rpy(sample_attr_df)
        

    # fit vst normalization model
    # Use glmgampoi
    result = sctransform.vst(
        r_slab_df, 
        cell_attr=r_sample_attr_df, 
        batch_var=(ro.vectors.StrVector([batch_id]) if correction else ro.NULL),
        min_cells=sample_thold,
        return_gene_attr=True, 
        return_cell_attr=True, 
        vst_flavor='v2', 
        verbosity=2,
        method='glmGamPoi'
    )

    # Get names from result in case genes get dropped from sctransform (not overdispersed)
    with ro.conversion.localconverter(ro.default_converter + pandas2ri.converter):
        genes = ro.conversion.rpy2py(result[11]).index.values
        cells = ro.conversion.rpy2py(result[10]).index.values
    # convert residuals result to a dataframe
    result_df = pd.DataFrame(
        np.asarray(result[0]), 
        index=genes, 
        columns=cells
    )

    # prepare to save requested outputs
    if save_data or save_plots:
        # make unique output directory per slab
        dir_path = f'{outdir}/{slab_id}'
        if not os.path.isdir(dir_path):
            os.makedirs(dir_path)
        # calculate residuals
        residual_var = result_df.var(axis=1)

    # save data if requested
    if save_data:
        # save csv of residuals
        result_df.to_csv(f'{dir_path}/residuals_{slab_id}.csv')
        # save csv of residual variances
        res_var_df = residual_var.reset_index().rename(columns={0:'residual_variance'})
        res_var_df = res_var_df.sort_values('residual_variance', ascending=False).reset_index()
        res_var_df.to_csv(f'{dir_path}/residual_variances_{slab_id}.csv')

    # save plots if requested
    if save_plots:
        # plots of model parameters
        plots = sctransform.plot_model_pars(result, show_theta=True)
        img = image_png(plots)
        with open(f'{dir_path}/parameters_{slab_id}.png', 'wb') as png:
            png.write(img.data)
        # plot of high variance genes
        means = slab_df.apply(geometric_mean, axis=1)
        plt.figure(figsize=(10, 4))
        sns.scatterplot(x=means, y=residual_var, alpha=0.1);
        plt.xlabel(f'geometric mean of {gene_mode} {data}')
        plt.xscale('log')
        plt.ylabel('residual variance')
        plt.title(f'normalized residual variance\nvs. mean {gene_mode} {data} ({slab_id})')
        plt.savefig(f'{dir_path}/residual_variances_{slab_id}.png')
        plt.show()

    # concatenate result with other residuals
    result_df = result_df.melt(value_name='residual', ignore_index=False).reset_index()
    result_df[norm_mode] = slab_id
    residuals_df = pd.concat([residuals_df, result_df])
    

In [None]:
residuals_df[:3]

In [None]:
residuals_df

In [None]:
residuals_df_ = residuals_df.rename(columns={'variable': 'sample_rep_id', 'index': gene_mode})
residuals_df_[:3]

In [None]:
# save normalized data as a csv

# revert column names changed by sctransform
residuals_df = residuals_df.rename(columns={'variable': 'sample_rep_id', 'index': gene_mode})

# add back sample information
residuals_df = pd.merge(residuals_df, df[['sample_rep_id', sample_mode, rep]].drop_duplicates(), on='sample_rep_id', how='left')

# add in raw count data
residuals_df = pd.merge(residuals_df, df, on=['sample_rep_id', mode0, mode1, mode2, rep], how='left').fillna(0)

# tidy up dataframe
residuals_df = residuals_df[[mode0, mode1, mode2, rep, data, 'residual']]

# save output
residuals_df.to_csv(f'{outdir}/normalized-residuals.csv')

residuals_df


In [None]:
# examine data that was removed during normalization process

# remove zero values
dropped_data_df = dropped_data_df[dropped_data_df[data] != 0.0]

# add back sample information
dropped_data_df = pd.merge(dropped_data_df, df[['sample_rep_id', sample_mode, rep]].drop_duplicates(), on='sample_rep_id', how='left')

# tidy up dataframe
dropped_data_df = dropped_data_df[[mode0, mode1, mode2, rep, data, 'drop_reason']]

# show some summary statistics
for variable in ['drop_reason', mode0, mode1, mode2]:
    print(dropped_data_df[variable].value_counts())

# save drop data
dropped_data_df.to_csv(f'{outdir}/removed-data.csv')      

dropped_data_df
