In [1]:
from jax.numpy import array, arange, sin
from jax.typing import ArrayLike
import bokeh as bk
import polars as pl
from typing import NamedTuple
from bokeh.plotting import figure, show

In [3]:
from vega_datasets import data

stocks = data.stocks()

In [4]:
model_seq = arange(20)
arange(0, 20, 2)

Array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

In [5]:
from jax.numpy import hstack, tile

class PlotSize(NamedTuple):
    width: int
    height: int

def plot_xy(x: ArrayLike, y: ArrayLike, title: str = "", size: PlotSize = PlotSize(900, 300)):
    df = pl.DataFrame({"x": x.tolist(), "y": y.tolist()})
    plot = df.plot.scatter("x", "y", title=title, width=size.width, height=size.height)
    return plot

def plot_xys(x: ArrayLike, ys: list[ArrayLike], title: str = "", size: PlotSize = PlotSize(900, 300), color_title: str = "color"):
    # plot multiple series using bokeh
    plot = figure(title=title, width=size.width, height=size.height)
    colors = bk.palettes.Category20[20]
    for i, y in enumerate(ys):
        color = colors[i % len(colors)]
        plot.line(x, y, legend_label=f"{color_title} {i}", line_width=2, color=color)
    plot.legend.location = "top_left"
    show(plot)

    return plot

    

In [6]:
from typing import Callable
from IPython.display import display



def view_fn_1d(fn: Callable[[ArrayLike], ArrayLike], title: str = "", size: PlotSize = PlotSize(900, 300)):
    x = arange(0, 10, 0.1)
    y = fn(x)
    return plot_xy(x, y, title, size)

view_fn_1d(sin, "sin(x)")


In [7]:
from jax import vmap, jit
from jax.numpy import ones
from jax.random import PRNGKey, normal, split

from better_partial import partial as f, _ as __

def mul(x: float, a: float) -> ArrayLike:
    return a * x

view_fn_1d(f(mul)(..., a=2), "2 * x")


In [8]:
from jax.numpy import meshgrid, pi, power, hstack, linspace, float32, zeros, cos


def positional_encoding(pos: float, d_model: int):
    model_seq = arange(d_model)
    # Calculate angle rates
    angle_rads = pos / power(10000, (2 * (model_seq // 2) / float(d_model)))

    # Create an array with 0 for even indices and pi/2 for odd indices
    phase_shift = pi / 2 * (model_seq % 2)

    # Apply the phase shift
    angle_rads_shifted = angle_rads + phase_shift

    return sin(angle_rads_shifted)


# plot heatmap using bokeh
def heatmap(X: ArrayLike, title: str = "", size: PlotSize | None = None):
    if size is None:
        plot = bk.plotting.figure(title=title)
    else:
        plot = bk.plotting.figure(title=title, width=size.width, height=size.height)
    image = plot.image(
        image=[X.T], x=0, y=0, dw=X.shape[0], dh=X.shape[1], palette="Blues256"
    )
    plot.match_aspect = True
    color_bar = image.construct_color_bar(padding=1)

    plot.add_layout(color_bar, "right")

    show(plot)
    return plot


current_seed_ = PRNGKey(0)
key, current_seed_ = split(current_seed_)
T = 256
dim = 64
positional_encoding_v = vmap(positional_encoding, in_axes=(0, None))
X = positional_encoding_v(arange(T), dim)

heatmap(X, "X", PlotSize(1280, 400))

In [9]:
from jax import make_jaxpr
from jax.numpy import sum


ys = []
for i in range(5):
    key, current_seed_ = split(current_seed_)
    w = normal(key, (dim, ))
    ys.append(X @ w)
plot_xys(arange(T), ys, "X @ w")

In [10]:

# X = normal(key, (100, 50))
out_dim = 32
W = normal(key, (dim, out_dim))

heatmap(X @ W, "W", size=PlotSize(1024, 400))



In [33]:
from jax.numpy.linalg import svd, eigh

def PCA(A: ArrayLike) -> ArrayLike:
    A = A - A.mean(axis=0)
    A = A / A.std(axis=0)
    _, _, Vt = svd(A)
    return  A @ Vt.T

heatmap(PCA(X @ W), "PCA", size=PlotSize(1224, 400))
    


In [42]:
import optax
from jax import grad
from jax.numpy import mean, std, log

def ICA(A: ArrayLike, n_components, n_iterations, key) -> ArrayLike:
    X_mean = mean(A, axis=0)
    X = A - X_mean
    X /= std(X, axis=0)

    # Initialize weights
    W = normal(key, (n_components, X.shape[1]))

    # Define the loss function (using kurtosis as an example)
    def loss_fn(W, X):
        Y = X @ W.T
        kurtosis = mean(Y**4, axis=0)
        return -log(mean(kurtosis))

    # Gradient of the loss function
    grad_loss = grad(loss_fn)

    # Setup optimizer
    optimizer = optax.sgd(learning_rate=0.01)
    opt_state = optimizer.init(W)

    # Update step
    @jit
    def update(W, opt_state, X):
        grads = grad_loss(W, X)
        updates, opt_state = optimizer.update(grads, opt_state)
        new_W = optax.apply_updates(W, updates)
        return new_W, opt_state

    # Training loop
    for _ in range(n_iterations):
        W, opt_state = update(W, opt_state, X)

    return X @ W.T

key, current_seed_ = split(current_seed_)
heatmap(ICA(X[:, 5:10] @ W[:5, :], 5, 15000, key), "ICA", size=PlotSize(1224, 200))
heatmap(PCA(X[:, 5:10] @ W[:5, :]), "PCA", size=PlotSize(1224, 400))
heatmap(X[:, 5:10], "orig", size=PlotSize(1224, 200))
