
# 🧠 GNN-Style Correlation Analysis
This notebook builds a **correlation graph** from asset returns and applies a lightweight, framework-free, **GNN-style message passing** to learn smoothed node embeddings and cluster structures.

**What you'll do:**
1. Load returns (from CSV if available, else synthesize).
2. Compute a correlation matrix + build a weighted graph.
3. Extract communities (sector-like clusters) & the MST backbone.
4. Run a simple 2-layer message-passing network (no special libs).
5. Visualize embeddings & cluster structure.

> Drop a CSV at `data/returns.csv` with columns: `date, TICKER1, TICKER2, ...` (wide format). The code will detect and use it automatically.


In [None]:

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx

from typing import Tuple

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['axes.grid'] = True


## 1) Load Returns

In [None]:

def load_returns(path: str = 'data/returns.csv', n_assets: int = 40, n_days: int = 400, seed: int = 42) -> pd.DataFrame:
    """Load daily returns from CSV if present, else synthesize a market + sector model."""
    if os.path.exists(path):
        df = pd.read_csv(path, parse_dates=['date'])
        df = df.set_index('date').sort_index()
        # Infer returns if prices provided
        if (df > 5).all().all():  # hacky: assume raw prices if big
            df = df.pct_change().dropna()
        return df
    rng = np.random.default_rng(seed)
    # Synthetic market + sectors
    n_sectors = 5
    sector_members = rng.integers(0, n_sectors, size=n_assets)
    market = rng.normal(0, 0.008, size=(n_days, 1))
    sector_factors = rng.normal(0, 0.01, size=(n_days, n_sectors))
    idio = rng.normal(0, 0.012, size=(n_days, n_assets))
    data = np.zeros((n_days, n_assets))
    for j in range(n_assets):
        s = sector_members[j]
        data[:, j] = 0.4*market[:,0] + 0.4*sector_factors[:, s] + 0.2*idio[:, j]
    dates = pd.date_range('2022-01-01', periods=n_days, freq='B')
    cols = [f"A{j:02d}" for j in range(n_assets)]
    return pd.DataFrame(data, index=dates, columns=cols)

rets = load_returns()
rets.head()


## 2) Correlation Matrix

In [None]:

corr = rets.corr()
plt.figure(figsize=(8,6))
plt.imshow(corr, cmap='coolwarm', vmin=-1, vmax=1)
plt.colorbar(label='Correlation')
plt.title('Asset Correlation Matrix')
plt.tight_layout()
plt.show()


## 3) Build Correlation Graph

In [None]:

def build_graph_from_corr(corr: pd.DataFrame, threshold: float = 0.3) -> nx.Graph:
    """Undirected graph where edges connect pairs with |corr| >= threshold. Weight = corr."""
    G = nx.Graph()
    tickers = corr.columns.tolist()
    G.add_nodes_from(tickers)
    for i, a in enumerate(tickers):
        for j in range(i+1, len(tickers)):
            b = tickers[j]
            w = float(corr.iloc[i, j]) # type: ignore
            if abs(w) >= threshold:
                G.add_edge(a, b, weight=w)
    return G

G = build_graph_from_corr(corr, threshold=0.35)
print(nx.info(G))


## 4) Community Detection (Greedy Modularity)

In [None]:

communities = list(nx.algorithms.community.greedy_modularity_communities(G, weight='weight')) # type: ignore
comm_map = {}
for idx, com in enumerate(communities):
    for n in com:
        comm_map[n] = idx

print(f"Found {len(communities)} communities.")
for i, cset in enumerate(communities[:5]):
    print(f"Community {i}: {len(cset)} nodes")


## 5) Minimum Spanning Tree (MST) Backbone

In [None]:

# Transform correlation to distance: d = sqrt(2*(1 - corr))
W = corr.copy()
np.fill_diagonal(W.values, 1.0)
D = np.sqrt(2.0*(1.0 - W.clip(-0.999, 0.999)))
# Build complete graph with distances
G_all = nx.Graph()
for i, a in enumerate(corr.columns):
    for j in range(i+1, len(corr.columns)):
        b = corr.columns[j]
        G_all.add_edge(a, b, weight=float(D.iloc[i, j])) # type: ignore

T = nx.minimum_spanning_tree(G_all, weight='weight')
print(f"MST edges: {T.number_of_edges()}")


## 6) Visualize Graph by Community

In [None]:

pos = nx.spring_layout(G, seed=1, weight='weight', k=None, iterations=100)

plt.figure(figsize=(10,7))
for i, com in enumerate(communities):
    nx.draw_networkx_nodes(G, pos, nodelist=list(com), node_size=80, label=f"C{i}")
nx.draw_networkx_edges(G, pos, alpha=0.15)
plt.title("Correlation Graph — Communities")
plt.legend(markerscale=2, fontsize=8, ncol=2)
plt.axis('off')
plt.show()


## 7) Lightweight GNN-Style Message Passing (NumPy)

In [None]:

def build_node_features(rets: pd.DataFrame, window: int = 20) -> pd.DataFrame:
    """Simple statistical features per asset: mean, vol, skew (approx), kurtosis (approx)."""
    X = pd.DataFrame(index=rets.columns, columns=['mu', 'vol', 'skew', 'kurt'], dtype=float)
    R = rets.tail(window)
    mu = R.mean().values
    vol = R.std(ddof=1).values + 1e-8 # type: ignore
    z = (R - mu) / vol
    skew = (z**3).mean().values
    kurt = (z**4).mean().values - 3.0
    X.loc[:, 'mu'] = mu
    X.loc[:, 'vol'] = vol
    X.loc[:, 'skew'] = skew
    X.loc[:, 'kurt'] = kurt
    return X

X0 = build_node_features(rets, window=60)
X0.head()


In [None]:

def relu(x): return np.maximum(0, x)

def message_passing(G: nx.Graph, X: pd.DataFrame, layers: int = 2, alpha_self: float = 0.5) -> pd.DataFrame:
    """Simple symmetric normalize + aggregate: H^{(l+1)} = ReLU(alpha*H + (1-alpha)*A_hat H W_l) with identity weights.

    Here, we avoid learning and use identity 'weights' to produce smooth embeddings."""
    nodes = list(X.index)
    idx = {n:i for i,n in enumerate(nodes)}
    # Build normalized adjacency (A_hat)
    A = np.zeros((len(nodes), len(nodes)), dtype=float)
    for u, v, d in G.edges(data=True):
        i, j = idx[u], idx[v]
        w = abs(d.get('weight', 0.0))
        A[i, j] = w
        A[j, i] = w
    # Symmetric normalization
    deg = A.sum(axis=1)
    deg[deg == 0] = 1.0
    Dm12 = np.diag(1.0/np.sqrt(deg))
    A_hat = Dm12 @ A @ Dm12
    H = X.values.copy()
    for _ in range(layers):
        H = relu(alpha_self*H + (1.0 - alpha_self)*(A_hat @ H))
    return pd.DataFrame(H, index=nodes, columns=[f"emb{i}" for i in range(H.shape[1])])

H = message_passing(G, X0, layers=2, alpha_self=0.6)
H.head()


## 8) 2D Projection (PCA) and Plot

In [None]:

def pca_2d(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    Xc = X - X.mean(axis=0, keepdims=True)
    C = np.cov(Xc, rowvar=False)
    vals, vecs = np.linalg.eigh(C)
    order = np.argsort(vals)[::-1]
    vecs = vecs[:, order[:2]]
    proj = Xc @ vecs
    return proj, vals[order], vecs

proj, evals, evecs = pca_2d(H.values)

plt.figure(figsize=(8,6))
for i, n in enumerate(H.index):
    c = comm_map.get(n, -1)
    plt.scatter(proj[i,0], proj[i,1], s=30, alpha=0.8)
plt.title("Node Embeddings (PCA of Message-Passed Features)")
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.show()


## 9) Link Similarity vs Correlation

In [None]:

# Cosine similarity in embedding space vs. original correlation
from numpy.linalg import norm

tickers = list(H.index)
emb = H.values
sim = np.zeros_like(corr.values)
for i in range(len(tickers)):
    for j in range(len(tickers)):
        a, b = emb[i], emb[j]
        denom = (norm(a)*norm(b) + 1e-12)
        sim[i,j] = float(np.dot(a,b)/denom)

# Compare distributions of sim for edges vs. non-edges
edges = [(i,j) for i in range(len(tickers)) for j in range(i+1,len(tickers)) if G.has_edge(tickers[i], tickers[j])]
non_edges = [(i,j) for i in range(len(tickers)) for j in range(i+1,len(tickers)) if not G.has_edge(tickers[i], tickers[j])]

def sample_pairs(pairs, k=5000, seed=0):
    rng = np.random.default_rng(seed)
    if len(pairs) <= k:
        return pairs
    idx = rng.choice(len(pairs), size=k, replace=False)
    return [pairs[t] for t in idx]

edges_s = sample_pairs(edges, k=4000)
non_s   = sample_pairs(non_edges, k=4000)

edge_vals = np.array([sim[i,j] for i,j in edges_s])
non_vals  = np.array([sim[i,j] for i,j in non_s])

plt.figure(figsize=(9,4))
plt.hist(edge_vals, bins=40, alpha=0.6, label='Edges (|corr|>=thr)')
plt.hist(non_vals, bins=40, alpha=0.6, label='Non-edges')
plt.title("Embedding Cosine Similarity for Edges vs. Non-edges")
plt.legend()
plt.show()

print("Mean cos-sim (edges):   ", edge_vals.mean())
print("Mean cos-sim (non-edge):", non_vals.mean())


## 10) Rolling Correlation Stability

In [None]:

def rolling_corr_stability(rets: pd.DataFrame, window: int = 60, step: int = 20, threshold: float = 0.35) -> pd.DataFrame:
    """Compute Jaccard similarity of edge sets across rolling windows."""
    idx = rets.index
    out = []
    for start in range(0, len(idx)-2*window, step):
        A = rets.iloc[start:start+window]
        B = rets.iloc[start+step:start+step+window]
        GA = build_graph_from_corr(A.corr(), threshold=threshold)
        GB = build_graph_from_corr(B.corr(), threshold=threshold)
        EA = set(tuple(sorted(e)) for e in GA.edges())
        EB = set(tuple(sorted(e)) for e in GB.edges())
        inter = len(EA & EB); union = len(EA | EB)
        j = inter/union if union > 0 else np.nan
        out.append((rets.index[start+step], j))
    return pd.Series(dict(out)).to_frame('jaccard')

stab = rolling_corr_stability(rets, window=80, step=20, threshold=0.35)
stab.plot(title='Rolling Edge-Set Stability (Jaccard)')
plt.ylabel('Jaccard Similarity')
plt.show()
