# Interactive PCA
> Yang

In [115]:
%matplotlib qt
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button
from matplotlib.widgets import PolygonSelector
from sklearn.decomposition import PCA

rng = np.random.RandomState(0)
n_samples = 1000
cov = [[1, 0], [0, 1]]
X = np.concatenate([
    rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), 
    rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples)])



pca = PCA(n_components=1).fit(X)
component = pca.components_.reshape(-1)

# print(pca.components_)
# print(pca.explained_variance_)
# print(list(zip(pca.components_, pca.explained_variance_)))

fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))
ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label="samples")
x_center = np.mean(X, axis=0)

comp_vector = [component, x_center]

ax_orig.set(
    aspect="auto", 
    title="2-dimensional dataset with principal components",
    xlabel="first feature",
    ylabel="second feature",
)


def onselect(verts):
    _x_center, _total_vector = verts
    # print(x_center)
    # print(component)
    _component = np.array(_total_vector) - np.array(_x_center)
    ax_redim.clear()
    ax_redim.hist((X @ _component.T - _x_center @ _component.T),50)
    ax_redim.set(
        aspect="auto",
        title="1-dimensional dataset after dimension reduction",
        xlabel="Main feature",
        ylabel="Number of samples",
    )
    fig.canvas.draw()

selector = PolygonSelector(ax_orig, onselect=onselect, 
                           props=dict(color='r', linestyle='-', linewidth=3, alpha=0.6, label=f"Component"))
component, x_center = comp_vector
selector.verts = [x_center, x_center + component]
ax_orig.legend()

ax_redim.hist((X @ component.T - x_center @ component.T),50)
ax_redim.set(
    aspect="auto",
    title="1-dimensional dataset after dimension reduction",
    xlabel="Main feature",
    ylabel="Number of samples",
)
#_asp = np.diff(ax_orig.get_ylim())[0] / np.diff(ax_orig.get_xlim())[0]
#ax_redim.set_aspect(_asp)

plt.tight_layout()
plt.show()