# Introduction to matplotlib and seaborn modules for visualization
* [matplotlib](https://matplotlib.org/stable/plot_types/index.html) covers basic graphs
* [seaborn](https://seaborn.pydata.org/examples/index.html) contains advanced graphs and can make use of pandas's DataFrame directly

In [None]:
# !pip install matplotlib seaborn

### These 4 modules plus [scipy.stats](https://docs.scipy.org/doc/scipy/reference/stats.html) are the primary tools for data sciences

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import numpy as np

## Load gene expression dataset

In [None]:
data = pd.read_excel('CRC_sample_data.xlsx', sheet_name = 'expression', header = 0, index_col = 0)
data.head()

## matplotlib code template

In [None]:
plt.figure(figsize = (4, 4))

plt.scatter(data['FAP'], data['SLC5A6'])

plt.xlabel('FAP')
plt.ylabel('SLC5A6')
plt.title('FAP vs SLC5A6')

# plt.axis([6, 7, 5, 6])
# plt.xlim([6, 7])
# plt.ylim([7, 8])

# plt.savefig('figure.jpg', dpi = 300)
plt.savefig('figure.svg', dpi = 70)
plt.show()

## Multiple graphs can be placed on the same figure, with automatic coloring
Use **label** parameter and **legend()** to distinguish plots 

In [None]:
plt.figure(figsize = (5, 5))

plt.scatter(data.loc[data['CMS'] == 'CMS1', 'FAP'], 
            data.loc[data['CMS'] == 'CMS1', 'SLC5A6'],
            label = 'CMS1')

plt.scatter(data.loc[data['CMS'] == 'CMS2', 'FAP'], 
            data.loc[data['CMS'] == 'CMS2', 'SLC5A6'],
            label = 'CMS2')

plt.xlabel('FAP')
plt.ylabel('SLC5A6')
plt.title('FAP vs SLC5A6 by CMS group')

plt.legend()
plt.show()

## Key visualization settings
* Color
* Shape
* Size
* Transparency

### Color, shape, and transparency

In [None]:
plt.figure(figsize = (5, 5))

plt.scatter(data['FAP'], data['SLC5A6'],
            color = 'tab:red',
            marker = 'x',
            alpha = 0.7)

plt.xlabel('FAP')
plt.ylabel('SLC5A6')

plt.show()

### Size

In [None]:
plt.figure(figsize = (5, 5))

plt.scatter(data['FAP'], data['SLC5A6'],
            color = 'violet',
            marker = 'o',
            s = data['GFPT2'] ** 2,
            alpha = 0.7)

plt.xlabel('FAP')
plt.ylabel('SLC5A6')

plt.show()

## So many [color choices](https://matplotlib.org/stable/gallery/color/named_colors.html)

## Integration of for loop with visualization
Drawing different sample groups with different colors (and add legend)

In [None]:
print(pd.unique(data['CMS']))

In [None]:
plt.figure(figsize = (5, 5))

for cms in pd.unique(data['CMS']):
    plt.scatter(data.loc[data['CMS'] == cms, 'FAP'], 
                data.loc[data['CMS'] == cms, 'SLC5A6'],
                label = cms)

plt.xlabel('FAP')
plt.ylabel('SLC5A6')
plt.title('FAP vs SLC5A6 by CMS group')

plt.legend()
plt.show()

## Histogram

In [None]:
plt.figure(figsize = (5, 3))

plt.hist(data['AGR2'])

plt.xlabel('AGR2')
plt.ylabel('# patients')

plt.show()

### Bin sizes
numpy.arange is similar to range but can operate on real numbers, not just integers

In [None]:
plt.figure(figsize = (5, 3))

plt.hist(data['AGR2'], bins = 15)

plt.xlabel('AGR2')
plt.ylabel('# patients')

plt.show()

## Overlay histograms with transparency `alpha`

In [None]:
plt.figure(figsize = (5, 3))

plt.hist(data.loc[data['CMS'] == 'CMS2', 'AGR2'], facecolor = 'tab:orange', alpha = 0.3, label = 'CMS2')
plt.hist(data.loc[data['CMS'] == 'CMS1', 'AGR2'], facecolor = 'tab:blue', alpha = 0.3, label = 'CMS1')

plt.xlabel('AGR2')
plt.ylabel('# patients')

plt.legend()
plt.show()

### Control `bins` to match multiple histograms

In [None]:
common_bins = np.arange(4.5, 12, 0.4)

plt.figure(figsize = (5, 3))

plt.hist(data.loc[data['CMS'] == 'CMS1', 'AGR2'], facecolor = 'tab:blue', alpha = 0.5, label = 'CMS1', bins = common_bins)
plt.hist(data.loc[data['CMS'] == 'CMS2', 'AGR2'], facecolor = 'tab:orange', alpha = 0.5, label = 'CMS2', bins = common_bins)

plt.xlabel('AGR2')
plt.ylabel('# patients')

plt.legend()
plt.show()

## Convert count to density
Set **density** = True

In [None]:
common_bins = np.arange(4.5, 12, 0.4)

plt.figure(figsize = (5, 3))

plt.hist(data.loc[data['CMS'] == 'CMS1', 'AGR2'], facecolor = 'tab:blue', alpha = 0.5, label = 'CMS1', bins = common_bins,
         density = True)
plt.hist(data.loc[data['CMS'] == 'CMS2', 'AGR2'], facecolor = 'tab:orange', alpha = 0.5, label = 'CMS2', bins = common_bins,
         density = True)
plt.hist(data.loc[data['CMS'] == 'CMS3', 'AGR2'], facecolor = 'tab:green', alpha = 0.5, label = 'CMS3', bins = common_bins,
         density = True)

plt.xlabel('AGR2')
plt.ylabel('density')

plt.legend()
plt.show()

## Comparing histogram side-by-side with subplot

In [None]:
common_bins = np.arange(4, 12, 0.5)

plt.figure(figsize = (4, 3))

plt.hist(data.loc[data['CMS'] == 'CMS1', 'AGR2'], bins = common_bins, density = True)
plt.xlabel('AGR2')
plt.ylabel('density')
plt.title('CMS1')

plt.show()

plt.figure(figsize = (4, 3))

plt.hist(data.loc[data['CMS'] == 'CMS2', 'AGR2'], bins = common_bins, density = True)
plt.xlabel('AGR2')
plt.ylabel('density')
plt.title('CMS2')

plt.show()

In [None]:
common_bins = np.arange(4, 12, 0.5)

plt.figure(figsize = (8, 3))

plt.subplot(1, 2, 1)
plt.hist(data.loc[data['CMS'] == 'CMS1', 'AGR2'], bins = common_bins, density = True)
plt.xlabel('AGR2')
plt.ylabel('density')
plt.title('CMS1')

plt.subplot(1, 2, 2)
plt.hist(data.loc[data['CMS'] == 'CMS2', 'AGR2'], bins = common_bins, density = True)
plt.xlabel('AGR2')
plt.ylabel('density')
plt.title('CMS2')

plt.tight_layout()

plt.show()

## Pie chart

In [None]:
prop = data['CMS'].value_counts()
prop.head()

In [None]:
plt.figure(figsize = (5, 5))
plt.pie([20, 5, 2], labels = ['CMS1', 'CMS2', 'CMS3'], startangle = 0, autopct = '%.1f')
plt.show()

### Highlighting CMS2

In [None]:
plt.figure(figsize = (5, 5))

explode = (0, 0., 0.1)  # only 'explode' the 2nd group
plt.pie(prop.values, explode = explode, labels = prop.index, startangle = 90)

plt.show()

### Donut chart

In [None]:
plt.figure(figsize = (3, 3))

explode = (0, 0.1, 0)  # only 'explode' the 2nd group
plt.pie(prop.values, explode = explode, labels = prop.index, startangle = 90, 
        radius = 2, wedgeprops = {'width': 0.5, 'edgecolor': 'white'})

plt.show()

## Box plot & violin plot

In [None]:
plt.figure(figsize = (3, 5))

plt.boxplot([data.loc[data['CMS'] == 'CMS1', 'AGR2'], 
             data.loc[data['CMS'] == 'CMS2', 'AGR2'],
             data.loc[data['CMS'] == 'CMS3', 'AGR2']], 
            labels = ('CMS1', 'CMS2', 'CMS3'))

plt.ylabel('AGR2')

plt.show()

## Horizontal boxplot
Set **vert** to False

In [None]:
plt.figure(figsize = (5, 3))

plt.boxplot([data.loc[data['CMS'] == 'CMS1', 'AGR2'], 
             data.loc[data['CMS'] == 'CMS2', 'AGR2']], 
            labels = ('CMS1', 'CMS2'), vert = False)

plt.xlabel('AGR2')

plt.show()

## List comprehension with box plot

In [None]:
# [data.loc[data['CMS'] == cms, 'AGR2'] for cms in pd.unique(data['CMS'])]

In [None]:
plt.figure(figsize = (5, 3))

plt.boxplot([data.loc[data['CMS'] == cms, 'AGR2'] for cms in pd.unique(data['CMS'])], 
            labels = pd.unique(data['CMS']), vert = False)

plt.xlabel('AGR2')

plt.show()

# seaborn

## Violin plot
We will use violin plot from seaborn because matplotlib's violin plot does not provide much customization

In [None]:
data.head()

In [None]:
_ = sns.violinplot(data = data, x = 'CMS', y = 'AGR2')

## Let's add mutation information to the data

In [None]:
mutation = pd.read_excel('CRC_sample_data.xlsx', sheet_name = 'mutation', index_col = 0)
mutation.head(5)

In [None]:
merged = pd.concat([data, mutation], axis = 1, join = 'inner')

print(merged.shape)
merged.head(2)

## Violin plot with additional hue component
We can control the components of a seaborn plot via matplotlib

In [None]:
_ = sns.violinplot(data = merged, x = 'TP53', y = 'AGR2')

In [None]:
_ = sns.violinplot(data = merged, x = 'CMS', y = 'AGR2', hue = 'BRAF')

In [None]:
plt.figure(figsize = (8, 4))
_ = sns.violinplot(data = merged, x = 'CMS', y = 'AGR2', hue = 'BRAF')
_ = plt.legend(loc = 'lower center')

## Clustermap

### Understanding the components of [clustermap](https://seaborn.pydata.org/generated/seaborn.clustermap.html)

In [None]:
expr_data = data.iloc[:, :-1]

In [None]:
_ = sns.clustermap(data = expr_data, metric = 'correlation', method = 'average', 
                   z_score = 1, figsize = (6, 7), cmap = 'RdBu', center = 0, 
                   row_cluster = True, col_cluster = True, 
                   row_colors = None, col_colors = None)

## Color [palettes](https://seaborn.pydata.org/tutorial/color_palettes.html) in seaborn

In [None]:
_ = sns.clustermap(data = expr_data, metric = 'correlation', method = 'average', 
                   z_score = 1, figsize = (6, 7), cmap = 'vlag', center = 0, 
                   row_cluster = True, col_cluster = True)

In [None]:
_ = sns.clustermap(data = expr_data, metric = 'correlation', method = 'average', 
                   z_score = 1, figsize = (5, 7), cmap = 'icefire', center = 0, 
                   row_cluster = True, col_cluster = True)

## Summarize frequency of samples in each group with pandas's crosstab

In [None]:
pd.crosstab(index = merged['CMS'], columns = merged['KRAS'])

### Visualizing frequency table with seaborn's heatmap

In [None]:
plt.figure(figsize = (3, 4))
sns.heatmap(pd.crosstab(index = merged['CMS'], columns = merged['KRAS']),
            cmap = plt.cm.Greens, annot = True)
plt.yticks(rotation = 0)
plt.show()

## Linear trend plot with seaborn's lmplot

In [None]:
_ = sns.lmplot(data = data, x = 'AGR2', y = 'REG4')

In [None]:
_ = sns.lmplot(data = data, x = 'AGR2', y = 'REG4', hue = 'CMS')