# Using MELD to characterize chordin loss-of-function
## Introduction

In this tutorial, we will demonstrate how to use MELD to characterize the effect of Cas9-mutagenesis in the zebrafish embryo. We will use a dataset generated by the Klein and Megason labs and published in [Wagner et al. (2018) (doi: 10.1126/science.aar4362)](https://www.ncbi.nlm.nih.gov/pubmed/29700229). Here, zebrafish embryos were injected with Cas9 + gRNAs at the one-cell stage targeting either chordin (*chd*) in the experimental condition or tyrosinase (*tyr*) in the control condition. Embryos were collected in a rough time course from 14-16 hours post fertilization for scRNA-seq and 27,000 cells were recovered.

[Chordin](https://www.genecards.org/cgi-bin/carddisp.pl?gene=CHRD) is a BMP antagonist required for proper specification of dorsally-derived neural tissues ([Hammerschmidt et al. 1997](https://www.ncbi.nlm.nih.gov/pubmed/9007232)). [Tyrosinase](https://www.genecards.org/cgi-bin/carddisp.pl?gene=TYR) is a gene required for melanin production, but does not affect cell type specification at the time points considered in this study.

We will also introduce some basics of preprocessing, visualization and imputation to give an idea of how you might include MELD in a general scRNA-seq analysis workflow.

**Note:** this is a modified and abbreviated version of the original notebook available on [the MELD GitHub](https://github.com/KrishnaswamyLab/MELD). The full version includes parameter optimization and VFC on subclusters.

Here's the order we'll follow:

* [1. Loading the dataset](#1.-Loading-data)  
* [2. Embedding Data Using PHATE](#3.-Embedding-Data-Using-PHATE)
* [3. Using MELD to calculate sample-associated density estimates and likelihood](#4.-Using-MELD-to-calculate-sample-associated-density-estimates-and-relative-likelihood)

## 0. Installing packages

If you haven't installed MELD yet, you can do so from this notebook. We'll also install some other useful packages while we're at it.

In [None]:
!pip install --user meld phate magic-impute cmocean diffxpy seaborn scanpy

## 1. Loading data

**Standard imports**

Note: if you get an error here, you may have to restart the runtime to make sure Colab recognizes the newly installed packages. Go to _Runtime_ -> _Restart runtime_.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmocean
import phate
import scprep
import meld
import sklearn
import tempfile
import os
import scanpy as sc

# making sure plots & clusters are reproducible
np.random.seed(42)

## Load data

To facilitate running this notebook quickly, we're going to download a preprocessed AnnData object from FigShare.

The preprocessing steps to download the published data from GEO and perform filtering, library size normalization, and sqrt transformation can be found in the full tutorial on [the MELD GitHub](https://github.com/KrishnaswamyLab/MELD).

In [None]:
URL = "https://ndownloader.figshare.com/files/25687247?private_link=f194ae7d6bcec9bd11a3"

with tempfile.TemporaryDirectory() as tempdir:
        filepath = os.path.join(tempdir, "Klein2018_Zebrafish.h5ad")
        scprep.io.download.download_url(URL, filepath)
        adata = sc.read_h5ad(filepath)

### Subsample

To enable this dataset to run in Google CoLab, we need to subsample to 10,000 cells.



In [None]:
subsample_index = np.random.choice(adata.shape[0], size=10000, replace=False)
adata = adata[subsample_index].copy()

In [None]:
data = adata.to_df()
metadata = adata.obs

### Examining the number of cells in each sample past filtering



First, we create a colormap for visualizating the samples. You can select colors by hex code with Google's [RGB color picker](https://www.google.com/search?client=firefox-b-1-d&q=rgb+color+picker).

In [None]:
sample_cmap = {
    'chdA': '#fb6a4a',
    'chdB': '#de2d26',
    'chdC': '#a50f15',
    'tyrA': '#6baed6',
    'tyrB': '#3182bd',
    'tyrC': '#08519c'
}

As we can see in the following plot, there are many more cells that passed QC in the chdA condition relative to the other samples. To account for this, the MELD algorithm automatically normalizes each replicate to account for varying numbers of cells.



In [None]:
fig, ax = plt.subplots(1)

groups, counts = np.unique(metadata['sample_labels'], return_counts=True)
for i, c in enumerate(counts):
    ax.bar(i, c, color=sample_cmap[groups[i]])
    
ax.set_xticks(np.arange(i+1))
ax.set_xticklabels(groups)
ax.set_ylabel('# cells')

fig.tight_layout()

## 2. Embedding Data Using PHATE

The API of PHATE models that of Scikit Learn. First, you instantiate a PHATE estimator object with the parameters for fitting the PHATE embedding to a given dataset. Next, you use the `fit` and `fit_transform` functions to generate an embedding. For more information, check out [**the PHATE readthedocs page**](http://phate.readthedocs.io/).

We'll just use the default parameters for now, but the following parameters can be tuned (read our documentation at [phate.readthedocs.io](https://phate.readthedocs.io/) to learn more):

* `knn` : Number of nearest neighbors (default: 5). Increase this (e.g. to 20) if your PHATE embedding appears very disconnected. You should also consider increasing `knn` if your dataset is extremely large (e.g. >100k cells)
* `decay` : Alpha decay (default: 15). Decreasing `decay` increases connectivity on the graph, increasing `decay` decreases connectivity. This rarely needs to be tuned. Set it to `None` for a k-nearest neighbors kernel.
* `t` : Number of times to power the operator (default: 'auto'). This is equivalent to the amount of smoothing done to the data. It is chosen automatically by default, but you can increase it if your embedding lacks structure, or decrease it if the structure looks too compact.
* `gamma` : Informational distance constant (default: 1). `gamma=1` gives the PHATE log potential, but other informational distances can be interesting. If most of the points seem concentrated in one section of the plot, you can try `gamma=0`.


Here's the simplest way to apply PHATE:
```python
phateop = phate.PHATE(knn=9, decay=10, gamma=0, n_jobs=-2)
Y = phateop.fit_transform(data_sqrt)
```

In [None]:
data_pca = scprep.reduce.pca(data)

In [None]:
phate_op = phate.PHATE(n_jobs=-1)
data_phate = phate_op.fit_transform(data_pca)

### Coloring a PHATE plot by sample ID

And then we plot using `scprep.plot.scatter2d`. For more advanced plotting, we recommend Matplotlib. If you want more help on using Matplotlib, they have [**extensive documentation**](https://matplotlib.org/tutorials/index.html) and [**many Stackoverflow threads**](https://stackoverflow.com/questions/tagged/matplotlib).

In [None]:
scprep.plot.scatter2d(data_phate, c=metadata['sample_labels'], cmap=sample_cmap, 
                      legend_anchor=(1,1), figsize=(6,5), s=10, label_prefix='PHATE', ticks=False)

### Coloring a PHATE plot by ClusterIDs

In Wagner et al. (2018), cells from the *chd* and *tyr* conditions were assigned cluster IDs through projection back to a reference dataset. In the published analysis, these number of cells mapping to each cluster in the *chd* vs *tyr* condition was used at the measure of *chd* loss-of-function on that cluster. To visualize the relationships between these clusters, we will color the PHATE plot by each cell's published ClusterID.

In [None]:
scprep.plot.scatter2d(data_phate, c=metadata['cluster'], cmap=cmocean.cm.phase, 
                      legend_anchor=(1,1), figsize=(5,5), s=10, label_prefix='PHATE', ticks=False)

### Discussion Question

1. What do you notice about this PHATE plot when you compare to the distribution of sample labels above? Are there some clusters that you think are more or less suited to analysis of differential abundance?

## 4. Using MELD to calculate sample-associated density estimates and relative likelihood

Using MELD, we quantify the effect of an experimental perturbation by first estimating the density of each sample over a graph learned from all cells from all samples. This yields one density estimate per sample. We then normalize density estimates across samples from the same replicate to calculate the sample-associated relative likelihood. This relative likelihood is a ratio between the sample probability densities from each condition and indicates how much more likely we are to observe a given cell in one condition relative to another. 

We can use the relative likelihood estimates to identify which cells are the most enriched in each experimental condition and which cell types are unchanging across conditions. We can also use this value to identify the gene signature of a perturbation (*i.e.* the genes that change the most across experimental conditions).

#### Separating replicate and conditions

We run the MELD algorithm on each sample independently, then normalize within each replicate. First we're going to create a vector that indicated the replicate that each cell was sequenced in.

In [None]:
metadata['genotype_name'] = np.where(metadata['sample_labels'].str.startswith("chd"), "chd", "tyr")
metadata['genotype'] = np.where(metadata['genotype_name'] == "chd", 1, 0)
metadata['replicate'] = metadata['sample_labels'].str[-1]

### Run MELD

These next two code blocks build the graph for MELD and estimate the density of each sample. The parameters for knn and beta are optimized in the [full notebook on GitHub](https://github.com/KrishnaswamyLab/MELD).

Here, we're going to create a MELD operator object, which inherits from the sklearn [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). The full documentation for MELD can be found here: https://meld-docs.readthedocs.io/en/stable/

The input to MELD is the data and the sample labels. Here, we're only using the first 100 PCs of the data. The output of `meld_op.fit_transform()` is the sample associated density estimate referenced in the [MELD paper](https://www.biorxiv.org/content/10.1101/532846v4). This is equivalent to a kernel density estimate of the sample over the graph.

In [None]:
# beta is the amount of smoothing to do for density estimation
# knn is the number of neighbors used to set the kernel bandwidth
meld_op = meld.MELD(beta=67, knn=7)
sample_densities = meld_op.fit_transform(data_pca, sample_labels=metadata['sample_labels'])

Let's look at the sample densities on a PHATE plot.



In [None]:
fig, axes = plt.subplots(2,3, figsize=(11,6))

for i, ax in enumerate(axes.flatten()):
    density = sample_densities.iloc[:,i]
    scprep.plot.scatter2d(data_phate, c=density,
                          title=density.name,
                          vmin=0, 
                          ticks=False, ax=ax)
    
fig.tight_layout()

### Discussion question

1. What is the sum of the density of each sample over the data? How would you calculate this?

2. What do you notice about the density of each sample across replicates? Do you see more similarity between samples of the same replicate or of the same condition? Is this expected or unexpected? Does it change your interpreation of the results?


Once we have the densities, we then compare the densities between conditions within each replicate. This gives us a relative likelihood that a given cell would be observed in each condition. 

In common speech, likelihood and probability are used interchangeably. However they have very distinct statistical meanings. The probability (or probability density) of an event is the chance that an event will happen under a given model (probability of the data given the model). In this case, the sample densities are the probability that if you were to randomly pick a new cell from that sample that it would be a given cell. However, when comparing the densities for each sample for a specific cell, these values can be considered the likelihood that the cell would be observed in a given sample (likelihood of the model given the data). Here the configuration of experimental variables is considered the "model parameters" for the likelihood.

To better understand this distinction, I recommend [the Likelihood Function article on Wikipedia](https://en.wikipedia.org/wiki/Likelihood_function#Discrete_probability_distribution).

We want to calculate the ratio between these likelihoods so that we can understand how much more likely it would be to observe a cell in the treatment condition relative to the control condition. To calculate this ratio, we apply an L1 normalization of the densities within each replicate. This normalizes the values to sum to 1 across samples within each replicate. 

In [None]:
# This is a helper function to apply L1 normalization across the densities for each replicate
def replicate_normalize_densities(sample_densities, replicate):
    # Get the unique replicates
    replicates = np.unique(replicate)
    sample_likelihoods = sample_densities.copy()
    for rep in replicates:
        # Select the columns of `sample_densities` for that replicate
        curr_cols = sample_densities.columns[[col.endswith(rep) for col in sample_densities.columns]]
        curr_densities = sample_densities[curr_cols]
        # Apply L1 normalization
        sample_likelihoods[curr_cols] = sklearn.preprocessing.normalize(curr_densities, norm='l1')
    return sample_likelihoods

In [None]:
sample_likelihoods = replicate_normalize_densities(sample_densities, metadata['replicate'])

We now have the sample associated relative likelihoods for each condition.

In [None]:
fig, axes = plt.subplots(1,3, figsize=(13,4))

experimental_samples = ['chdA', 'chdB', 'chdC']

for curr_sample, ax in zip(experimental_samples, axes):
    scprep.plot.scatter2d(data_phate, c=sample_likelihoods[curr_sample], cmap=meld.get_meld_cmap(),
                          vmin=0, vmax=1,
                          title=curr_sample, ticks=False, ax=ax)

fig.tight_layout()

We can also look at the mean and standard deviation of the relative likelihood estimates across replicates. Notice how areas that are consistently enriched or depleted across replicates have low standard deviation.

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8.7,4))

scprep.plot.scatter2d(data_phate, c=sample_likelihoods[experimental_samples].mean(axis=1), 
                      cmap=meld.get_meld_cmap(), vmin=0, vmax=1,
                      title='Mean', ticks=False, ax=axes[0])
scprep.plot.scatter2d(data_phate, c=sample_likelihoods[experimental_samples].std(axis=1), vmin=0, 
                      cmap='inferno', title='St. Dev.', ticks=False, ax=axes[1])

fig.tight_layout()

We use the average likelihood of the chordin samples as the measure of the perturbation.

In [None]:
metadata['chd_likelihood'] = sample_likelihoods[experimental_samples].mean(axis=1).values

### Discussion Questions:
1. Here, we only look at the `chd` relative likelihood. Why don't we look at the `tyr` relative likelihood?

2. What does the variation in the relative likelihood values across replicates tell you?


### Examining the distribution of _chd_ likelihood values in published clusters

Finally, we will compare using clusters based on data geometry to using MELD for quantifying the effect of an experimental perturbation. 

Let's sort the index of each cluster from lowest to highest average _chd_ likelihood value

In [None]:
metadata['clusterID'] = scprep.utils.sort_clusters_by_values(metadata['clusterID'], metadata['chd_likelihood'])

#### Create jitter plots

These show the distribution of _chd_ likelihood values within each cluster. Each point is a cell and the y-axis is the _chd_ likelihood. The slight jitter in the x-xais is  only to help show density within each cluster.

In grey, `scprep` plots a circle denoting the mean likelihood value of each cluster. Additionally, we will plot a circle in purple denoting the ratio (or fold-change) of _tyr_ to _chd_ cells in each cluster. 

In [None]:
fig, ax = plt.subplots(1, figsize=(10,10))

# See example usage: https://scprep.readthedocs.io/en/stable/examples/jitter.html
scprep.plot.jitter(metadata['clusterID'], metadata['chd_likelihood'], 
                   c=metadata['sample_labels'], 
                   cmap=sample_cmap,
                   legend=False, 
                   plot_means=True, 
                   means_s=50, 
                   xlabel=False, 
                   ylabel='Mean chd likelihood',
                   ax=ax)

### This code will plot the ratio of tyr:chd cells per cluster
means = metadata.groupby('clusterID')['genotype'].mean()
ax.scatter(means.index, means - np.mean(metadata['genotype']) + 0.5, color='#7c5295', edgecolor='k', s=50)

# Axis tick labels
ax.set_xticklabels(pd.unique(metadata.sort_values('clusterID')['cluster']), rotation=90)
ax.set_ylim(0,1)

fig.tight_layout()

### Discussion

1. What do you notice about the distribution of relative likelihood values within each cluster? Are there clusters where the fold-change in abundance (purple circles) and average likelihood (grey circles) differ greatly? Why do you think that is?

### Activity

Here, we want to visualize the heterogeneity within some clusters using PHATE run on each subset of the data.

Pick one cluster above with a large amount of variation and one cluster with a low amount of variation. Coordinate with your group to try to get a number of different clusters chosen per group.

Use the code below to plot the PHATE embedding of that cluster and color it by the likelihood values.

In [None]:
# These are the cluster names
print(metadata['cluster'].unique())

Duplicate these cells to test multiple clusters

In [None]:
# =========
# Pick a cluster to analyze
curr_cluster = ""
# =========

# Take a subset of the data
curr_subset = metadata['cluster'] == curr_cluster
curr_data = data_pca.loc[curr_subset]
curr_metadata = metadata.loc[curr_subset]
curr_data_phate = phate.PHATE(verbose=0).fit_transform(curr_data)

In [None]:
scprep.plot.scatter2d(curr_data_phate, 
                      c=curr_metadata['chd_likelihood'], 
                      cmap=meld.get_meld_cmap(), vmin=0, vmax=1,
                      ticks=False, figsize=(4,4),
                      title='{} ({} cells)'.format(curr_cluster, curr_data_phate.shape[0]), 
                      legend=False, fontsize=10)

### Discussion

What do you notice about the relationship between standard deviation of the likelihood and the amount of heterogeneity seen in the PHATE embedding? What does this suggest?

This concludes the workshop version of the notebook, but to see how vertex frequency clustering interacts with MELD to identify subpopulations of cells, please consult [the full tutorial on GitHub](https://github.com/KrishnaswamyLab/MELD).