# Module 2.2 Basic Visualization of Data with `plotnine`

In this module, we will learn how to use the `plotnine` package to create basic visualizations of data.

`plotnine` is a Python implementation of the 'Grammar of Graphics' that was originally developed in the R package `ggplot2`, which is a powerful and flexible visualization package. The grammar of graphics, is a theory of visualization that describes how to **build** a visualization from components. The grammar of graphics is described in detail in the book [The Grammar of Graphics](https://www.springer.com/gp/book/9780387245447) by Leland Wilkinson.

## The Grammar of Graphics

The grammar of graphics describes a visualization as a mapping between data and aesthetic attributes (e.g. color, shape, size, etc.) of geometric objects (e.g. points, lines, bars, etc.). The grammar of graphics is composed of the following components:

- **Data**: The data to be visualized (usually a `pandas` `DataFrame`)
- **Aesthetics**: The aesthetic attributes of the geometric objects. (e.g. color, shape, size, etc.)
- **Geometric Objects**: The geometric objects that represent the data (e.g. points, lines, bars, etc.).
- **Scales**: The scales that map values in the data space to values in the aesthetic space.
- **Coordinate System**: The coordinate system that defines the space in which the geometric objects are drawn.
- **Statistical Transformations**: The statistical transformations that are applied to the data before plotting.
- **Facets**: The facets that are used to split the data into subsets and plot each subset on a separate panel.

Using these components, figures are built up in layers. For example, a scatter plot is composed of a layer of points, a layer of scales, and a layer of coordinate system. The grammar of graphics provides a *framework* for building up figures in layers.

In [None]:
import warnings
import numpy as np
import pandas as pd
import plotnine as pn

from plotnine.options import set_option

set_option("base_family", "Helvetica")
warnings.filterwarnings("ignore")

In [None]:
# Import the data as a pandas dataframe
data = pd.read_csv("data/GSE63482_Expression_matrix.tsv", sep="\t")

data

In [None]:
scatter = pn.ggplot(data, pn.aes(x="E15_cpn", y="E18_cpn"))

scatter.draw()

In [None]:
pn.ggplot(data, pn.aes(x="E15_cpn", y="E18_cpn")) + pn.geom_point()

In [None]:
scatter = scatter + pn.geom_smooth(method="lm")

scatter.draw()

In [None]:
# Melt the 'wide' dataframe into a 'long' dataframe
data_melted = data.melt(id_vars=["gene_id"])

data_melted

In [None]:
# Split the 'variable' column into two columns ['age','celltype']
data_melted[["age", "celltype"]] = data_melted["variable"].str.split("_", expand=True)

data_melted

In [None]:
box_plot = (
    pn.ggplot(data_melted, pn.aes(x="age", y="value", fill="celltype"))
    + pn.geom_boxplot()
)
box_plot.draw()

In [None]:
# Log transform the gene expression values
data_melted["log_value"] = np.log(data_melted["value"] + 1)

box_plot = (
    pn.ggplot(data_melted, pn.aes(x="age", y="log_value", fill="celltype"))
    + pn.geom_boxplot()
)
box_plot.draw()

In [None]:
violin_plot = (
    pn.ggplot(data_melted, pn.aes(x="age", y="log_value", fill="celltype"))
    + pn.geom_violin()
)
violin_plot.draw()

## Exploring a little deeper

## Faceting

In [None]:
box_plot = (
    pn.ggplot(data_melted, pn.aes(x="variable", y="log_value", fill="celltype"))
    + pn.geom_boxplot()
    + pn.theme(axis_text_x=pn.element_text(rotation=45, hjust=1))
)
box_plot.draw()

In [None]:
box_plot + pn.facet_wrap("age", scales="free_x")

In [None]:
(
    box_plot
    + pn.facet_wrap("celltype", scales="free_x")
    + pn.theme(axis_text_x=pn.element_text(rotation=45, hjust=1))
)

In [None]:
(
    box_plot
    + pn.facet_grid("age ~ celltype", scales="free_x")
    + pn.theme(axis_text_x=pn.element_text(rotation=45, hjust=1))
)

In [None]:
pax6_plot = pn.ggplot(
    data_melted[data_melted["gene_id"] == "Pax6"],
    pn.aes(x="age", y="value", color="celltype"),
)

pax6_plot.draw()

In [None]:
pax6_plot = pax6_plot + pn.geom_point()

pax6_plot.draw()

In [None]:
pax6_plot = pax6_plot + pn.geom_line(pn.aes(group="celltype"))

pax6_plot.draw()

In [None]:
# Create a reusable function to make a plot for a given gene
def plot_gene(gene_id: str):
    return (
        pn.ggplot(
            data_melted[data_melted["gene_id"] == gene_id],
            pn.aes(x="age", y="value", color="celltype"),
        )
        + pn.geom_point()
        + pn.geom_line(pn.aes(group="celltype"))
        + pn.labs(title=gene_id)
    )


plot_gene("Dlx1")

In [None]:
gene_list = ["Cux1", "Tle4", "Bcl11b"]

[plot_gene(gene).save(f"{gene}.pdf") for gene in gene_list]

### Heatmap Example


In [None]:
gene_list = data["gene_id"].sample(20).tolist()

plot_df = data_melted[data_melted["gene_id"].isin(gene_list)].copy()

# Row normalize the data
groups = plot_df[["gene_id", "value"]].groupby("gene_id")
mean, std = groups.transform("mean"), groups.transform("std")
plot_df["value"] = (plot_df[mean.columns] - mean) / std


heatmap_plot = (
    pn.ggplot(plot_df, pn.aes(x="variable", y="gene_id", fill="value"))
    + pn.geom_tile()
    + pn.scale_fill_gradient2(low="blue", mid="white", high="red")
    + pn.theme(axis_text_x=pn.element_text(rotation=45, hjust=1))
)

heatmap_plot.draw()