<a href="https://colab.research.google.com/github/josetraderx/mean_reversion_OU/blob/main/mean_reversion_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mean Reversion Trading Strategy Demo 📈\n",
    "\n",
    "## Interactive demonstration of the Ornstein-Uhlenbeck mean reversion strategy\n",
    "\n",
    "This notebook walks you through:\n",
    "1. **Data Collection** - Fetching real market data\n",
    "2. **Parameter Estimation** - Using Maximum Likelihood Estimation\n",
    "3. **Signal Generation** - Creating buy/sell signals\n",
    "4. **Backtesting** - Evaluating strategy performance\n",
    "5. **Visualization** - Interactive plots and analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🚀 Setup and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages if running in Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install yfinance\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import yfinance as yf\n",
    "from scipy.optimize import minimize\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Set style for better plots\n",
    "plt.style.use('seaborn-v0_8')\n",
    "plt.rcParams['figure.figsize'] = (12, 8)\n",
    "\n",
    "print(\"✅ All imports successful!\")\n",
    "print(\"🚀 Ready to start the demo!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📊 Step 1: Data Collection\n",
    "\n",
    "Let's fetch real market data for a stock. You can change the ticker and date range below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "TICKER = \"AAPL\"  # Change this to any ticker you want\n",
    "START_DATE = \"2022-01-01\"\n",
    "END_DATE = \"2024-01-01\"\n",
    "\n",
    "print(f\"📈 Fetching data for {TICKER} from {START_DATE} to {END_DATE}...\")\n",
    "\n",
    "# Fetch data\n",
    "data = yf.download(TICKER, start=START_DATE, end=END_DATE)\n",
    "prices = data['Close'].values\n",
    "\n",
    "print(f\"✅ Successfully fetched {len(prices)} data points\")\n",
    "print(f\"📊 Price range: ${prices.min():.2f} - ${prices.max():.2f}\")\n",
    "\n",
    "# Plot the raw data\n",
    "plt.figure(figsize=(14, 6))\n",
    "plt.plot(data.index, prices, linewidth=1.5, alpha=0.8)\n",
    "plt.title(f'{TICKER} Stock Price Over Time', fontsize=16, fontweight='bold')\n",
    "plt.xlabel('Date', fontsize=12)\n",
    "plt.ylabel('Price ($)', fontsize=12)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🔢 Step 2: Parameter Estimation Using Maximum Likelihood\n",
    "\n",
    "Now we'll estimate the Ornstein-Uhlenbeck process parameters:\n",
    "- **μ (mu)**: Long-term mean\n",
    "- **θ (theta)**: Speed of mean reversion\n",
    "- **σ (sigma)**: Volatility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_ou_parameters(prices):\n",
    "    \"\"\"\n",
    "    Estimate Ornstein-Uhlenbeck parameters using Maximum Likelihood Estimation\n",
    "    \"\"\"\n",
    "    n = len(prices)\n",
    "    dt = 1 / 252  # Daily data, 252 trading days per year\n",
    "    \n",
    "    def ou_log_likelihood(params):\n",
    "        mu, theta, sigma = params\n",
    "        \n",
    "        if theta <= 0 or sigma <= 0:\n",
    "            return 1e6\n",
    "            \n",
    "        X_diff = np.diff(prices)\n",
    "        X_lag = prices[:-1]\n",
    "        \n",
    "        # Expected change according to OU process\n",
    "        expected_diff = theta * (mu - X_lag) * dt\n",
    "        \n",
    "        # Residuals\n",
    "        residuals = X_diff - expected_diff\n",
    "        \n",
    "        # Log likelihood\n",
    "        variance = sigma**2 * dt\n",
    "        log_likelihood = -0.5 * len(residuals) * np.log(2 * np.pi * variance) - \\\n",
    "                        0.5 * np.sum(residuals**2) / variance\n",
    "                        \n",
    "        return -log_likelihood  # Return negative for minimization\n",
    "    \n",
    "    # Initial parameter guess\n",
    "    mu_init = np.mean(prices)\n",
    "    theta_init = 0.1\n",
    "    sigma_init = np.std(np.diff(prices)) / np.sqrt(dt)\n",
    "    \n",
    "    initial_params = [mu_init, theta_init, sigma_init]\n",
    "    \n",
    "    # Optimize\n",
    "    result = minimize(ou_log_likelihood, initial_params, \n",
    "                     method='L-BFGS-B', \n",
    "                     bounds=[(None, None), (1e-6, None), (1e-6, None)])\n",
    "    \n",
    "    return result.x\n",
    "\n",
    "# Estimate parameters\n",
    "print(\"🔍 Estimating Ornstein-Uhlenbeck parameters...\")\n",
    "mu_est, theta_est, sigma_est = estimate_ou_parameters(prices)\n",
    "\n",
    "print(f\"\\n📊 Estimated Parameters:\")\n",
    "print(f\"   μ (Long-term mean): ${mu_est:.2f}\")\n",
    "print(f\"   θ (Mean reversion speed): {theta_est:.4f}\")\n",
    "print(f\"   σ (Volatility): ${sigma_est:.2f}\")\n",
    "print(f\"\\n💡 Interpretation:\")\n",
    "print(f\"   - The stock tends to revert to ${mu_est:.2f}\")\n",
    "print(f\"   - Half-life of mean reversion: {np.log(2)/theta_est:.1f} days\")\n",
    "print(f\"   - Daily volatility: ${sigma_est * np.sqrt(1/252):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📈 Step 3: Generate Trading Signals\n",
    "\n",
    "Based on the estimated parameters, we'll generate buy and sell signals:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_signals(prices, mu, sigma, threshold_multiplier=1.5):\n",
    "    \"\"\"\n",
    "    Generate buy and sell signals based on deviation from mean\n",
    "    \"\"\"\n",
    "    threshold = threshold_multiplier * sigma * np.sqrt(1/252)  # Daily threshold\n",
    "    \n",
    "    buy_signals = prices < (mu - threshold)\n",
    "    sell_signals = prices > (mu + threshold)\n",
    "    \n",
    "    return buy_signals, sell_signals, threshold\n",
    "\n",
    "# Generate signals\n",
    "buy_signals, sell_signals, threshold = generate_signals(prices, mu_est, sigma_est)\n",
    "\n",
    "print(f\"📊 Signal Summary:\")\n",
    "print(f\"   Buy signals: {np.sum(buy_signals)} ({np.sum(buy_signals)/len(prices)*100:.1f}% of days)\")\n",
    "print(f\"   Sell signals: {np.sum(sell_signals)} ({np.sum(sell_signals)/len(prices)*100:.1f}% of days)\")\n",
    "print(f\"   Threshold: ${threshold:.2f}\")\n",
    "\n",
    "# Plot signals\n",
    "plt.figure(figsize=(14, 8))\n",
    "plt.plot(data.index, prices, label='Price', linewidth=1.5, alpha=0.8)\n",
    "plt.axhline(mu_est, color='black', linestyle='--', label=f'Mean (${mu_est:.2f})', alpha=0.7)\n",
    "plt.axhline(mu_est + threshold, color='red', linestyle=':', label=f'Sell Threshold', alpha=0.7)\n",
    "plt.axhline(mu_est - threshold, color='green', linestyle=':', label=f'Buy Threshold', alpha=0.7)\n",
    "\n",
    "# Plot signals\n",
    "buy_dates = data.index[buy_signals]\n",
    "sell_dates = data.index[sell_signals]\n",
    "\n",
    "if len(buy_dates) > 0:\n",
    "    plt.scatter(buy_dates, prices[buy_signals], color='green', marker='^', \n",
    "               s=50, label=f'Buy Signals ({len(buy_dates)})', zorder=5)\n",
    "               \n",
    "if len(sell_dates) > 0:\n",
    "    plt.scatter(sell_dates, prices[sell_signals], color='red', marker='v', \n",
    "               s=50, label=f'Sell Signals ({len(sell_dates)})', zorder=5)\n",
    "\n",
    "plt.title(f'{TICKER} - Mean Reversion Trading Signals', fontsize=16, fontweight='bold')\n",
    "plt.xlabel('Date', fontsize=12)\n",
    "plt.ylabel('Price ($)', fontsize=12)\n",
    "plt.legend(loc='best')\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🎯 Step 4: Backtest the Strategy\n",
    "\n",
    "Now let's see how our strategy would have performed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def backtest_strategy(prices, buy_signals, sell_signals, initial_capital=10000):\n",
    "    \"\"\"\n",
    "    Backtest the mean reversion strategy\n",
    "    \"\"\"\n",
    "    portfolio_value = np.zeros(len(prices))\n",
    "    cash = initial_capital\n",
    "    position = 0  # 0: no position, 1: long, -1: short\n",
    "    shares = 0\n",
    "    trades = []\n",
    "    \n",
    "    portfolio_value[0] = initial_capital\n",
    "    \n",
    "    for i in range(1, len(prices)):\n",
    "        current_price = prices[i]\n",
    "        \n",
    "        # Check for entry signals\n",
    "        if position == 0:\n",
    "            if buy_signals[i]:  # Enter long position\n",
    "                shares = cash / current_price\n",
    "                position = 1\n",
    "                trades.append(('BUY', data.index[i], current_price, shares))\n",
    "                cash = 0\n",
    "            elif sell_signals[i]:  # Enter short position (simplified)\n",
    "                position = -1\n",
    "                trades.append(('SELL_SHORT', data.index[i], current_price, cash/current_price))\n",
    "        \n",
    "        # Check for exit signals\n",
    "        elif position == 1 and sell_signals[i]:  # Exit long\n",
    "            cash = shares * current_price\n",
    "            trades.append(('SELL', data.index[i], current_price, shares))\n",
    "            shares = 0\n",
    "            position = 0\n",
    "            \n",
    "        elif position == -1 and buy_signals[i]:  # Exit short\n",
    "            position = 0\n",
    "            trades.append(('COVER', data.index[i], current_price, 0))\n",
    "        \n",
    "        # Calculate portfolio value\n",
    "        if position == 1:\n",
    "            portfolio_value[i] = shares * current_price\n",
    "        else:\n",
    "            portfolio_value[i] = cash\n",
    "    \n",
    "    return portfolio_value, trades\n",
    "\n",
    "# Run backtest\n",
    "print(\"🔄 Running backtest...\")\n",
    "portfolio_values, trades = backtest_strategy(prices, buy_signals, sell_signals)\n",
    "\n",
    "# Calculate performance metrics\n",
    "initial_capital = 10000\n",
    "final_value = portfolio_values[-1]\n",
    "total_return = (final_value - initial_capital) / initial_capital * 100\n",
    "buy_hold_return = (prices[-1] - prices[0]) / prices[0] * 100\n",
    "\n",
    "print(f\"\\n📊 Backtest Results:\")\n",
    "print(f\"   Initial Capital: ${initial_capital:,.2f}\")\n",
    "print(f\"   Final Portfolio Value: ${final_value:,.2f}\")\n",
    "print(f\"   Total Return: {total_return:.2f}%\")\n",
    "print(f\"   Buy & Hold Return: {buy_hold_return:.2f}%\")\n",
    "print(f\"   Strategy vs Buy & Hold: {total_return - buy_hold_return:+.2f}%\")\n",
    "print(f\"   Total Trades: {len(trades)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📊 Step 5: Visualization and Analysis\n",
    "\n",
    "Let's create comprehensive visualizations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create comprehensive performance visualization\n",
    "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))\n",
    "\n",
    "# 1. Price and signals\n",
    "ax1.plot(data.index, prices, label='Price', alpha=0.8)\n",
    "ax1.axhline(mu_est, color='black', linestyle='--', alpha=0.7, label='Mean')\n",
    "if np.sum(buy_signals) > 0:\n",
    "    ax1.scatter(data.index[buy_signals], prices[buy_signals], \n",
    "               color='green', marker='^', s=30, label='Buy')\n",
    "if np.sum(sell_signals) > 0:\n",
    "    ax1.scatter(data.index[sell_signals], prices[sell_signals], \n",
    "               color='red', marker='v', s=30, label='Sell')\n",
    "ax1.set_title('Price and Trading Signals', fontweight='bold')\n",
    "ax1.set_ylabel('Price ($)')\n",
    "ax1.legend()\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# 2. Portfolio performance\n",
    "buy_hold_portfolio = initial_capital * (prices / prices[0])\n",
    "ax2.plot(data.index, portfolio_values, label='Strategy', linewidth=2)\n",
    "ax2.plot(data.index, buy_hold_portfolio, label='Buy & Hold', alpha=0.7)\n",
    "ax2.set_title('Portfolio Performance Comparison', fontweight='bold')\n",
    "ax2.set_ylabel('Portfolio Value ($)')\n",
    "ax2.legend()\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "# 3. Price distribution\n",
    "ax3.hist(prices, bins=50, alpha=0.7, density=True, color='skyblue')\n",
    "ax3.axvline(mu_est, color='red', linestyle='--', linewidth=2, label=f'Estimated Mean: ${mu_est:.2f}')\n",
    "ax3.axvline(np.mean(prices), color='green', linestyle=':', linewidth=2, label=f'Sample Mean: ${np.mean(prices):.2f}')\n",
    "ax3.set_title('Price Distribution', fontweight='bold')\n",
    "ax3.set_xlabel('Price ($)')\n",
    "ax3.set_ylabel('Density')\n",
    "ax3.legend()\n",
    "\n",
    "# 4. Rolling returns\n",
    "strategy_returns = np.diff(portfolio_values) / portfolio_values[:-1] * 100\n",
    "buy_hold_returns = np.diff(buy_hold_portfolio) / buy_hold_portfolio[:-1] * 100\n",
    "\n",
    "window = 30  # 30-day rolling window\n",
    "strategy_rolling = pd.Series(strategy_returns).rolling(window).mean()\n",
    "buy_hold_rolling = pd.Series(buy_hold_returns).rolling(window).mean()\n",
    "\n",
    "ax4.plot(data.index[1:], strategy_rolling, label='Strategy (30-day avg)', linewidth=2)\n",
    "ax4.plot(data.index[1:], buy_hold_rolling, label='Buy & Hold (30-day avg)', alpha=0.7)\n",
    "ax4.axhline(0, color='black', linestyle='-', alpha=0.3)\n",
    "ax4.set_title('Rolling Average Daily Returns', fontweight='bold')\n",
    "ax4.set_xlabel('Date')\n",
    "ax4.set_ylabel('Daily Return (%)')\n",
    "ax4.legend()\n",
    "ax4.grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📋 Summary and Next Steps\n",
    "\n",
    "Let's create a comprehensive summary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate additional metrics\n",
    "strategy_returns = np.diff(portfolio_values) / portfolio_values[:-1]\n",
    "strategy_volatility = np.std(strategy_returns) * np.sqrt(252)  # Annualized\n",
    "strategy_sharpe = (total_return/100) / (strategy_volatility) if strategy_volatility > 0 else 0\n",
    "\n",
    "buy_hold_returns_calc = np.diff(buy_hold_portfolio) / buy_hold_portfolio[:-1]\n",
    "buy_hold_volatility = np.std(buy_hold_returns_calc) * np.sqrt(252)\n",
    "buy_hold_sharpe = (buy_hold_return/100) / (buy_hold_volatility) if buy_hold_volatility > 0 else 0\n",
    "\n",
    "print(\"=\"*60)\n",
    "print(f\"📊 COMPREHENSIVE STRATEGY ANALYSIS - {TICKER}\")\n",
    "print(\"=\"*60)\n",
    "print(f\"📅 Period: {START_DATE} to {END_DATE}\")\n",
    "print(f\"📈 Trading Days: {len(prices)}\")\n",
    "print()\n",
    "print(\"🔢 ORNSTEIN-UHLENBECK PARAMETERS:\")\n",
    "print(f\"   μ (Long-term mean): ${mu_est:.2f}\")\n",
    "print(f\"   θ (Mean reversion speed): {theta_est:.4f}\")\n",
    "print(f\"   σ (Volatility): ${sigma_est:.2f}\")\n",
    "print(f\"   Half-life: {np.log(2)/theta_est:.1f} days\")\n",
    "print()\n",
    "print(\"📊 TRADING ACTIVITY:\")\n",
    "print(f\"   Buy signals: {np.sum(buy_signals)} ({np.sum(buy_signals)/len(prices)*100:.1f}% of days)\")\n",
    "print(f\"   Sell signals: {np.sum(sell_signals)} ({np.sum(sell_signals)/len(prices)*100:.1f}% of days)\")\n",
    "print(f\"   Total trades: {len(trades)}\")\n",
    "print()\n",
    "print(\"💰 PERFORMANCE COMPARISON:\")\n",
    "print(f\"   Strategy Return: {total_return:+.2f}%\")\n",
    "print(f\"   Buy & Hold Return: {buy_hold_return:+.2f}%\")\n",
    "print(f\"   Excess Return: {total_return - buy_hold_return:+.2f}%\")\n",
    "print()\n",
    "print(\"⚡ RISK METRICS:\")\n",
    "print(f\"   Strategy Volatility: {strategy_volatility:.2f}\")\n",
    "print(f\"   Buy & Hold Volatility: {buy_hold_volatility:.2f}\")\n",
    "print(f\"   Strategy Sharpe Ratio: {strategy_sharpe:.2f}\")\n",
    "print(f\"   Buy & Hold Sharpe Ratio: {buy_hold_sharpe:.2f}\")\n",
    "print()\n",
    "print(\"✅ STRATEGY ASSESSMENT:\")\n",
    "if total_return > buy_hold_return:\n",
    "    print(\"   🎯 Strategy OUTPERFORMED buy & hold\")\n",
    "else:\n",
    "    print(\"   📉 Strategy UNDERPERFORMED buy & hold\")\n",
    "    \n",
    "if strategy_sharpe > buy_hold_sharpe:\n",
    "    print(\"   🏆 Strategy has BETTER risk-adjusted returns\")\n",
    "else:\n",
    "    print(\"   ⚠️  Strategy has WORSE risk-adjusted returns\")\n",
    "    \n",
    "print(\"=\"*60)\n",
    "print(\"⚠️  DISCLAIMER: This is for educational purposes only.\")\n",
    "print(\"   Past performance does not guarantee future results.\")\n",
    "print(\"=\"*60)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🎮 Try Different Parameters!\n",
    "\n",
    "Want to experiment? Go back and change:\n",
    "- **TICKER**: Try different stocks (MSFT, GOOGL, TSLA, etc.)\n",
    "- **Date ranges**: Different time periods\n",
    "- **threshold_multiplier**: More or less sensitive signals\n",
    "\n",
    "## 🚀 Next Steps\n",
    "\n",
    "1. **Improve the strategy**: Add transaction costs, better risk management\n",
    "2. **Test on more assets**: Portfolio of mean-reverting stocks\n",
    "3. **Real-time implementation**: Connect to live data feeds\n",
    "4. **Machine learning**: Use ML to optimize parameters\n",
    "\n",
    "## 📚 Learn More\n",
    "\n",
    "- [Ornstein-Uhlenbeck Process Theory](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process)\n",
    "- [Mean Reversion in Finance](https://www.investopedia.com/terms/m/meanreversion.asp)\n",
    "- [Quantitative Trading Strategies](https://www.quantstart.com/)\n",
    "\n",
    "---\n",
    "\n",
    "**🎉 Congratulations! You've successfully implemented and backtested a mean reversion trading strategy!**"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
