An example of Principal Component Analysis on the [iris](https://en.wikipedia.org/wiki/Iris_flower_data_set) dataset. Based on the PCA tutorial of [scikit-learn](https://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_iris.html).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# unused but required import for doing 3d projections with matplotlib < 3.2
import mpl_toolkits.mplot3d  # noqa: F401

from sklearn import datasets, decomposition


sklearn has many cool datasets that we can call and play with. See more in the `sklearn.datasets` doc

In [None]:
iris = datasets.load_iris()
X = iris.data # data
y = iris.target # label

In [None]:
# list features
features = ["sepal length", "sepal width", "petal lenght", "petal width"]
# list flower types
names = ["Setosa","Versicolour","Virginica"]
# assign a color to each one
colors = ["tab:orange", "tab:blue", "tab:green"]

In [None]:
fid0, fid1 = 1, 2

# color the points by their label
for i, name in enumerate(names):
    idx = np.where(i == y)
    plt.scatter(X[idx, fid0], X[idx, fid1], c=colors[i],
               label=name)

plt.xlabel(f"{features[fid0]}")
plt.ylabel(f"{features[fid1]}")
plt.title(f"Comparing {features[fid0]} with {features[fid1]}")
plt.legend()

Inspect X.shape (150, 4): 150 samples total of the 3 types of irises (Setosa, Versicolour, and Virginica); 4 features are measured for each sample: sepal length, sepal width, petal lenght and petal width.

What do we have in the dataset? How do we visualize things when there are 4 dimensions?

In [None]:
# plot 3D plot of our raw data

# make a figure
fig = plt.figure(1, figsize=(4, 3))
plt.clf()
ax = fig.add_subplot(111, projection="3d", elev=48, azim=134)
ax.set_position([0, 0, 0.95, 1])

# features being ploted on the x, y, z axis
fidx, fidy, fidz = 0, 1, 2

# color the points by their label
for i, name in enumerate(names):
    idx = np.where(i == y)
    ax.scatter(X[idx, fidx], X[idx, fidy], X[idx, fidz], c=colors[i],
               label=name)

# display names as legends
plt.legend(bbox_to_anchor=(1.5, 1), frameon=False)

# set axes labels
ax.set_xlabel(features[fidx])
ax.set_ylabel(features[fidy])
ax.set_zlabel(features[fidz])


Look up plotly if you want 3D interactive plots.

Use PCA to find new axes to express our data in!

In [None]:
# initialize a PCA object (think about it as loading in a model)
pca = decomposition.PCA()
pca.fit(X)
X_new = pca.transform(X)

# alternatively you can call: X_transformed = pca.fit_transform(X)

In [None]:
fig = plt.figure(1, figsize=(4, 3))

# which pcs to look at
pcx, pcy = 0, 2

# color the points by their label
for i, name in enumerate(names):
    idx = np.where(i == y)
    plt.scatter(X_new[idx, pcx], X_new[idx, pcy], c=colors[i], label=name)


# display names as legends
plt.legend(bbox_to_anchor=(1.1, 0.65), frameon=False)

# set axes labels
plt.xlabel(f"PC{pcx}")
plt.ylabel(f"PC{pcy}")


In [None]:
fig = plt.figure(1, figsize=(4, 3))
plt.clf()

ax = fig.add_subplot(111, projection="3d", elev=48, azim=134)
ax.set_position([0, 0, 0.95, 1])

# color the points by their label
for i, name in enumerate(names):
    idx = np.where(i == y)
    ax.scatter(X_new[idx, 0], X_new[idx, 1], X_new[idx, 2], c=colors[i],
               label=name)

# display names as legends
plt.legend(bbox_to_anchor=(1.1, 1), frameon=False)

# set axes labels
ax.set_xlabel("PC0")
ax.set_ylabel("PC1")
ax.set_zlabel("PC2")


Hold on: what's in these PCs? Is PCA some sort of black magic/black box?

> No! we can piece out some parts of the puzzle by looking at the loadings and variance explained


In [None]:
# import useful (home-made) helper functions
from helper import *

In [None]:
plot_var_explained(pca.explained_variance_ratio_, n=3)

In [None]:
components = pca.components_
components = np.squeeze(components)

count = 0  

feature_colors = ["tab:orange", "tab:blue", "tab:green", "tab:purple"]

for i in range(len(features)):
    for j in range(components.shape[0]):
        plt.bar(count, components[i, j], color=feature_colors[i], label=f"feature {i}" if j==0 else "")
        count += 1

plt.ylabel("Component loadings")
plt.xlabel("Feature x PC")
plt.xticks(range(count), list(range(4))*4)
plt.legend(bbox_to_anchor=(1., 0.6))