# TAD Week 6: Dimensionality Reduction

To include in the mini-lecture
* geometric description of multi-d data
* intuitive explanation of PCA
* mathematical equation for PCA reconstruction


Today is a bit of a change from the previous weeks. Instead of focusing on statistical tests or models, we will focus instead on another important technique in data analysis: dimensionality reduction. 

As you saw in today's mini-lecture, principal component analysis (PCA) is a powerful tool for visualizing your data in a low-dimensional space for two reasons: it is guaranteed to capture the axes of maximum variance in your data (i.e., it is optimal and objective), and it is a linear transform (i.e. it is easy to compute and understand). UMAP is also a powerful tool, but for very different reasons: it attempts to find some low-dimensional sub-space along which your data are scattered, and to reduce that sub-space to a few dimensions while preserving the relationship of the original data. UMAP is neither optimal nor objective, opting instead for flexibility -- as we will see, this has its pros and cons.

To illustrate these basic properties of PCA and UMAP, we start with MNIST, a classic machine learning dataset. Then we do a deep dive into UMAP with some scRNA-seq data, looking at its dependence on key parameters, some of its failure modes, and whether or not it really does what it claims to do!

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import StandardScaler
import umap
import umap.plot
from sklearn.datasets import fetch_openml
from itertools import repeat
from scipy.spatial.distance import euclidean

from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import scale
from sklearn.decomposition import TruncatedSVD
from scipy import stats
import scipy.io as sio

In [None]:
# Imports for interactive UMAP plotting
from io import BytesIO
from PIL import Image
import base64
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10


In [None]:
# Some helper functions

def plot_MNIST_sample(X):
    """
    Plots 9 images in the MNIST dataset.
    Credit: https://compneuro.neuromatch.io/tutorials/W1D4_DimensionalityReduction/student/W1D4_Tutorial3.html
    
    Args:
     X (numpy array of floats) : Data matrix each column corresponds to a
                                 different random variable
    Returns:
    Nothing.
    """

    fig, ax = plt.subplots()
    k = 0
    for k1 in range(3):
        for k2 in range(3):
            k = k + 1
            plt.imshow(np.reshape(X[k, :], (28, 28)),
                     extent=[(k1 + 1) * 28, k1 * 28, (k2+1) * 28, k2 * 28],
                     vmin=0, vmax=255, cmap='viridis')
            plt.xlim((3 * 28, 0))
            plt.ylim((3 * 28, 0))
            plt.tick_params(axis='both', which='both', bottom=False, top=False,
                      labelbottom=False)
            plt.clim([0, 250])
            ax.set_xticks([])
            ax.set_yticks([])
    plt.show()
    
    
def plot_MNIST_weights(weights):
    """
    Visualize PCA basis vector weights for MNIST. Red = positive weights,
    blue = negative weights, white = zero weight.
    Credit: https://compneuro.neuromatch.io/tutorials/W1D4_DimensionalityReduction/student/W1D4_Tutorial3.html

    Args:
    weights (numpy array of floats) : PCA basis vector

    Returns:
    Nothing.
    """
    fig, ax = plt.subplots()
    cmap = plt.cm.get_cmap('seismic')
    plt.imshow(np.real(np.reshape(weights, (28, 28))), cmap=cmap)
    plt.tick_params(axis='both', which='both', bottom=False, top=False,
                  labelbottom=False)
    plt.clim(-.15, .15)
    plt.colorbar(ticks=[-.15, -.1, -.05, 0, .05, .1, .15])
    ax.set_xticks([])
    ax.set_yticks([])
    
    
def get_variance_explained(X):
    """
    Use the covariance matrix to calculate variance explained
    along each axis (variable) for the data in X (n obs x n vars).
    
    NB, uses np.cov(rowvar=False)! 
    
    Args:
        X (np array): n obs x n vars matrix
        
    Returns:
        Variance explained by each variable of X (diags of cov)
    """
    cov = np.cov(X, rowvar=False)
    return cov[np.eye(X.shape[1]) == 1]


# Utility for the bokeh function
def embeddable_image(data):
    """See: https://www.kaggle.com/code/parulpandey/part3-visualising-kannada-mnist-with-umap/notebook
    """
    img_data = 255 - 15 * data.astype(np.uint8)
    image = Image.fromarray(img_data, mode='L').resize((28,28), Image.BICUBIC)
    buffer = BytesIO()
    image.save(buffer, format='png')
    for_encoding = buffer.getvalue()
    return 'data:image/png;base64,' + base64.b64encode(for_encoding).decode()


def hoverable_2d_scatter_with_images(data, labels, unwrapped_images, image_dims, s=4):
    """
    Takes x/y data and makes a scatter plot that is 1) colored by label,
    2) hoverable with Bokeh, and 3) optionally has images associated with each point
    
    Lightly modified from: https://www.kaggle.com/code/parulpandey/part3-visualising-kannada-mnist-with-umap/notebook
    
    data (np.array): n obs x 2 (x,y)
    labels (np.array): n obs
    unwrapped_images (np.array): n obs x n pixels
    image_dims (tuple or list): width x height to re-wrap images
    """
    
    assert data.ndim == 2
    assert data.shape[1] == 2
    
    assert data.shape[0] == labels.shape[0] == unwrapped_images.shape[0]
    
    assert len(image_dims) == 2
    assert np.product(image_dims) == unwrapped_images.shape[1]

    rewrapped_images = unwrapped_images.reshape(unwrapped_images.shape[0], image_dims[0], image_dims[1])

    digits_df = pd.DataFrame(data, columns=('x', 'y'))
    digits_df['digit'] = [str(lab) for lab in labels]
    digits_df['image'] = list(map(embeddable_image, rewrapped_images))


    datasource = ColumnDataSource(digits_df)
    color_mapping = CategoricalColorMapper(factors=[str(9 - x) for x in labels],
                                           palette=Spectral10)

    plot_figure = figure(
        title='UMAP projection of the MNIST dataset',
        plot_width=600,
        plot_height=600,
        tools=('pan, wheel_zoom, reset')
    )

    plot_figure.add_tools(HoverTool(tooltips="""
    <div>
        <div>
            <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
        </div>
        <div>
            <span style='font-size: 16px; color: #224499'>Digit:</span>
            <span style='font-size: 18px'>@digit</span>
        </div>
    </div>
    """))

    plot_figure.circle(
        'x',
        'y',
        source=datasource,
        color=dict(field='digit', transform=color_mapping),
        line_alpha=0.6,
        fill_alpha=0.6,
        size=s
    )
    show(plot_figure)

    

## Loading the data

We will start with the MNIST dataset for the first few exercises. This is a set of 70,000 low-res images of handwritten digits. Each image is 28x28, and each image is unwrapped in the data, so we'll be working with a 70,000 x 28^2 (784) matrix. MNIST is a classic machine learning dataset -- many ML researchers spend years trying to shave off tenths of percentage points of accuracy on it! See eg: https://benchmarks.ai/mnist. It is generally considered to be an easy problem in machine learning; very simple classifiers perform >90% accurate.

Here, we won't bother with classification, but the semantic meaning of the data points as numbers will help us to understand what PCA is doing with the data.

In [None]:
# Get mnist data (takes ~30 seconds)
mnist = fetch_openml(name='mnist_784', as_frame = False)

In [None]:
X = mnist.data.astype('int')
y = mnist.target.astype('int')  # 70,000-integer vector; tells which digit is represented by each image

In [None]:
# Visualize nine of the images
plot_MNIST_sample(X)

## PCA Visualization
This PCA tutorial is partially borrowed / entirely inspired by NeuroMatch Academy's Dimensionality Reduction tutorial: https://compneuro.neuromatch.io/tutorials/W1D4_DimensionalityReduction/student/W1D4_Tutorial3.html

Let's run PCA on the MNIST dataset and see what we get. Scikit-learn has a [handy PCA object](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html) that takes care of the math for us. Run the following cells and then answer questions 1-3.

In [None]:
from sklearn.decomposition import PCA

# Initialize a pca model
pca = PCA(random_state=2022)

# TODO: Normalize each variable (ie columnn) (ie pixel position) by subtracting its mean across the dataset
X_norm = ...

# TODO: Run PCA on the data ("fit"), and return the representation of the data in the new coordinates ("transform")
pca.fit(...)
X_transform = ...

# Extract PCs and var explained from the model object
components = pca.components_
explained_variance = pca.explained_variance_

In [None]:
# Plot the data in the new coordinates
x_pc = 0
y_pc = 1

plt.figure(figsize=(12,8))
plt.scatter(X_transform[:,x_pc], X_transform[:,y_pc], s=2, c=y, cmap='tab10')
plt.xlabel('PC1', fontsize=18)
plt.ylabel('PC2', fontsize=18)
plt.title('MNIST in the reduced PCA space', fontsize=18)
plt.colorbar(ticks=range(10))
plt.clim(-0.5, 9.5)

In [None]:
# Visualize the components (each 1 x num pixels) as if they were also digits
for i in range(3):
    plot_MNIST_weights(components[i,:])
    plt.title(f'MNIST PC {i}')

In [None]:
# Plot the explained variance vs the i-th PC
plt.plot(explained_variance)
plt.xlabel('Component number')
plt.ylabel('Variance')
plt.title('Scree plot for MNIST PCA (variance explained by i-th component)')

In [None]:
#TODO (Q2): what percentage of the total variance does the first principal component explain?
percent_explained_PC1 = ...

print(percent_explained_PC1)

**QUESTION 1:** Which of the following questions does PCA analysis help to answer?
* Which set of new vectors best reconstrcut the original data?
* Which set of predictors are maximally uncorrelated with each other?
* Which set of predictors describe the axes along which the data vary most?

**QUESTION 2:** What percentage of the total variance (rounded to one decimal) does PC1 explain?

**QUESTION 3:** in the scatter plot of PC1 vs. PC2, what does each point represent?
* One picture from the dataset
* One pixel from one picture from the dataset
* One component of the PCA results

**QUESTION 4:** How would you interpret PC1's meaning, in the context of MNIST? What does it tell you about the digits in the dataset?

**QUESTION 5:** About how many principal components would you use to reconstruct the MNIST dataset with high accuracy but substantially reduced dimensionality? 
* 2
* 20
* 100


In [None]:
# TODO: using the number of PCs from question 5, reconstruct the orginal MNIST dataset 
# from the transformed data and the principal components. Then determine the mean squared error for your reconstruction.

# Hint: you'll need to use matrix multiplication, and don't forget to add back the column means!

n_pcs_to_use = 20
X_reconstructed = ...
reconstruction_pca_mse = np.mean((X - X_transform)**2)
print(reconstruction_pca_mse)

plot_MNIST_sample(X_reconstructed)


**QUESTION 6:** What is the mean squared error for your reconstruction?
    
    

You should see readable, if slightly blurry, digits in the plot above!

## Comparing PCA to random projections
Another common technique for performing dimensionality reduction (although it is becoming less common as computers get faster) is using random projections. Instead of calculating the optimal set of projection axes (i.e. the principal components), we just use random ones that happen to have decent statistical properties, on average.

Here, we compare one type of random projections to PCA, based on variance explained by the axes of the projections.

In [None]:
from sklearn.random_projection import SparseRandomProjection

n_iter = 100
np.random.seed(2022)

# TODO: simulate many different random projections of the MNIST dataset. 
# Get the maximum variance explained by any single dimension of the transformed data, 
# and compare the distribution of these values to the maximum variance explained from PCA (ie PC1).

# Hint: use the sklearn SparseRandomProjection function, and the helper function get_variance_explained(), defined above.

max_var_explained = np.zeros(n_iter,)
for i in range(n_iter):
    sparse_rand_proj = ...
    X_rand = ...
    max_var_explained[i] = ...

In [None]:
plt.hist(max_var_explained, label='Max of random')
plt.axvline(explained_variance[0], color='r', label='PC1')
plt.xlabel('Variance')
plt.legend()

**QUESTION 7:** On average, how many times more/less variance does the best sparse projection component explain than the first principal component?

In [None]:
# TODO: answer question 7
variance_ratio = ...
variance_ratio

In [None]:
# TODO: similar to how you did above, reconstruct the original MNIST data from a sparse random projection.
# Then calculate the MSE, and visualize it.

n_sparse_projs_to_use = int(variance_ratio * n_pcs_to_use) # Multiply these so that the reconstructions explain roughly similar amounts of variance.
sparse_rand_proj = ...
X_rand = ...
X_reconstructed_from_sparse = ...


plot_MNIST_sample(X_reconstructed_from_sparse)

**QUESTION 8:** How does the sparse random reconstruction differ from the PCA-based reconstruction? Explain why it differs in this way.

Before we leave PCA, it is worth acknowledging that there are lots of other linear dimensionality reduction algorithms out there. If you're interested in applying dimensionality reduction to neural activity, I highly recommend [this video](https://youtu.be/zeBFyRaoVnQ?t=1394) by Byron Yu; this link points you to the part of the video where he covers a bunch of different techniques and when you might use each one. Very useful!

## UMAP for clustering

Now, let's run UMAP on the same dataset and see how the results differ from PCA. Run the cells, then answer the questions below.

In [None]:
# This cell may take 1-2 minutes

# Initialize the UMAP model for 2d 
umap_model = umap.UMAP(n_components=2, random_state=2022)

# Run UMAP on the data ("fit"), and return the representation of the data in the new coordinates ("transform")
umap_model.fit(...)
embedding = ...

In [None]:
# Plot the data with UMAP's built-in plotting utility; this uses datashader on the backend to make a nice tidy plot
umap.plot.points(umap_model, labels=y)
hover_data = pd.DataFrame({'index':np.arange(X.shape[0]),
                           'label':y})

In [None]:
# We can also plot the data with an interactive package called holoviews (Bokeh)

# Start Bokeh
output_notebook()

# Make the plot
hoverable_2d_scatter_with_images(embedding, y, X, (28,28), s=2)

**QUESTION 9:**: UMAP appears to separate the digits far better than PCA did. This suggests that the data are separable in what way?
* Linearly
* Non-linearly
* Uniformly

**QUESTION 10:** UMAP attempts to preserve local structure in the data. Using the hover tooltip, examine the way that the digits change within each digit cluster (i.e. from the top to bottom of all the 1's). Did UMAP appear to preserve local structure in this case?

**QUESTION 11:** Zoom into the left edge of the 6's cluster. You should see a few points that are not 6's (colored differently). Examine the corresponding images. What happened here -- did UMAP make a mistake?

**QUESTION 12:** Compare and contrast the insights into the MNIST dataset that PCA and UMAP give us. Which do you feel reveals more about the dataset?

## "The Specious Art of Single-Cell Genomics"

This exercise recreates a piece of data from Chari et al. 2021, a paper from Lior Pachter's lab that examines the use of clustering algorithms like PCA and UMAP in the analysis of single-cell RNA sequencing data. The paper's main argument is that investigators should not do quantifications in UMAP output space (i.e. cell-to-cell distance, relative spacing of different cell types, "trajectories") in order to draw conclusions about biological variables of interest, because UMAP distorts the data too heavily to be interpretable. 

Their line of reasoning is two-fold. First, it is impossible to reduce high-dimensional data to low-dimensional data without distorition; this is a mathematical fact, and is unavoidble. So, second, the authors ask, how good of a job does UMAP do at doing that reduction, given the inevitable distortions? The authors argue that it does a very bad job, and to prove their point, they design an auto-encoder that, by their metrics, performs just as well as UMAP in maitaining the characteristics of the original data, but shapes the 2D data in any way you like (for example, an elephant or a world map). 

Here, we recreate one piece of data from their Figure 2b: the correlation between all inter-cell-type distances as measured in ambient space (log-normalized counts), PCA space, or PCA-UMAP space. Put another way: given a gene expression dataset with $T$ cell types, first find the centroid (mean) of each cell type, then find the set of $T*(T-1)/2$ inter-centroid distances. Repeat that calculate for the ambient, PCA, and PCA-UMAP spaces. Finally, correlate the $T*(T-1)/2$-length vectors in pairs (ambient vs PCA, ambient vs PCA-UMAP) and see how the UMAP correlation compares to the PCA correlation.

A technical note: in single-cell analysis, it is standard practice to perform UMAP *on top of* PCA. That is, one first runs PCA on the data; then one runs UMAP on the output of the PCA. This is because UMAP tends to fail in very high dimensions (i.e. many thousands, which is what single cell analyses often have), but PCA can help reduce that number to ~50 or so.

In [None]:
# Load the data
data_path = '/Users/jonahpearl/Documents/PiN/G3/TAD/TAD_python/data'
count_mat = sio.mmread(os.path.join(data_path, 'tenx.mtx'))
count_mat.shape

In [None]:
meta = pd.read_csv(os.path.join(data_path, 'metadata.csv'), index_col=0)
print(meta.shape)
meta.head()

Let's briefly examine the data.

**QUESTION 13**: How many genes are present in this dataset? How many cells? (Hint: think about what count_mat and meta each represent.) 

**QUESTION 14**: Which of these is represents the dimensionality of the data?

**QUESTION 15**: How many unique cell types ("clusters") are there in this data set?

In [None]:
#TODO: find the number of unique cell types in the data set (hint -- this can be done in one simple line of code!)
...

Let's find the centroids for each cell type.

In [None]:
#TODO: write a function to find the centroids for each cluster.
# We define the centroid of a cluster to be the average of all cells in that cell type.

def get_centroids(counts, cluster_labels):
    """Get centroids of clusters in a set of points
    Arguments:
        counts {np.array} -- a cells x genes matrix of rna counts
        cluster_labels {array or list} -- cluster labels for each cell
    
    Returns:
        clusters {np.array} -- list of all unique clusters
        centroids {np.array} -- the centroid for each corresponding cluster
    """
    unique_clusters = np.unique(cluster_labels)
    centroids = np.zeros((len(unique_clusters), counts.shape[1]))
    
    for i in range(len(unique_clusters)):
        cells_in_cluster = ...
        centroids[i, :] = ...
    
    return unique_clusters, centroids

# Use your function to get the centroids for each cluster in the ambient space (count_mat)
clusters, centroids = get_centroids(count_mat, meta.cluster)

**QUESTION 16:** What is the value of the second dimension of the centroid of the Esr1_6 cells?

In [None]:
# TODO: find the value of the second dimension of the centroid of the Esr1_6 cells
esr1_6_idx = ...
centroids[esr1_6_idx][0,1]

In [None]:
#TODO: write a function that gets the T(T-1)/2 inter-centroid distances, and returns them as a single vector.
# We will use L1 distance, which is more suitable for sparse high dimensional data than L2.
def get_pairwise_dists(points):
    """Get pairwise L1 distances of a set of points
    Arguments:
        points {np.array} -- a points x dimensions matrix
    Returns:
        dists {np.array} -- a vector of size N*N-1 pairwise distances
    """
    
    all_dists = ...  # hint: check out scipy.metric.pairwise_distances. Make sure to use the L1 norm!
    
    # Remove redundant pairs
    dists = np.triu(all_dists)
    dists = dists[dists != 0]
    
    return dists
    

**QUESTION 17**: What is the mean pairwise distance between clusters in the ambient space?

In [None]:
pairwise_dist_vector = get_pairwise_dists(centroids)
pairwise_dist_vector.mean()

In [None]:
# Check that there are the expected number of inter-centroid distance values
assert pairwise_dist_vector.shape[0] == centroids.shape[0]*(centroids.shape[0] - 1)/2

In [None]:

def get_pca_and_umap(counts, n_pcs=50, n_umap_components=2):
""" This function returns PCA and PCA-UMAP representations of the input data
"""    
    print('Scale')
    # Scale the data for PCA
    scaled_mat = scale(counts)
    
    print('PCA')
    # Get the PCA object
    tsvd = TruncatedSVD(n_components=n_pcs)
    
    # Run PCA
    x_pca = tsvd.fit_transform(scaled_mat)
    
    print('UMAP')
    # Get the UMAP object
    reducer = umap.UMAP(n_components=n_umap_components)
    
    # Run UMAP. Note that we purposely run UMAP on the output of the PCA!
    x_pca_UMAP = reducer.fit_transform(x_pca)
    
    return x_pca, x_pca_UMAP

In [None]:
# Get PCA and UMAP representations of the data
counts_pca, counts_pca_UMAP = get_pca_and_umap(count_mat)

In [None]:
# TODO: get pairwise distance vectors and correlate to ambient
...
pairwise_dist_vector_pca = ...

...
pairwise_dist_vector_UMAP = ...

print(np.corrcoef(pairwise_dist_vector, pairwise_dist_vector_pca)[0,1])
print(np.corrcoef(pairwise_dist_vector, pairwise_dist_vector_UMAP)[0,1])

**QUESTION 18:** What is the correlation between the ambient and UMAP space?

Now, use bootstrapping to estimate the standard error of these correlations.

In [None]:
# TODO: bootstrap the correlations
n_iters = 10  # each iteration takes ~10-20 seconds, so don't do too many!
amb_pca_corrs = np.zeros(n_iters)
amb_UMAP_corrs = np.zeros(n_iters)
np.random.seed(10)
for i in range(n_iters):
    print(i)
    
    # Resample from the data
    ...
    
    # Get PCA and UMAP representations of the re-sampled data
    ...
    
    # Find cluster centroids in the re-sampled data
    ...
    
    # Get pairwise distances between centroid pairs in each latent space
    ambient_dists = ...
    pca_dists = ...
    umap_dists = ...
    
    # Correlate the distances in ambient spce to PCA and UMAP
    amb_pca_corrs[i] = ...
    amb_UMAP_corrs[i] = ...
    

In [None]:
# Plot the results of the bootstrapping
from seaborn import pointplot
corrs_df = pd.DataFrame({'pca': amb_pca_corrs, 'pca_UMAP': amb_UMAP_corrs}).melt(var_name='latent', value_name='corr_to_ambient')
pointplot(data=corrs_df, y='corr_to_ambient', x='latent')

**QUESTION 18**: What is the standard error for the correlation between the ambient and UMAP data?



In [None]:
std = ...
print(std)

**QUESTION 19**: Generate a p-value for the hypothesis that the correlation between the ambient and the UMAP representations is significantly lower than the correlation between the ambient and the PCA representations. Is there a real effect?

In [None]:
stats.ranksums(..., ... alternative='two-sided')