# Plot PCA and UMAP of language model embedding of test sets ("unsupervised embedding")

- Data was scaled using training set's scaler
- PCA was computed already, using training set's PCA
- UMAP computed from scratch (first on training set, then applied to test set)

This is different from supervised embedding which comes out of the trained classifiers.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
import seaborn as sns
import pandas as pd
import genetools
from genetools.palette import HueValueStyle
import gc
from slugify import slugify
from typing import Optional

import malid.external.genetools_plots
import malid.external.genetools_scanpy_helpers

sns.set_style("dark")

import scanpy as sc

from malid import config, io, logger, helpers
from malid.datamodels import TargetObsColumnEnum, healthy_label
from malid.external import genetools_plots
import choosegpu

choosegpu.configure_gpu(enable=True, memory_pool=True)

['GPU-cc60845a-1c98-0d24-b5e4-0ce9c58dd2bb']

In [2]:
disease_color_palette = helpers.disease_color_palette.copy()
disease_color_palette[healthy_label] = HueValueStyle(
    color=disease_color_palette[healthy_label], zorder=-15, alpha=0.5
)

In [3]:
def plot_background(
    all_points, representation, fold_id, fold_label, gene_locus, plt_quantile=0.01
):
    # plot the background, with lower zorder than any palette zorders
    fig, ax = genetools.plots.scatterplot(
        data=all_points,
        x_axis_key=f"X_{representation}1",
        y_axis_key=f"X_{representation}2",
        hue_key="disease",
        discrete_palette=disease_color_palette,
        alpha=0.5,
        marker=".",
        marker_size=malid.external.genetools_plots.get_point_size(all_points.shape[0]),
        figsize=(5, 5),
        legend_title="Disease",
        enable_legend=True,
        remove_x_ticks=True,
        remove_y_ticks=True,
    )

    ax.set_title(
        f"Fold {fold_id} {fold_label}, {representation}, {gene_locus} - all diseases"
    )

    # Zoom in
    ax.set_xlim(
        np.quantile(all_points[f"X_{representation}1"], plt_quantile),
        np.quantile(all_points[f"X_{representation}1"], 1 - plt_quantile),
    )
    ax.set_ylim(
        np.quantile(all_points[f"X_{representation}2"], plt_quantile),
        np.quantile(all_points[f"X_{representation}2"], 1 - plt_quantile),
    )

    # Put sample sizes in legend
    genetools.plots.add_sample_size_to_legend(
        ax=ax,
        data=all_points,
        hue_key="disease",
    )

    return fig, ax

In [4]:
def plot_within_disease(
    all_points,
    disease,
    representation,
    fold_id,
    fold_label,
    gene_locus,
    plt_quantile=0.01,
):
    foreground_points = all_points[all_points["disease"] == disease]

    fig, ax = plt.subplots(figsize=(5, 5))

    foreground_hue_key = "study_name"
    foreground_palette = helpers.study_name_color_palette
    foreground_legend_title = "study name"
    foreground_marker_size = malid.external.genetools_plots.get_point_size(
        all_points.shape[0]
    )
    foreground_marker = "o"

    plot_title = (
        f"{disease}, fold {fold_id} {fold_label}, {representation}, {gene_locus}"
    )

    # plot the foreground
    genetools.plots.scatterplot(
        data=foreground_points,
        x_axis_key=f"X_{representation}1",
        y_axis_key=f"X_{representation}2",
        hue_key=foreground_hue_key,
        discrete_palette=foreground_palette,
        ax=ax,
        enable_legend=True,
        alpha=0.5,
        marker=foreground_marker,
        marker_size=foreground_marker_size,
        legend_title=foreground_legend_title,
        remove_x_ticks=False,
        remove_y_ticks=False,
    )
    ax.set_title(plot_title)

    # Zoom in
    ax.set_xlim(
        np.quantile(all_points[f"X_{representation}1"], plt_quantile),
        np.quantile(all_points[f"X_{representation}1"], 1 - plt_quantile),
    )
    ax.set_ylim(
        np.quantile(all_points[f"X_{representation}2"], plt_quantile),
        np.quantile(all_points[f"X_{representation}2"], 1 - plt_quantile),
    )

    ax.set_aspect("equal", "datalim")  # change axes limits to get 1:1 aspect

    # Put sample sizes in legend
    genetools.plots.add_sample_size_to_legend(
        ax=ax,
        data=foreground_points,
        hue_key=foreground_hue_key,
    )

    return fig, ax

In [5]:
def plot_within_disease_overall_density(
    all_points,
    disease,
    representation,
    fold_id,
    fold_label,
    gene_locus,
    plt_quantile=0.01,
):
    xcol = f"X_{representation}1"
    ycol = f"X_{representation}2"
    xlims = (
        np.quantile(all_points[xcol], plt_quantile),
        np.quantile(all_points[xcol], 1 - plt_quantile),
    )
    ylims = (
        np.quantile(all_points[ycol], plt_quantile),
        np.quantile(all_points[ycol], 1 - plt_quantile),
    )

    foreground_points = all_points[all_points["disease"] == disease]

    # filter down
    # TODO: do this early in other methods too?
    foreground_points = foreground_points[
        (foreground_points[xcol] >= xlims[0])
        & (foreground_points[xcol] <= xlims[1])
        & (foreground_points[ycol] >= ylims[0])
        & (foreground_points[ycol] <= ylims[1])
    ]

    fig, ax = plt.subplots(figsize=(5, 5))

    # set minimum count for background cells: https://stackoverflow.com/a/5405654/130164
    # also set grid size
    hexplotted = ax.hexbin(
        foreground_points[xcol],
        foreground_points[ycol],
        mincnt=10,
        gridsize=25,
        cmap="Blues",
    )

    # Add color bar.
    # see also https://stackoverflow.com/a/44642014/130164
    # Pull colorbar out of axis by creating a special axis for the colorbar - rather than distorting main ax.
    # specify width and height relative to parent bbox
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes

    colorbar_ax = inset_axes(
        ax,
        width="5%",
        height="80%",
        loc="center left",
        bbox_to_anchor=(1.05, 0.0, 1, 1),
        bbox_transform=ax.transAxes,
        borderpad=0,
    )

    colorbar = fig.colorbar(hexplotted, cax=colorbar_ax, label="Density")

    # set global "current axes" back to main axes,
    # so that any calls like plt.title target main ax rather than inset colorbar_ax
    plt.sca(ax)

    ax.set_title(
        f"{disease}, fold {fold_id} {fold_label}, {representation}, {gene_locus}"
    )

    ax.set_aspect("equal", "datalim")  # change axes limits to get 1:1 aspect

    return fig, ax

In [6]:
def plot_within_disease_relative_density(
    all_points,
    disease,
    representation,
    positive_class,  # which study name to use as numerator in proportion
    fold_id,
    fold_label,
    gene_locus,
    n_bins=25,  # per dimension
    minimal_bin_density_quantile: Optional[float] = 0.50,  # drop bins with low counts
    plt_quantile=0.01,  # zoom in
):
    foreground_points = all_points[all_points["disease"] == disease]

    xlims = (
        np.quantile(all_points[f"X_{representation}1"], plt_quantile),
        np.quantile(all_points[f"X_{representation}1"], 1 - plt_quantile),
    )
    ylims = (
        np.quantile(all_points[f"X_{representation}2"], plt_quantile),
        np.quantile(all_points[f"X_{representation}2"], 1 - plt_quantile),
    )

    fig, ax, description = genetools_plots.two_class_relative_density_plot(
        data=foreground_points,
        x_key=f"X_{representation}1",
        y_key=f"X_{representation}2",
        hue_key="study_name",
        positive_class=positive_class,
        colorbar_label=f"Proportion of {positive_class} vs all {disease}",
        quantile=minimal_bin_density_quantile,
        figsize=(5, 5),
        n_bins=n_bins,
        range=(xlims, ylims),  # Only use zoom extent
    )

    ax.set_title(
        f"{disease}, fold {fold_id} {fold_label}, {representation}, {gene_locus}\n{description}"
    )

    # Zoom in
    ax.set_xlim(xlims)
    ax.set_ylim(ylims)

    ax.set_aspect("equal", "datalim")  # change axes limits to get 1:1 aspect

    return fig, ax

In [7]:
fold_label = "test"

In [8]:
for gene_locus in config.gene_loci_used:
    for fold_id in config.cross_validation_fold_ids:
        logger.info(f"Processing fold {fold_id}-{fold_label}, {gene_locus}")

        # Load test set
        adata = io.load_fold_embeddings(
            fold_id=fold_id,
            fold_label=fold_label,
            gene_locus=gene_locus,
            target_obs_column=TargetObsColumnEnum.disease,
        )
        assert not adata.obs["study_name"].isna().any()

        # Construct UMAP for each test set anndata
        # It's dependent on training set
        # Fit on training set (loaded and thrown away), apply to test set. Both are already scaled.
        # (This has already been done for PCA)
        _, adata = malid.external.genetools_scanpy_helpers.umap_train_and_test_anndatas(
            adata_train=io.load_fold_embeddings(
                fold_id=fold_id,
                fold_label="train_smaller",
                gene_locus=gene_locus,
                target_obs_column=TargetObsColumnEnum.disease,
            ),
            adata_test=adata,
            n_neighbors=15,
            n_components=2,
            inplace=True,
            random_state=0,
            use_rapids=True,
            use_pca=True,
        )

        # add PCA and UMAP to obs
        obsm_df = adata.obsm.to_df()
        adata.obs = genetools.helpers.horizontal_concat(
            adata.obs,
            obsm_df[obsm_df.columns[obsm_df.columns.str.startswith("X_umap")]],
        )
        adata.obs = genetools.helpers.horizontal_concat(
            adata.obs,
            obsm_df[["X_pca1", "X_pca2"]],
        )

        all_points = adata.obs

        for representation in ["umap", "pca"]:
            # Plot all diseases
            fig, ax = plot_background(
                all_points=all_points,
                representation=representation,
                fold_id=fold_id,
                fold_label=fold_label,
                gene_locus=gene_locus,
            )
            genetools.plots.savefig(
                fig,
                config.paths.output_dir
                / f"language_model_embedding.all_diseases.{representation}.fold_{fold_id}_{fold_label}.{gene_locus.name}.png",
                dpi=72,
            )
            plt.close(fig)

            # Compare by batch
            for disease in all_points["disease"].unique():
                fig, ax = plot_within_disease(
                    all_points=all_points,
                    disease=disease,
                    representation=representation,
                    fold_id=fold_id,
                    fold_label=fold_label,
                    gene_locus=gene_locus,
                )
                genetools.plots.savefig(
                    fig,
                    config.paths.high_res_outputs_dir
                    / f"language_model_embedding.by_batch.{slugify(disease)}.{representation}.fold_{fold_id}_{fold_label}.{gene_locus.name}.png",
                    dpi=72,
                )
                plt.close(fig)

                # Plot overall density
                fig, ax = plot_within_disease_overall_density(
                    all_points=all_points,
                    disease=disease,
                    representation=representation,
                    fold_id=fold_id,
                    fold_label=fold_label,
                    gene_locus=gene_locus,
                )
                genetools.plots.savefig(
                    fig,
                    config.paths.high_res_outputs_dir
                    / f"language_model_embedding.by_batch.{slugify(disease)}.{representation}.overall_density.fold_{fold_id}_{fold_label}.{gene_locus.name}.png",
                    dpi=72,
                )
                plt.close(fig)

                # Plot relative density, if we have two batches only
                study_names = all_points[all_points["disease"] == disease][
                    "study_name"
                ].unique()
                if len(study_names) == 2:
                    fig, ax = plot_within_disease_relative_density(
                        all_points=all_points,
                        disease=disease,
                        representation=representation,
                        # which study name to use as numerator in proportion
                        positive_class=study_names[1],
                        fold_id=fold_id,
                        fold_label=fold_label,
                        gene_locus=gene_locus,
                        minimal_bin_density_quantile=None,
                    )
                    genetools.plots.savefig(
                        fig,
                        config.paths.high_res_outputs_dir
                        / f"language_model_embedding.by_batch.{slugify(disease)}.{representation}.relative_density.fold_{fold_id}_{fold_label}.{gene_locus.name}.png",
                        dpi=72,
                    )
                    plt.close(fig)

                    fig, ax = plot_within_disease_relative_density(
                        all_points=all_points,
                        disease=disease,
                        representation=representation,
                        # which study name to use as numerator in proportion
                        positive_class=study_names[1],
                        fold_id=fold_id,
                        fold_label=fold_label,
                        gene_locus=gene_locus,
                        minimal_bin_density_quantile=0.50,
                    )
                    genetools.plots.savefig(
                        fig,
                        config.paths.high_res_outputs_dir
                        / f"language_model_embedding.by_batch.{slugify(disease)}.{representation}.relative_density_with_min_density_requirement.fold_{fold_id}_{fold_label}.{gene_locus.name}.png",
                        dpi=72,
                    )
                    plt.close(fig)

        del adata
        io.clear_cached_fold_embeddings()
        gc.collect()

2022-12-27 00:52:55,158 - unsupervised_embedding_visualize.ipynb - INFO - Processing fold 0-test, GeneLocus.BCR


2022-12-27 00:52:55,162 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/BCR/fold.0.test.h5ad -> /srv/scratch/maximz/cache/d23bdbcb1cb8d1c3007da595727fbbab8cc86779ab8967868314f2da.0.test.h5ad


Only considering the two last: ['.test', '.h5ad'].


Only considering the two last: ['.test', '.h5ad'].




2022-12-27 00:54:43,611 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/BCR/fold.0.train_smaller.h5ad -> /srv/scratch/maximz/cache/7f04f4bafb0ade43dcc30c658cb88992fd0cd5728e955bb146f9115a.0.train_smaller.h5ad


Only considering the two last: ['.train_smaller', '.h5ad'].


Only considering the two last: ['.train_smaller', '.h5ad'].




2022-12-27 03:10:46,907 - unsupervised_embedding_visualize.ipynb - INFO - Processing fold 1-test, GeneLocus.BCR


2022-12-27 03:10:46,911 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/BCR/fold.1.test.h5ad -> /srv/scratch/maximz/cache/ed1f2608e168c24c508b6bfbbe3a18ba1a4680e5c010e34cd184cae6.1.test.h5ad


Only considering the two last: ['.test', '.h5ad'].


Only considering the two last: ['.test', '.h5ad'].




2022-12-27 03:12:50,573 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/BCR/fold.1.train_smaller.h5ad -> /srv/scratch/maximz/cache/4110bb3aa0fb718bb62fede8d6a704d8165d8d6bc349377491f87944.1.train_smaller.h5ad


Only considering the two last: ['.train_smaller', '.h5ad'].


Only considering the two last: ['.train_smaller', '.h5ad'].




2022-12-27 05:58:58,518 - unsupervised_embedding_visualize.ipynb - INFO - Processing fold 2-test, GeneLocus.BCR


2022-12-27 05:58:58,523 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/BCR/fold.2.test.h5ad -> /srv/scratch/maximz/cache/83dcd409138574af7b7b712ce14967e926c14170bff7801b141edb49.2.test.h5ad


Only considering the two last: ['.test', '.h5ad'].


Only considering the two last: ['.test', '.h5ad'].




2022-12-27 06:01:23,858 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/BCR/fold.2.train_smaller.h5ad -> /srv/scratch/maximz/cache/578e6ec41823516592527ab2d5010994f0b640e6b5c16cfd655e655d.2.train_smaller.h5ad


Only considering the two last: ['.train_smaller', '.h5ad'].


Only considering the two last: ['.train_smaller', '.h5ad'].




2022-12-27 08:38:33,097 - unsupervised_embedding_visualize.ipynb - INFO - Processing fold 0-test, GeneLocus.TCR


2022-12-27 08:38:33,102 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/TCR/fold.0.test.h5ad -> /srv/scratch/maximz/cache/a6c3ab8bb9960154ff8dbb5d04eaf7a1c395f034f82f9649aefc9d35.0.test.h5ad


Only considering the two last: ['.test', '.h5ad'].


Only considering the two last: ['.test', '.h5ad'].




2022-12-27 08:42:17,759 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/TCR/fold.0.train_smaller.h5ad -> /srv/scratch/maximz/cache/6b796871077fee80a869f8764118b0fdb69bf9ca37a5d9a043cf68c5.0.train_smaller.h5ad


Only considering the two last: ['.train_smaller', '.h5ad'].


Only considering the two last: ['.train_smaller', '.h5ad'].




2022-12-27 13:26:56,668 - unsupervised_embedding_visualize.ipynb - INFO - Processing fold 1-test, GeneLocus.TCR


2022-12-27 13:26:56,679 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/TCR/fold.1.test.h5ad -> /srv/scratch/maximz/cache/d3eb18b805f6482478ad1ee75a5249a449864d796469d3e144db1326.1.test.h5ad


Only considering the two last: ['.test', '.h5ad'].


Only considering the two last: ['.test', '.h5ad'].




2022-12-27 13:30:25,586 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/TCR/fold.1.train_smaller.h5ad -> /srv/scratch/maximz/cache/117eb95eb3cad5ad91c5749ddfa0db422db06a1c91f41e0fdc99057d.1.train_smaller.h5ad


Only considering the two last: ['.train_smaller', '.h5ad'].


Only considering the two last: ['.train_smaller', '.h5ad'].




2022-12-27 18:22:10,070 - unsupervised_embedding_visualize.ipynb - INFO - Processing fold 2-test, GeneLocus.TCR


2022-12-27 18:22:10,074 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/TCR/fold.2.test.h5ad -> /srv/scratch/maximz/cache/d74a0f1a51ad1e35ec4654f627621ecdc4abf48e3481e399db8af384.2.test.h5ad


Only considering the two last: ['.test', '.h5ad'].


Only considering the two last: ['.test', '.h5ad'].




2022-12-27 18:25:38,858 - malid.external.scratch_cache - INFO - Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20221224/embedded/unirep_fine_tuned/anndatas_scaled/TCR/fold.2.train_smaller.h5ad -> /srv/scratch/maximz/cache/99bd53c7c35bce43fbec0229ffe1fc112a054c1e0ca9816d6fa80185.2.train_smaller.h5ad


Only considering the two last: ['.train_smaller', '.h5ad'].


Only considering the two last: ['.train_smaller', '.h5ad'].


