In [1]:
import os

import numpy as np

from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib import image as mpimg

from skcosmo.datasets import load_csd_1000r
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from skcosmo.decomposition import PCovR

from scipy.stats import pearsonr
from shutil import copy
from PIL import Image as im
import io
from IPython import display

def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = im.open(buf)
    return img

cmapX = cm.plasma

In [2]:
X, y = load_csd_1000r(return_X_y=True)
y = y.reshape(X.shape[0], -1)

X_scaler = StandardScaler()
X_scaled = X_scaler.fit_transform(X)

y_scaler = StandardScaler()
y_scaled = y_scaler.fit_transform(y)

In [3]:
alphas = np.array(
    list(
        sorted(
            set(
                [
                    *[round(x, 6) for x in np.linspace(0, 1, 101)],
                ]
            )
        )
    )
)
alphas = alphas[alphas <= 1.0].tolist()

In [4]:
regressor = Ridge(alpha=1e-4, fit_intercept=False).fit(X_scaled, y_scaled)
T = np.array([
    PCovR(mixing=alpha, regressor=regressor)
    .fit(X_scaled, y_scaled)
    .transform(X_scaled)
    for alpha in alphas
])

In [5]:
# fix mirrors

for i in range(1, len(T)):
    for dim in range(T[i].shape[1]):
        if pearsonr(T[i, :, dim], T[i-1, :, dim])[0]<0:
            T[i, :, dim] *= -1



In [6]:
images = []
for a, t in zip(alphas, T):
    fig, ax = plt.subplots(1, 1, figsize=(3, 3))
    ax.scatter(t[:, 0], t[:, 1], c=y, marker="o", s=3, rasterized=True, cmap=cmapX)
    ax.set_xlabel = (r"$\mathbf{PCov}_2$",)
    ax.set_ylabel = (r"$\mathbf{PCov}_1$",)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.annotate(
        xy=(0.05, 0.95),
        text=r"$\alpha=$" + str(round(a, 3)),
        xycoords="axes fraction",
        ha="left",
        va="top",
        fontsize=12,
    )

    images.append(fig2img(fig))
    if a in [0, 1]:
        for l in 'ABCDEFGHIJ':
            images.append(fig2img(fig))
    plt.cla()
    plt.close()

In [7]:
images[0].save('pcovr.gif', save_all=True, append_images=images, loop=0)

In [8]:
iframe = '<iframe src=pcovr.gif width=324 height=324></iframe>'
display.HTML(iframe)

