In [2]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go

In [3]:
highs = pd.read_csv("highs.csv", parse_dates=["Date"])
lows = pd.read_csv("lows.csv", parse_dates=["Date"])
market_caps = pd.read_csv("monthly_market_caps.csv", parse_dates=["Date"])
volumes = pd.read_csv("monthly_average_volumes.csv", parse_dates=["Date"])

In [213]:
for df in [highs, lows, market_caps, volumes]:
    df.sort_values("Date", inplace=True)
    
mcap_mean = []
mcap_std = []

for i in range(len(market_caps)):
    
    mcap_mean.append(np.mean(market_caps.iloc[i, 1:]))
    mcap_std.append(np.std(market_caps.iloc[i, 1:]))



tickers = [c for c in highs.columns if c != "Date"]

def compute_slope_and_vol(series):
    """Compute monthly log slope and volatility."""
    series = series.dropna()
    
    
    y = np.log(series.values)
    X = np.arange(len(y)).reshape(-1, 1)
    
    
    model = LinearRegression().fit(X, y)
    slope = model.coef_[0]                    
    pct_slope = np.expm1(slope)                
    
    rets = np.diff(series.values) / series.values[:-1]
    vol = np.std(rets)

    return pct_slope, vol


records = []

for t in tickers:
    
    h = highs[t].dropna()
    m = market_caps[t].dropna()
    
    pct_slope, vol = compute_slope_and_vol(h)
    
    
    if np.isnan(pct_slope):
        pct_slope, vol = compute_slope_and_vol(lows[t])
    
    mcap_norm = []
    for n in range((310-len(m)), 310):
        mcap_norm.append((mcap_mean[n]-m[n])/mcap_std[n])
    
    mcap = np.mean(mcap_norm)
    
    mcap_med = market_caps[t].median() if t in market_caps else np.nan
    
    records.append({
        "Ticker": t,
        "Slope": pct_slope,
        "Vol": vol,
        "MarketCap": mcap,
        
    })

df = pd.DataFrame(records)
sector_df= pd.read_csv('Tech Subsectors - Sheet1.csv')
df = df.merge(sector_df, on=['Ticker', 'Ticker'])



HIGH_GROWTH = 0.02
DECLINING = 0


def classify_growth(row):
    if row["Slope"] >= HIGH_GROWTH:
        return "High-growth"
    if row["Slope"] <= DECLINING:
        return "Declining"
    return "Stable"

df["Growth"] = df.apply(classify_growth, axis=1)

quart1, quart2 = df["MarketCap"].quantile([0.33, 0.66])

q1 = 0
q2 = .35


def mcap_bucket(x):
    
    if x <= q1: return "Low Cap"
    if x <= q2: return "Mid Cap"
    return "High Cap"

df["Performance"] = df["MarketCap"].apply(mcap_bucket)


industry_nodes = list(df['Industry'].unique())
growth_nodes = ["High-growth", "Stable", "Declining"]
perf_nodes = ["High Cap", "Mid Cap", "Low Cap"]

all_nodes = sentiment_nodes + growth_nodes + perf_nodes
node_index = {name: i for i, name in enumerate(all_nodes)}

source = []
target = []
value = []

for g in growth_nodes:
    for i in industry_nodes:
        count = int(((df["Growth"] == g) & (df["Industry"] == i)).sum())
        if count > 0:
            source.append(node_index[g])
            target.append(node_index[i])
            value.append(count)


for i in industry_nodes:
    for p in perf_nodes:
        count = int(((df["Industry"] == i) & (df["Performance"] == p)).sum())
        if count > 0:
            source.append(node_index[i])
            target.append(node_index[p])
            value.append(count)

fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=20,
        thickness=18,
        line=dict(color="black", width=0.5),
        label=all_nodes
    ),
    link=dict(
        source=source,
        target=target,
        value=value
    )
)])


fig.update_layout(
    title="Tech Stock Sankey: → Growth → Subsector → Performance (Market Cap)",
    annotations=[
        dict(
            text="Size is number of companies, Stable is stocks with a slope greater than 0 and high growth is stocks with greater than .02. Market cap is based on the normalization of each month.",
            x=0,
            y=-0.25,     
            xref="paper",
            yref="paper",
            showarrow=False,
            xanchor="left",
            yanchor="top",
            font=dict(size=12)
        )
    ],
    margin=dict(b=120)  
)



fig.show()
