# AI Stock Investment Tool - Colab Training

Train and compare all model types (including Hybrid Multi-Modal) on GPU, then serve the API for your frontend.

**Before starting:**
1. Go to **Runtime > Change runtime type > T4 GPU**
2. Run cells **in order** from top to bottom
3. If repo is private, you'll need a [GitHub Personal Access Token](https://github.com/settings/tokens) with `repo` scope

**Sections:**
- 1-2: Setup & environment
- 3: Fetch data & build features (50-ticker universe)
- 4-6: Configure, train, and compare all models
- 7: Save results to Drive
- 8: Optuna hyperparameter search (optional)
- 9: Export production artifacts (best model)
- 10: Serve API via ngrok (connect to frontend)
- 11: Download artifacts to local machine

## 1. Setup

In [None]:
!pip install -q yfinance lightgbm torch optuna pyarrow scikit-learn scipy pandas numpy matplotlib feedparser pyngrok

In [None]:
import os
os.chdir("/content")

# Clean previous clone if any
!rm -rf AI-stock-investment-tool

# Try public clone first, fall back to token auth
REPO = "https://github.com/kevin6598/AI-stock-investment-tool.git"
ret = os.system("git clone %s 2>/dev/null" % REPO)

if ret != 0:
    from getpass import getpass
    print("Public clone failed -- repo is private.")
    print("Create a token at: https://github.com/settings/tokens (repo scope)")
    token = getpass("Paste your GitHub token: ")
    os.system("git clone https://%s@github.com/kevin6598/AI-stock-investment-tool.git" % token)
    del token

os.chdir("/content/AI-stock-investment-tool")
print("Working dir: %s" % os.getcwd())
!git log --oneline -3

In [None]:
import torch, sys
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU detected. Go to Runtime > Change runtime type > T4 GPU")

## 2. Mount Google Drive & Load Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

# Configure paths -- artifacts persist on Google Drive across sessions
DRIVE_DIR = "/content/drive/MyDrive/ai_stock_tool"
os.makedirs(DRIVE_DIR, exist_ok=True)

DATA_PATH = os.path.join(DRIVE_DIR, "dataset.parquet")
OUTPUT_DIR = os.path.join(DRIVE_DIR, "models_registry")
ARTIFACT_DIR = os.path.join(DRIVE_DIR, "artifacts")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(ARTIFACT_DIR, exist_ok=True)

print("Drive dir:    %s" % DRIVE_DIR)
print("Data path:    %s" % DATA_PATH)
print("Output dir:   %s" % OUTPUT_DIR)
print("Artifact dir: %s" % ARTIFACT_DIR)

## 3. Fetch Data & Build Features

Fetches 50-ticker universe with ticker embeddings (needed for Hybrid Multi-Modal).
If you already built the dataset, skip Option A and use Option B to load from Drive.

In [None]:
# === Option A: Build dataset from scratch (50-ticker universe) ===

TICKERS = [
    "AAPL", "MSFT", "GOOGL", "AMZN", "META", "NVDA", "TSLA", "BRK-B", "JPM", "JNJ",
    "V", "PG", "UNH", "HD", "MA", "DIS", "PYPL", "BAC", "NFLX", "ADBE",
    "CRM", "CMCSA", "XOM", "VZ", "KO", "INTC", "PEP", "ABT", "CSCO", "TMO",
    "COST", "MRK", "WMT", "AVGO", "ACN", "CVX", "NKE", "LLY", "MCD", "TXN",
    "QCOM", "DHR", "UPS", "BMY", "PM", "LIN", "NEE", "ORCL", "RTX", "HON",
]
PERIOD = "5y"
FORWARD_HORIZONS = [21, 63, 126]  # 1M, 3M, 6M

from data.stock_api import get_historical_data, get_stock_info
from training.feature_engineering import (
    build_panel_dataset, cross_sectional_normalize, add_ticker_embedding_column,
)

print("Fetching stock data for %d tickers..." % len(TICKERS))
stock_dfs = {}
stock_infos = {}
for ticker in TICKERS:
    try:
        df = get_historical_data(ticker, period=PERIOD)
        if not df.empty and len(df) > 300:
            stock_dfs[ticker] = df
            stock_infos[ticker] = get_stock_info(ticker) or {}
            print("  %s: %d rows" % (ticker, len(df)))
    except Exception as e:
        print("  WARN: %s failed: %s" % (ticker, e))

market_df = get_historical_data("SPY", period=PERIOD)
print("  SPY (market): %d rows" % len(market_df))

valid_tickers = sorted(stock_dfs.keys())
print("\nValid tickers: %d / %d" % (len(valid_tickers), len(TICKERS)))

print("Building features...")
panel = build_panel_dataset(stock_dfs, stock_infos, market_df, FORWARD_HORIZONS)
panel = cross_sectional_normalize(panel)
panel, ticker_to_id = add_ticker_embedding_column(panel, valid_tickers)
print("Panel shape: %s" % str(panel.shape))

# Save to Drive
panel.to_parquet(DATA_PATH)
print("\nSaved to %s" % DATA_PATH)

In [None]:
# === Option B: Load existing dataset from Drive ===

import pandas as pd

panel = pd.read_parquet(DATA_PATH)
valid_tickers = panel.index.get_level_values(1).unique().tolist()
print("Loaded panel: %s" % str(panel.shape))
print("Tickers (%d): %s" % (len(valid_tickers), valid_tickers))
print("Date range: %s to %s" % (
    panel.index.get_level_values(0).min(),
    panel.index.get_level_values(0).max(),
))

## 4. Configure Models

In [None]:
from training.model_config import ModelConfig, ConfigGrid

# All 5 model types in one comparison
configs = [
    ModelConfig(model_type="elastic_net", learning_rate=0.1, epochs=1),
    ModelConfig(model_type="lightgbm", learning_rate=0.05, epochs=500,
                extra_params={"num_leaves": 31, "max_depth": 6}),
    ModelConfig(model_type="lightgbm", learning_rate=0.01, epochs=500,
                extra_params={"num_leaves": 63, "max_depth": 8}),
    ModelConfig(model_type="lstm_attention", learning_rate=1e-3, epochs=100,
                dropout=0.2, sequence_length=60),
    ModelConfig(model_type="transformer", learning_rate=3e-4, epochs=100,
                dropout=0.2, sequence_length=60),
    ModelConfig(model_type="hybrid_multimodal", learning_rate=1e-3, epochs=50,
                dropout=0.2, batch_size=64,
                extra_params={
                    "n_tickers": len(valid_tickers),
                    "hidden_dim": 128,
                    "fusion_dim": 128,
                    "vae_latent_dim": 16,
                    "patience": 10,
                }),
]

print("Total configs to train: %d" % len(configs))
for i, c in enumerate(configs):
    print("  [%d] %s | lr=%s | dropout=%s | epochs=%s" % (
        i, c.model_type, c.learning_rate, c.dropout, c.epochs))

## 5. Train with Walk-Forward Validation

In [None]:
from training.model_config import MultiConfigRunner
from training.model_selection import WalkForwardConfig

# Define walk-forward settings
HORIZON = "1M"  # Change to "3M" or "6M" as needed
TARGET_COL = "fwd_return_21d"  # Must match horizon: 21d=1M, 63d=3M, 126d=6M

wf_config = WalkForwardConfig(
    train_start="2015-01-01",
    test_end="2025-01-01",
    train_min_months=36,
    val_months=6,
    test_months=6,
    step_months=6,
    embargo_days=21,
    expanding=True,
)

# Feature columns (exclude targets, _close, and ticker_id)
feature_cols = [
    c for c in panel.columns
    if not c.startswith("fwd_return_")
    and not c.startswith("residual_return_")
    and not c.startswith("ranked_target_")
    and c not in ("_close", "ticker_id")
]
print("Features: %d" % len(feature_cols))
print("Target: %s" % TARGET_COL)
print("Horizon: %s" % HORIZON)

In [None]:
# Optional: Prune low-importance features to speed up training
from training.feature_engineering import prune_features

sample = panel.head(5000)  # Use a sample for feature selection
_, selected_cols = prune_features(
    sample[feature_cols],
    sample[TARGET_COL],
    importance_threshold=0.005,
)
print(f"Pruned: {len(feature_cols)} -> {len(selected_cols)} features")

# Uncomment to use pruned features:
# feature_cols = selected_cols

In [None]:
import time

runner = MultiConfigRunner(save_to_registry=False)

print("Starting training...")
print("=" * 60)
t0 = time.time()

results = runner.run(
    configs=configs,
    panel=panel,
    target_col=TARGET_COL,
    feature_cols=feature_cols,
    wf_config=wf_config,
    horizon=HORIZON,
)

total_time = time.time() - t0
print("=" * 60)
print(f"Training complete in {total_time:.1f}s")
print(f"Configs evaluated: {len(results)}")

## 6. Compare Results

In [None]:
import pandas as pd

# Build results table
rows = []
for r in results:
    ev = r.evaluation
    rows.append({
        "Model": r.config.model_type,
        "LR": r.config.learning_rate,
        "Dropout": r.config.dropout,
        "IC": round(ev.mean_ic, 4),
        "ICIR": round(ev.icir, 2),
        "Sharpe": round(ev.mean_sharpe, 2),
        "Max DD": round(ev.mean_mdd, 4),
        "Calmar": round(ev.mean_calmar, 2),
        "Hit Ratio": round(ev.mean_hit_ratio, 4),
        "Folds": len(ev.fold_results),
        "Time (s)": round(r.training_time, 1),
    })

df_results = pd.DataFrame(rows)
df_results = df_results.sort_values("IC", ascending=False).reset_index(drop=True)
print("\nModel Comparison (sorted by IC):")
print("=" * 80)
display(df_results)

In [None]:
# Statistical comparison
from training.model_comparison import ModelComparisonEngine

evaluations = [r.evaluation for r in results]
engine = ModelComparisonEngine()
report = engine.compare(evaluations)

print("Rankings by IC:")
for name, val in report.rankings.get("ic", []):
    print(f"  {name}: {val:.4f}")

print(f"\nStability scores:")
for name, score in report.stability_scores.items():
    print(f"  {name}: {score:.4f}")

print(f"\nBest per horizon: {report.best_per_horizon}")

if report.significance_tests:
    print(f"\nSignificance tests:")
    for key, test in report.significance_tests.items():
        sig = "YES" if test["ttest_significant_5pct"] else "no"
        print(f"  {key}: p={test['ttest_p_value']:.4f} (significant: {sig})")

In [None]:
# Visualize results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# IC by model
axes[0].barh(df_results["Model"] + " (lr=" + df_results["LR"].astype(str) + ")", df_results["IC"])
axes[0].set_xlabel("Mean IC")
axes[0].set_title("Information Coefficient")

# Sharpe by model
axes[1].barh(df_results["Model"] + " (lr=" + df_results["LR"].astype(str) + ")", df_results["Sharpe"])
axes[1].set_xlabel("Mean Sharpe")
axes[1].set_title("Sharpe Ratio")

# Training time
axes[2].barh(df_results["Model"] + " (lr=" + df_results["LR"].astype(str) + ")", df_results["Time (s)"])
axes[2].set_xlabel("Seconds")
axes[2].set_title("Training Time")

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "comparison.png"), dpi=150, bbox_inches="tight")
plt.show()

## 7. Save Best Models to Drive

In [None]:
import json

# Save all results
results_data = []
for i, r in enumerate(results):
    model_dir = os.path.join(OUTPUT_DIR, f"{r.config.model_type}_v{i}")
    os.makedirs(model_dir, exist_ok=True)

    # Save config
    with open(os.path.join(model_dir, "config.json"), "w") as f:
        json.dump(r.config.to_json(), f, indent=2)

    # Save metrics
    metrics = {
        "mean_ic": r.evaluation.mean_ic,
        "icir": r.evaluation.icir,
        "mean_sharpe": r.evaluation.mean_sharpe,
        "mean_mdd": r.evaluation.mean_mdd,
        "mean_calmar": r.evaluation.mean_calmar,
        "mean_hit_ratio": r.evaluation.mean_hit_ratio,
        "n_folds": len(r.evaluation.fold_results),
    }
    with open(os.path.join(model_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)

    results_data.append({
        "config": r.config.to_json(),
        "training_time": r.training_time,
        "best_params": r.best_params,
        "metrics": metrics,
    })

    print(f"Saved: {model_dir}")

# Save combined results JSON (for local import)
results_path = os.path.join(OUTPUT_DIR, "results.json")
with open(results_path, "w") as f:
    json.dump(results_data, f, indent=2)

print(f"\nAll results saved to {OUTPUT_DIR}")
print(f"Import locally with: DataExporter.import_results('{results_path}')")

## 8. (Optional) Hyperparameter Search with Optuna

In [None]:
# Run Optuna HP search for the best model type
from training.hyperparameter_search import HyperparameterSearcher

best_model_type = results[0].config.model_type
print("Running Optuna HP search for: %s" % best_model_type)

searcher = HyperparameterSearcher(
    model_type=best_model_type,
    panel=panel,
    target_col=TARGET_COL,
    feature_cols=feature_cols,
    outer_config=wf_config,
    n_trials=20,
    inner_folds_count=3,
)

search_results = searcher.search()

print("\nBest params per fold:")
for fold_idx, params in search_results.get("best_params_per_fold", {}).items():
    print("  Fold %s: %s" % (fold_idx, params))

eval_result = search_results.get("evaluation")
if eval_result:
    print("\nEvaluation after HP search:")
    print("  IC: %.4f" % eval_result.mean_ic)
    print("  ICIR: %.2f" % eval_result.icir)
    print("  Sharpe: %.2f" % eval_result.mean_sharpe)

## 9. Export Production Artifacts

Export the best model from section 6 as production artifacts for the FastAPI backend.

In [None]:
import json
import pickle
import numpy as np
from datetime import datetime
from training.models import create_model
from training.model_selection import compute_prediction_metrics

# Use the best model from training results (sorted by IC)
best_result = results[0]
best_config = best_result.config
print("Best model: %s (IC=%.4f, Sharpe=%.2f)" % (
    best_config.model_type, best_result.evaluation.mean_ic, best_result.evaluation.mean_sharpe))

# Retrain best model on full data for production
print("\nRetraining %s on full dataset for production..." % best_config.model_type)

X = panel[feature_cols].values.astype(np.float32)
y = panel[TARGET_COL].values.astype(np.float32)
np.nan_to_num(X, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
np.nan_to_num(y, copy=False, nan=0.0, posinf=0.0, neginf=0.0)

split = int(len(X) * 0.85)
val_split = int(len(X) * 0.95)

model = create_model(best_config.model_type, best_config.to_dict())
model.fit(X[:split], y[:split], X[split:val_split], y[split:val_split], feature_names=feature_cols)

# Evaluate on held-out test set
test_preds = model.predict(X[val_split:])
valid_mask = ~np.isnan(test_preds)
test_metrics = compute_prediction_metrics(y[val_split:][valid_mask], test_preds[valid_mask])
print("Production model test IC: %.4f, Hit ratio: %.4f" % (test_metrics.ic, test_metrics.hit_ratio))

In [None]:
# Save all artifacts to Drive and repo
os.makedirs(ARTIFACT_DIR, exist_ok=True)

# 1. Model
model_path = os.path.join(ARTIFACT_DIR, "model.pkl")
with open(model_path, "wb") as f:
    pickle.dump(model, f)
print("Model saved: %s" % model_path)

if hasattr(model, 'net'):
    import torch
    model.net.eval()
    torch.save(model.net.state_dict(), os.path.join(ARTIFACT_DIR, "model.pt"))
    print("State dict saved")

# 2. Feature scaler
if hasattr(model, 'scaler'):
    with open(os.path.join(ARTIFACT_DIR, "feature_scaler.pkl"), "wb") as f:
        pickle.dump(model.scaler, f)
    print("Scaler saved")

# 3. Config
config_data = {
    "model_type": best_config.model_type,
    "horizons": ["1M", "3M", "6M"],
    "horizon_days": [21, 63, 126],
    "n_features": len(feature_cols),
    "n_tickers": len(valid_tickers),
}
config_data.update(best_config.extra_params)
with open(os.path.join(ARTIFACT_DIR, "config.json"), "w") as f:
    json.dump(config_data, f, indent=2)

# 4. Feature columns
with open(os.path.join(ARTIFACT_DIR, "feature_columns.json"), "w") as f:
    json.dump(feature_cols, f)

# 5. Ticker list
with open(os.path.join(ARTIFACT_DIR, "ticker_list.json"), "w") as f:
    json.dump(valid_tickers, f)

# 6. Training metadata
metadata = {
    "version": "%s_v%s" % (best_config.model_type, datetime.now().strftime("%Y%m%d_%H%M%S")),
    "trained_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "model_type": best_config.model_type,
    "n_tickers": len(valid_tickers),
    "n_features": len(feature_cols),
    "n_samples": len(X),
    "train_size": split,
    "test_ic": float(test_metrics.ic),
    "test_hit_ratio": float(test_metrics.hit_ratio),
    "walkforward_ic": float(best_result.evaluation.mean_ic),
    "walkforward_sharpe": float(best_result.evaluation.mean_sharpe),
    "tickers": valid_tickers,
}
with open(os.path.join(ARTIFACT_DIR, "training_metadata.json"), "w") as f:
    json.dump(metadata, f, indent=2)

print("\nAll artifacts exported to: %s" % ARTIFACT_DIR)
print("Version: %s" % metadata["version"])

# Copy to repo for API serving
LOCAL_ARTIFACTS = "/content/AI-stock-investment-tool/artifacts"
os.makedirs(LOCAL_ARTIFACTS, exist_ok=True)
!cp -r {ARTIFACT_DIR}/* {LOCAL_ARTIFACTS}/
print("Copied to repo artifacts/ for API serving")

## 10. Serve API via ngrok (Connect to Frontend)

Start the FastAPI backend on Colab and expose it via ngrok so your local Next.js frontend can call it.

**Prerequisites:** You need a free [ngrok auth token](https://dashboard.ngrok.com/get-started/your-authtoken). Free tier gives 1 tunnel.

In [None]:
import subprocess
import time
from pyngrok import ngrok

# --- Configure ngrok ---
# Get your free auth token at: https://dashboard.ngrok.com/get-started/your-authtoken
NGROK_AUTH_TOKEN = ""  # <-- Paste your token here

if NGROK_AUTH_TOKEN:
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
else:
    print("WARNING: No ngrok auth token set. Tunnel may not work.")
    print("Get one at: https://dashboard.ngrok.com/get-started/your-authtoken")

# Make sure artifacts are in the repo dir
LOCAL_ARTIFACTS = "/content/AI-stock-investment-tool/artifacts"
assert os.path.exists(os.path.join(LOCAL_ARTIFACTS, "model.pkl")), \
    "No model.pkl found! Run sections 9-10 first to train and export."

# Install uvicorn if not present
!pip install -q uvicorn fastapi pydantic python-multipart

# Start FastAPI in background
os.chdir("/content/AI-stock-investment-tool")
server_proc = subprocess.Popen(
    ["python", "-m", "uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000"],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
)
time.sleep(3)

# Check if server started successfully
if server_proc.poll() is not None:
    print("ERROR: Server failed to start!")
    print(server_proc.stderr.read().decode())
else:
    print("FastAPI server started (PID: %d)" % server_proc.pid)

    # Open ngrok tunnel
    public_url = ngrok.connect(8000)
    print("\n" + "=" * 60)
    print("PUBLIC API URL: %s" % public_url)
    print("=" * 60)
    print("\nAPI Docs:  %s/docs" % public_url)
    print("Health:    %s/api/v1/health" % public_url)
    print("Predict:   %s/api/v1/predict" % public_url)
    print("\n--- To connect your local frontend ---")
    print("Option A (PowerShell): set env var before npm run dev:")
    print('  $env:NEXT_PUBLIC_API_URL="%s"' % public_url)
    print("  cd frontend; npm run dev")
    print("\nOption B: Create frontend/.env.local with:")
    print("  NEXT_PUBLIC_API_URL=%s" % public_url)
    print("\nKeep this cell running! The tunnel closes when the runtime stops.")

In [None]:
# Quick test: verify API is responding
import urllib.request
import json

test_url = "http://localhost:8000/api/v1/health"
try:
    resp = urllib.request.urlopen(test_url, timeout=5)
    data = json.loads(resp.read().decode())
    print("Health check OK:")
    for k, v in data.items():
        print("  %s: %s" % (k, v))
except Exception as e:
    print("Health check failed: %s" % e)
    print("Check server logs:")
    if server_proc.poll() is not None:
        print(server_proc.stderr.read().decode()[-500:])

## 11. Download Artifacts to Local Machine

If you prefer to run the API locally instead of via ngrok, download the trained artifacts.

In [None]:
# Zip artifacts and download to your local machine
!cd {ARTIFACT_DIR} && zip -r /content/artifacts.zip .

print("Artifact contents:")
!ls -lh {ARTIFACT_DIR}

print("\nTotal zip size:")
!ls -lh /content/artifacts.zip

# Download via browser
from google.colab import files
files.download("/content/artifacts.zip")

print("\nAfter downloading, on your local machine:")
print("  1. Unzip into your project root:")
print("     unzip artifacts.zip -d artifacts/")
print("  2. Start the API:")
print("     python -m api.main")
print("  3. Start the frontend:")
print("     cd frontend && npm run dev")

## 12. Cleanup

Stop the API server and ngrok tunnel when done.

In [None]:
# Stop server and tunnel
try:
    ngrok.disconnect(public_url)
    ngrok.kill()
    print("ngrok tunnel closed")
except Exception:
    pass

try:
    server_proc.terminate()
    server_proc.wait(timeout=5)
    print("FastAPI server stopped")
except Exception:
    pass