In [None]:
%pip install -U "git+https://github.com/sashakolpakov/stonks-analysis.git@main"


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
from dire_jax import DiRe
from stonks_analysis import download_data, compute_generic_features


In [None]:
# -- 1) Download & align --
tickers = ["AAPL", "AMZN", "PLTR", "BTC-USD", "ETH-USD", "DOGE-USD"]
raw = {t: download_data(t, lookback_days=365*2) for t in tickers}
# intersection of dates
common_idx = raw[tickers[0]].index
for t in tickers[1:]:
    common_idx = common_idx.intersection(raw[t].index)
aligned = {t: raw[t].loc[common_idx] for t in tickers}

# -- 2) Build feature DataFrame --
dfs = []
for t in tickers:
    feats = compute_generic_features(aligned[t]).dropna()
    label = t.replace("-USD","")
    feats["Ticker"] = label
    dfs.append(feats)
df_pts = pd.concat(dfs)

# -- 3) Prepare X and labels --
labels = df_pts["Ticker"].tolist()
X = df_pts[["Close","LogVol","Return"]].values
X_scaled = StandardScaler().fit_transform(X)

# -- 4) Set up four reducers --
reducers = {
    "DiRe-JAX": DiRe(dimension=2, n_neighbors=32, init_embedding_type="spectral", max_iter_layout=64, cutoff=12),
    "PCA":      PCA(n_components=2),
    "t-SNE":    TSNE(n_components=2, init="pca"),
    "UMAP":     umap.UMAP(n_components=2)
}

# -- 5) Compute embeddings --
embeddings = {}
for name, red in reducers.items():
    Z = red.fit_transform(X_scaled)
    embeddings[name] = Z

# -- 6) Plot all four with discrete labels --
label_list = sorted(set(labels))  # e.g. ['AAPL','AMZN','BTC','ETH']
palette = plt.get_cmap("tab10")   # up to 10 distinct colors

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for ax, (name, Z) in zip(axes, embeddings.items()):
    # plot each point with its label’s color
    for lbl in label_list:
        idx = [i for i, L in enumerate(labels) if L == lbl]
        ax.scatter(
            Z[idx, 0], Z[idx, 1],
            c=[palette(label_list.index(lbl))],
            label=lbl, s=5
        )
    ax.set_title(name)
    ax.legend()

plt.tight_layout()
plt.show()
