In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 1. Quantify relationships between stocks

def compute_log_returns(price_df):
    """
    price_df: DataFrame, index = dates, columns = stock tickers
    returns: log returns, same shape but first row dropped
    """
    # Use log-returns: r_t = log(P_t / P_{t-1})
    log_prices = np.log(price_df)
    returns = log_prices.diff().dropna()
    return returns

def compute_correlation_matrix(returns_df):
    """
    returns_df: DataFrame of returns
    returns: DataFrame correlation matrix (symmetric)
    """
    corr_matrix = returns_df.corr()
    return corr_matrix

# Example:
# price_df = ...  # price data
# returns_df = compute_log_returns(price_df)
# corr_matrix = compute_correlation_matrix(returns_df)


# 2. MST + clustering by "closeness"

class UnionFind:
    """
    Simple Union-Find (Disjoint Set) structure for Kruskal's MST.
    """
    def __init__(self, items):
        # Map item -> parent, initially each item is its own parent
        self.parent = {item: item for item in items}
        self.rank = {item: 0 for item in items}

    def find(self, x):
        # Path compression
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        # Union by rank
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x == root_y:
            return False  # already in the same set

        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1

        return True

def build_mst_from_corr(corr_matrix):
    """
    corr_matrix: DataFrame, correlation between stocks.
    We define a distance = 1 - correlation.
    Then run Kruskal to get an MST.

    returns:
        mst_edges: list of edges (stock_i, stock_j, corr_ij, dist_ij)
    """
    tickers = list(corr_matrix.columns)
    edges = []

    # Build edge list (only use i < j to avoid duplicates)
    for i in range(len(tickers)):
        for j in range(i + 1, len(tickers)):
            s1 = tickers[i]
            s2 = tickers[j]
            corr_ij = corr_matrix.loc[s1, s2]
            # distance: higher correlation -> smaller distance
            dist_ij = 1.0 - corr_ij
            edges.append((dist_ij, s1, s2, corr_ij))

    # Sort edges by distance ascending
    edges.sort(key=lambda x: x[0])

    uf = UnionFind(tickers)
    mst_edges = []

    # Kruskal's algorithm
    for dist_ij, s1, s2, corr_ij in edges:
        if uf.union(s1, s2):
            mst_edges.append((s1, s2, corr_ij, dist_ij))

    return mst_edges


def clusters_from_mst(mst_edges, corr_threshold=0.5):
    """
    Given MST edges and a correlation threshold, cut the "weak" edges
    and use connected components as clusters.

    corr_threshold: if corr < threshold, we remove that edge.
    returns: list of clusters; each cluster is a list of tickers.
    """
    # Build adjacency list using only "strong" edges
    adjacency = {}
    all_nodes = set()

    for s1, s2, corr_ij, dist_ij in mst_edges:
        all_nodes.add(s1)
        all_nodes.add(s2)
        if corr_ij >= corr_threshold:
            adjacency.setdefault(s1, []).append(s2)
            adjacency.setdefault(s2, []).append(s1)

    # Make sure every node appears in adjacency, even isolated ones
    for node in all_nodes:
        adjacency.setdefault(node, [])

    visited = set()
    clusters = []

    # Simple DFS/BFS to find connected components
    for node in all_nodes:
        if node not in visited:
            stack = [node]
            current_cluster = []
            visited.add(node)

            while stack:
                v = stack.pop()
                current_cluster.append(v)
                for nei in adjacency[v]:
                    if nei not in visited:
                        visited.add(nei)
                        stack.append(nei)

            clusters.append(current_cluster)

    return clusters

# Example:
# mst_edges = build_mst_from_corr(corr_matrix)
# clusters = clusters_from_mst(mst_edges, corr_threshold=0.6)

# 3. Visualization
#    (a) MST graph with clusters
#    (b) Time-series within each cluster

def plot_mst_graph(mst_edges, clusters):
    """
    Plot a simple 2D graph of the MST.
    We place nodes evenly on a circle and color by cluster.
    """
    # Flatten list of clusters to know order
    all_nodes = [node for cluster in clusters for node in cluster]
    n = len(all_nodes)

    # Position nodes on a circle
    angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
    positions = {}
    for i, node in enumerate(all_nodes):
        x = np.cos(angles[i])
        y = np.sin(angles[i])
        positions[node] = (x, y)

    # Build a simple color list for clusters
    base_colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red',
                   'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray',
                   'tab:olive', 'tab:cyan']

    # Map each node to a color based on its cluster index
    node_colors = {}
    for ci, cluster in enumerate(clusters):
        color = base_colors[ci % len(base_colors)]
        for node in cluster:
            node_colors[node] = color

    # Plot edges
    plt.figure(figsize=(6, 6))
    for s1, s2, corr_ij, dist_ij in mst_edges:
        x1, y1 = positions[s1]
        x2, y2 = positions[s2]
        plt.plot([x1, x2], [y1, y2], linewidth=0.5, color='lightgray')

    # Plot nodes
    for node, (x, y) in positions.items():
        plt.scatter(x, y, color=node_colors[node], s=50)
        plt.text(x, y, node, fontsize=8,
                 ha='center', va='center')

    plt.title("MST of Stocks (colored by cluster)")
    plt.axis('off')
    plt.tight_layout()
    plt.show()


def plot_cluster_time_series(price_df, clusters):
    """
    For each cluster, plot the price series of its stocks
    on the same figure.
    """
    for i, cluster in enumerate(clusters):
        plt.figure(figsize=(8, 4))
        for ticker in cluster:
            # Some clusters may have single stocks
            if ticker in price_df.columns:
                plt.plot(price_df.index, price_df[ticker], label=ticker)

        plt.xlabel("Date")
        plt.ylabel("Price")
        plt.title(f"Cluster {i + 1}: Price Movements")
        plt.legend()
        plt.tight_layout()
        plt.show()

# Full pipeline example

# 1) Compute relationships
# returns_df = compute_log_returns(price_df)
# corr_matrix = compute_correlation_matrix(returns_df)

# 2) Build MST and cluster
# mst_edges = build_mst_from_corr(corr_matrix)
# clusters = clusters_from_mst(mst_edges, corr_threshold=0.6)

# 3a) Visualize MST with clusters
# plot_mst_graph(mst_edges, clusters)

# 3b) Visualize time-series for each cluster
# plot_cluster_time_series(price_df, clusters)
