In [None]:
from alpaca.data.historical import StockHistoricalDataClient
from alpaca.data.requests import StockBarsRequest
from alpaca.data.timeframe import TimeFrame
from dotenv import load_dotenv
import os
import pandas as pd


In [None]:
load_dotenv(r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\venv\.env")  # or just ".env" depending on your setup

API_KEY = os.getenv("ALPACA_API_KEY")
API_SECRET = os.getenv("ALPACA_SECRET")
#print(API_KEY, API_SECRET)


In [None]:
# Initialize Alpaca historical data client (for stocks)
client = StockHistoricalDataClient(API_KEY, API_SECRET)

In [None]:
from datetime import datetime, timedelta

def get_price_data(symbol: str, timeframe: TimeFrame = TimeFrame.Day, limit: int = 30) -> pd.DataFrame:
    """
    Fetch historical OHLCV data for a symbol using alpaca-py.

    :param symbol: Ticker symbol, e.g. "AAPL"
    :param timeframe: TimeFrame object, e.g. TimeFrame.Minute, TimeFrame.Day
    :param limit: Number of bars to fetch
    :return: pandas DataFrame with OHLCV data
    """
    end = datetime.now()
    start = end - timedelta(days=limit * 2)  # buffer for weekends/holidays

    request = StockBarsRequest(
        symbol_or_symbols=[symbol],
        timeframe=timeframe,
        start=start,
        end=end
    )

    bars = client.get_stock_bars(request).df
    symbol_df = bars[bars.index.get_level_values("symbol") == symbol].copy()
    symbol_df.index = symbol_df.index.droplevel("symbol")
    symbol_df.index = symbol_df.index.tz_convert("America/New_York")

    return symbol_df

In [None]:
df = get_price_data("AAPL", timeframe=TimeFrame.Day, limit=30)
df[["open", "high", "low", "close", "volume"]].head() # interacting with the API and get price data

In [None]:
from alpaca.data.requests import StockLatestQuoteRequest, StockLatestTradeRequest

# Latest quote
quote = client.get_stock_latest_quote(StockLatestQuoteRequest(symbol_or_symbols=["AAPL"]))
print(quote)

# Latest trade
trade = client.get_stock_latest_trade(StockLatestTradeRequest(symbol_or_symbols=["AAPL"]))
print(trade)

In [None]:
from alpaca.trading.client import TradingClient

trading_client = TradingClient(API_KEY, API_SECRET, paper=True)

account = trading_client.get_account()
print(account.buying_power)
print(account.status)

In [None]:
positions = trading_client.get_all_positions()
for p in positions:
    print(f"{p.symbol}: {p.qty} shares @ avg price {p.avg_entry_price}")

In [None]:
orders = trading_client.get_orders()  # open, closed, all
for o in orders:
    print(f"{o.symbol} | {o.side} | {o.qty} | {o.status}")

In [None]:
from alpaca.trading.requests import MarketOrderRequest
from alpaca.trading.enums import OrderSide, TimeInForce
# placing orders 
order = MarketOrderRequest(
    symbol="AAPL",
    qty=1,
    side=OrderSide.BUY,
    time_in_force=TimeInForce.DAY
)

response = trading_client.submit_order(order)
print(response)


In [None]:
from alpaca.trading.client import TradingClient

calendar = trading_client.get_calendar()
for day in calendar[:5]:
    print(f"{day.date} | Open: {day.open} | Close: {day.close}")

In [None]:
import asyncio
import nest_asyncio
from alpaca.data.live import StockDataStream
stream = StockDataStream(API_KEY, API_SECRET)

# ✅ Handler must be an async function
async def handle_quote(data):
    print("Quote:", data)

# ✅ Subscribe correctly — NOT using @decorator
stream.subscribe_quotes(handle_quote, "SPY")

# Run the stream
async def main():
    await stream.run()

nest_asyncio.apply()

asyncio.run(main())

In [None]:
# this is a working single stock prototype
import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import os
import asyncio
import logging
import pandas as pd
from datetime import datetime, date
from collections import deque
from dotenv import load_dotenv
import requests
import json
from pytz import timezone

from alpaca.trading.client import TradingClient
from alpaca.trading.requests import MarketOrderRequest
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.data.live import StockDataStream

from strategies.strategy_registry.momentum_strategy import MomentumStrategy
from core.position_sizer import DynamicPositionSizer
from indicators.atr import ATRIndicator

load_dotenv(r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\venv\.env")

API_KEY = os.getenv("ALPACA_API_KEY")
API_SECRET = os.getenv("ALPACA_SECRET")
DISCORD_WEBHOOK = os.getenv("DISCORD_WEBHOOK")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("LiveRunner")

trading_client = TradingClient(API_KEY, API_SECRET, paper=True)
stream = StockDataStream(API_KEY, API_SECRET)

symbol = "AAPL"
daily_buffer = deque(maxlen=100)
current_day_bar = None
strategy = MomentumStrategy()
sizer = DynamicPositionSizer(risk_percentage=0.07)

portfolio_history = []
total_fees = 0
peak_portfolio_value = 0
last_trade_date = None
market_closed = False

# Load historical daily bars
processed_path = Path(rf"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\data\data_storage\proc_data\proc_{symbol}_file.json")
if processed_path.exists():
    with open(processed_path, "r") as f:
        processed_data = json.load(f)
        for item in processed_data[-100:]:
            daily_buffer.append({
                "Date": pd.to_datetime(item["Date"], unit='ms'),
                "Open": item["Open"],
                "High": item["High"],
                "Low": item["Low"],
                "Close": item["Close"],
                "Volume": item["Volume"],
                "ATR": item.get("ATR"),
                "RSI": item.get("RSI"),
                "Momentum": item.get("Momentum"),
                "VWAP": item.get("VWAP"),
                "OBV": item.get("OBV"),
                "MACD": item.get("MACD"),
                "MACD_Signal": item.get("MACD_Signal"),
                "Price_Change": item.get("Price_Change"),
                "Daily_Return": item.get("Daily_Return"),
                "Lag_Close_1": item.get("Lag_Close_1")
            })
    logger.info(f"Loaded {len(daily_buffer)} preprocessed bars for {symbol}.")
else:
    logger.warning(f"Processed data not found for {symbol}.")

def is_market_close(timestamp_utc):
    eastern = timezone("US/Eastern")
    eastern_time = timestamp_utc.astimezone(eastern)
    return eastern_time.hour == 16 and eastern_time.minute == 0

def on_market_close():
    global current_day_bar, portfolio_history, market_closed
    if not market_closed:
        logger.info("Market closed. Running end-of-day actions...")
        if current_day_bar:
            daily_buffer.append(current_day_bar)
        df = pd.DataFrame(list(daily_buffer))
        df.to_csv(f"data/processed/final_{symbol}_bars.csv", index=False)
        pd.DataFrame(portfolio_history).to_csv(f"data/processed/portfolio_{symbol}.csv", index=False)
        final_value = portfolio_history[-1]["Portfolio_Value"] if portfolio_history else 0
        drawdown = portfolio_history[-1]["Drawdown"] if portfolio_history else 0
        summary = f"📈 Market Closed\nFinal Value: ${final_value:.2f} | Drawdown: {drawdown:.2%}"
        send_discord_message(summary)
        prompt = f"The market has closed. Here's the end-of-day summary: {summary}. Provide risk commentary."
        insight = query_local_llm(prompt)
        if insight:
            send_discord_message(f"🤖 LLM Insight (EOD):\n{insight}")
        market_closed = True

async def send_heartbeat(interval_seconds=3600):
    await asyncio.sleep(10)
    send_discord_message(":rocket: LLM Trading Bot is now running.")
    while True:
        await asyncio.sleep(interval_seconds)
        send_discord_message(":heartbeat: LLM Trading Bot is still running...")

def submit_order(symbol, qty, side):
    order = MarketOrderRequest(
        symbol=symbol,
        qty=int(qty),
        side=side,
        time_in_force=TimeInForce.DAY
    )
    try:
        return trading_client.submit_order(order)
    except Exception as e:
        logger.error(f"Order submission failed: {e}")
        return None

def send_discord_message(content):
    if not DISCORD_WEBHOOK:
        return
    try:
        requests.post(DISCORD_WEBHOOK, json={"content": content})
    except Exception as e:
        logger.error(f"Failed to send Discord message: {e}")

def query_local_llm(prompt: str):
    try:
        response = requests.post(
            "http://localhost:11434/api/generate",
            json={"model": "deepseek", "prompt": prompt, "stream": False}
        )
        return response.json().get("response", "").strip()
    except Exception as e:
        logger.error(f"Failed to query LLM: {e}")
        return ""

async def on_bar(bar):
    global current_day_bar, peak_portfolio_value, last_trade_date

    timestamp = pd.to_datetime(bar.timestamp)
    bar_date = timestamp.date()

    if is_market_close(timestamp):
        on_market_close()

    if current_day_bar and current_day_bar['Date'].date() != bar_date:
        daily_buffer.append(current_day_bar)
        logger.info(f"Finalized bar for {current_day_bar['Date'].date()}")
        current_day_bar = None

    if current_day_bar is None:
        current_day_bar = {
            "Date": timestamp,
            "Open": bar.open,
            "High": bar.high,
            "Low": bar.low,
            "Close": bar.close,
            "Volume": bar.volume
        }
    else:
        current_day_bar["High"] = max(current_day_bar["High"], bar.high)
        current_day_bar["Low"] = min(current_day_bar["Low"], bar.low)
        current_day_bar["Close"] = bar.close
        current_day_bar["Volume"] += bar.volume

    logger.info(f"Minute update: {timestamp} | O: {bar.open:.2f} H: {bar.high:.2f} L: {bar.low:.2f} C: {bar.close:.2f} V: {bar.volume}")

    df = pd.DataFrame(list(daily_buffer) + [current_day_bar])
    if len(df) < 20:
        return

    df_tail = df.tail(20).copy()
    df_tail = ATRIndicator(df_tail).compute()
    df.update(df_tail)

    df = strategy.generate_signal(df)
    latest = df.iloc[-1]

    signal = latest['Signal']
    price = latest['Close']
    atr_value = latest['ATR']

    if pd.isna(atr_value) or atr_value <= 0:
        return

    atr_25 = df['ATR'].quantile(0.25)
    atr_75 = df['ATR'].quantile(0.75)

    market_conditions = (
        "low_volatility" if atr_value < atr_25 else
        "high_volatility" if atr_value > atr_75 else
        "normal"
    )

    try:
        account = trading_client.get_account()
        cash = float(account.cash)
    except Exception as e:
        logger.error(f"Failed to fetch account info: {e}")
        return

    try:
        position_data = trading_client.get_open_position(symbol)
        position = int(float(position_data.qty))
    except:
        position = 0

    stop_loss_price = price - (atr_value * 2)
    quantity = sizer.calculate_position_size(price, stop_loss_price, cash, market_conditions, signal)
    trade_fee = 0.001 * price * quantity
    max_affordable_qty = cash // (price + trade_fee)
    quantity = min(quantity, max_affordable_qty)

    print("\n=== LIVE BAR UPDATE ===")
    print(df.tail(5).to_string(index=False))
    print(f"Signal: {signal} | ATR: {atr_value:.2f} | Market: {market_conditions} | Stop-loss: {stop_loss_price:.2f}")

    #today_str = date.today().isoformat()
    #if last_trade_date == today_str:
    #    return

    llm_message = None
    if signal == 1 and quantity > 0: #and position == 0:
        submit_order(symbol, quantity, OrderSide.BUY)
        
        llm_message = f"[LIVE] BUY {quantity} {symbol} @ {price:.2f}"
        #last_trade_date = today_str
    elif signal == -1: #and position > 0:
        submit_order(symbol, position, OrderSide.SELL)
        llm_message = f"[LIVE] SELL {position} {symbol} @ {price:.2f}"
        #last_trade_date = today_str
    elif position > 0 and price <= stop_loss_price:
        submit_order(symbol, position, OrderSide.SELL)
        llm_message = f"[LIVE] STOP-LOSS SELL {position} {symbol} @ {price:.2f}"
        #last_trade_date = today_str
    elif signal == 0:
        llm_message = f"[LIVE] HOLD"
        #last_trade_date = today_str

    if llm_message:
        logger.info(llm_message)
        prompt = f"Analyze the following trade activity and provide a short summary with risk commentary:\n{llm_message}"
        llm_response = query_local_llm(prompt)
        if llm_response:
            send_discord_message(f"🤖 LLM Insight:\n{llm_response}")

    portfolio_value = cash + position * price
    peak_portfolio_value = max(peak_portfolio_value, portfolio_value)
    drawdown = (portfolio_value - peak_portfolio_value) / peak_portfolio_value
    portfolio_history.append({
        "Date": datetime.now(),
        "Portfolio_Value": portfolio_value,
        "Cash": cash,
        "Position": position,
        "Price": price,
        "Drawdown": drawdown
    })

stream.subscribe_bars(on_bar, symbol)

stream_started = False
retry_cooldown_seconds = 300

async def run_stream_with_retry():
    global stream_started
    while not stream_started:
        try:
            logger.info("Attempting to start Alpaca stream...")
            await stream.run()
            stream_started = True
        except Exception as e:
            logger.error(f"Stream failed to start: {e}")
            send_discord_message(f"⚠️ Stream failed to start. Retrying in 5 minutes.\nError: {e}")
            await asyncio.sleep(retry_cooldown_seconds)

async def main():
    heartbeat_task = asyncio.create_task(send_heartbeat())
    await asyncio.gather(heartbeat_task, run_stream_with_retry())

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    asyncio.run(main())



In [None]:
# this is a working single stock prototype
import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import os
import asyncio
import logging
import pandas as pd
from datetime import datetime, date
from collections import deque
from dotenv import load_dotenv
import requests
import json
from pytz import timezone

from alpaca.trading.client import TradingClient
from alpaca.trading.requests import MarketOrderRequest
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.data.live import StockDataStream

from strategies.strategy_registry.momentum_strategy import MomentumStrategy
from core.position_sizer import DynamicPositionSizer
from indicators.atr import ATRIndicator

load_dotenv(r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\venv\.env")

API_KEY = os.getenv("ALPACA_API_KEY")
API_SECRET = os.getenv("ALPACA_SECRET")
DISCORD_WEBHOOK = os.getenv("DISCORD_WEBHOOK")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("LiveRunner")

trading_client = TradingClient(API_KEY, API_SECRET, paper=True)
stream = StockDataStream(API_KEY, API_SECRET)

symbol = "AAPL"
daily_buffer = deque(maxlen=100)
current_day_bar = None
strategy = MomentumStrategy()
sizer = DynamicPositionSizer(risk_percentage=0.07)

portfolio_history = []
total_fees = 0
peak_portfolio_value = 0
last_trade_date = None
market_closed = False
local_position_qty = 0
local_position_side = None  # "long" or "short"
current_stop_loss = None
portfolio_value = 0
MAX_PYRAMID_LAYERS = 12  # You can tune this
pyramid_layer = 0       # Move this to a global or symbol-level state if you go multi-symbol


# Load historical daily bars
processed_path = Path(rf"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\data\data_storage\proc_data\proc_{symbol}_file.json")
if processed_path.exists():
    with open(processed_path, "r") as f:
        processed_data = json.load(f)
        for item in processed_data[-100:]:
            daily_buffer.append({
                "Date": pd.to_datetime(item["Date"], unit='ms'),
                "Open": item["Open"],
                "High": item["High"],
                "Low": item["Low"],
                "Close": item["Close"],
                "Volume": item["Volume"],
                "ATR": item.get("ATR"),
                "RSI": item.get("RSI"),
                "Momentum": item.get("Momentum"),
                "VWAP": item.get("VWAP"),
                "OBV": item.get("OBV"),
                "MACD": item.get("MACD"),
                "MACD_Signal": item.get("MACD_Signal"),
                "Price_Change": item.get("Price_Change"),
                "Daily_Return": item.get("Daily_Return"),
                "Lag_Close_1": item.get("Lag_Close_1")
            })
    logger.info(f"Loaded {len(daily_buffer)} preprocessed bars for {symbol}.")
else:
    logger.warning(f"Processed data not found for {symbol}.")

def is_market_close(timestamp_utc):
    eastern = timezone("US/Eastern")
    eastern_time = timestamp_utc.astimezone(eastern)
    return eastern_time.hour == 16 and eastern_time.minute == 0

def on_market_close():
    global current_day_bar, portfolio_history, market_closed
    if not market_closed:
        logger.info("Market closed. Running end-of-day actions...")
        if current_day_bar:
            daily_buffer.append(current_day_bar)
        df = pd.DataFrame(list(daily_buffer))
        df.to_csv(f"data/processed/final_{symbol}_bars.csv", index=False)
        pd.DataFrame(portfolio_history).to_csv(f"data/processed/portfolio_{symbol}.csv", index=False)
        final_value = portfolio_history[-1]["Portfolio_Value"] if portfolio_history else 0
        drawdown = portfolio_history[-1]["Drawdown"] if portfolio_history else 0
        summary = f"📈 Market Closed\nFinal Value: ${final_value:.2f} | Drawdown: {drawdown:.2%}"
        send_discord_message(summary)
        prompt = f"The market has closed. Here's the end-of-day summary: {summary}. Provide risk commentary."
        insight = query_local_llm(prompt)
        if insight:
            send_discord_message(f"🤖 LLM Insight (EOD):\n{insight}")
        market_closed = True

async def send_heartbeat(interval_seconds=3600):
    await asyncio.sleep(10)
    send_discord_message(":rocket: LLM Trading Bot is now running.")
    while True:
        await asyncio.sleep(interval_seconds)
        send_discord_message(":heartbeat: LLM Trading Bot is still running...")

def submit_order(symbol, qty, side, price):
    try:
        account = trading_client.get_account()
        buying_power = float(account.buying_power)
        max_affordable_qty = int(buying_power // (price * 1.001))

        if qty > max_affordable_qty:
            logger.warning(f"Reducing {symbol} order from {qty} to {max_affordable_qty} due to buying power.")
            qty = max_affordable_qty

        if qty < 1:
            logger.error(f"Insufficient buying power to trade even 1 share of {symbol}. Skipping.")
            return False

        order = MarketOrderRequest(
            symbol=symbol,
            qty=qty,
            side=side,
            time_in_force=TimeInForce.DAY
        )
        trading_client.submit_order(order)
        return True

    except Exception as e:
        logger.error(f"Order submission failed for {symbol} ({side}, qty={qty}): {e}")
        return False

    except Exception as e:
        logger.error(f"Order submission failed for {symbol} ({side}, qty={qty}): {e}")
        return None


def send_discord_message(content):
    if not DISCORD_WEBHOOK:
        return
    try:
        requests.post(DISCORD_WEBHOOK, json={"content": content})
    except Exception as e:
        logger.error(f"Failed to send Discord message: {e}")

def query_local_llm(prompt: str):
    try:
        response = requests.post(
            "http://localhost:11434/api/generate",
            json={"model": "deepseek-r1:latest ", "prompt": prompt, "stream": False}
        )
        return response.json().get("response", "").strip()
    except Exception as e:
        logger.error(f"Failed to query LLM: {e}")
        return ""

async def on_bar(bar):
    global current_day_bar, daily_buffer
    global local_position_qty, local_position_side, portfolio_value
    global current_stop_loss, current_take_profit, entry_price, peak_portfolio_value, pyramid_layer

    peak_portfolio_value = max(peak_portfolio_value, portfolio_value)   
    timestamp = pd.to_datetime(bar.timestamp)
    bar_date = timestamp.date()

    if is_market_close(timestamp):
        on_market_close()

    if current_day_bar and current_day_bar['Date'].date() != bar_date:
        daily_buffer.append(current_day_bar)
        logger.info(f"Finalized bar for {current_day_bar['Date'].date()}")
        current_day_bar = None

    if current_day_bar is None:
        current_day_bar = {
            "Date": timestamp,
            "Open": bar.open,
            "High": bar.high,
            "Low": bar.low,
            "Close": bar.close,
            "Volume": bar.volume
        }
    else:
        current_day_bar["High"] = max(current_day_bar["High"], bar.high)
        current_day_bar["Low"] = min(current_day_bar["Low"], bar.low)
        current_day_bar["Close"] = bar.close
        current_day_bar["Volume"] += bar.volume

    logger.info(f"Minute update: {timestamp} | O: {bar.open:.2f} H: {bar.high:.2f} L: {bar.low:.2f} C: {bar.close:.2f} V: {bar.volume}")

    df = pd.DataFrame(list(daily_buffer) + [current_day_bar])
    if len(df) < 20:
        return

    df = ATRIndicator(df).compute()
    df = strategy.generate_signal(df)
    latest = df.iloc[-1]

    signal = latest["Signal"]
    price = latest["Close"]
    atr = latest["ATR"]
    logger.info(f"[SIGNAL] {bar.timestamp} | {symbol} | Signal: {signal} | Price: {price:.2f} | ATR: {atr:.2f}")

    if pd.isna(atr) or atr <= 0:
        return

    # Volatility regimes
    atr_25 = df["ATR"].quantile(0.25)
    atr_75 = df["ATR"].quantile(0.75)
    regime = (
        "low_volatility" if atr < atr_25 else
        "high_volatility" if atr > atr_75 else
        "normal"
    )

    # Position sizing
    try:
        account = trading_client.get_account()
        cash = float(account.cash)
    except Exception as e:
        logger.error(f"Failed to fetch account info: {e}")
        return

    try:
        position_data = trading_client.get_open_position(symbol)
        position = int(float(position_data.qty))
    except:
        position = 0

    stop_price = price - (atr * 2)
    quantity = sizer.calculate_position_size(price, stop_price, cash, regime, signal)
    trade_fee = 0.001 * price * quantity
    estimated_cost = price * quantity + trade_fee

    if estimated_cost > cash:
        max_affordable_qty = int(cash // (price * 1.001))  # 0.1% trade fee
        logger.warning(f"Reducing quantity from {quantity} to {max_affordable_qty} due to buying power.")
        quantity = max(max_affordable_qty, 0)

    if quantity < 1:
        logger.error(f"Insufficient buying power to trade even 1 share of {symbol}. Skipping.")
        return

    TP_MULTIPLIERS = {
        "low_volatility": 1.003,
        "normal": 1.005,
        "high_volatility": 1.10,
    }

    tp_mult = TP_MULTIPLIERS.get(regime, 1.05)
    exit_fraction = 0.25
    min_qty = 1
    llm_message = None

    # === ENTRY ===
    if signal == 1 and local_position_side != "long":
        if local_position_side == "short" and local_position_qty > 0:

            if submit_order(symbol, local_position_qty, OrderSide.BUY, price):
                logger.info(f"[LIVE] FLATTEN SHORT BEFORE LONG: BUY {local_position_qty} {symbol} @ {price:.2f}")
                local_position_qty = 0

        result = submit_order(symbol, quantity, OrderSide.BUY, price)
        if result:
            local_position_qty = quantity
            local_position_side = "long"
            current_stop_loss = price - (atr * 2)
            current_take_profit = price * tp_mult
            entry_price = price
            llm_message = f"[LIVE] LONG ENTRY: BUY {quantity} {symbol} @ {price:.2f} | SL: {current_stop_loss:.2f} | TP: {current_take_profit:.2f}"

    elif signal == -1 and local_position_side != "short":
        if local_position_side == "long" and local_position_qty > 0:
            if submit_order(symbol, local_position_qty, OrderSide.SELL, price):
                logger.info(f"[LIVE] FLATTEN LONG BEFORE SHORT: SELL {local_position_qty} {symbol} @ {price:.2f}")
                local_position_qty = 0

        result = submit_order(symbol, quantity, OrderSide.SELL, price)
        if result:
            local_position_qty = quantity
            local_position_side = "short"
            current_stop_loss = price + (atr * 2)
            current_take_profit = price * (2 - tp_mult)
            entry_price = price
            llm_message = f"[LIVE] SHORT ENTRY: SELL {quantity} {symbol} @ {price:.2f} | SL: {current_stop_loss:.2f} | TP: {current_take_profit:.2f}"

    # === SIGNAL FLIP ===
    elif signal == -1 and local_position_side == "long":
        qty = max(int(local_position_qty * exit_fraction), min_qty)
        trading_client.close_position(symbol)
        if submit_order(symbol, qty, OrderSide.SELL, price):
            local_position_qty -= qty
            llm_message = f"[LIVE] SIGNAL FLIP (LONG): PARTIAL SELL {qty} @ {price:.2f}"
            if local_position_qty <= 0:
                local_position_qty = 0
                local_position_side = None

    elif signal == 1 and local_position_side == "short":
        qty = max(int(local_position_qty * exit_fraction), min_qty)
        trading_client.close_position(symbol)
        if submit_order(symbol, qty, OrderSide.BUY, price):
            local_position_qty -= qty
            llm_message = f"[LIVE] SIGNAL FLIP (SHORT): PARTIAL COVER {qty} @ {price:.2f}"
            if local_position_qty <= 0:
                local_position_qty = 0
                local_position_side = None

    # === TP ===
    elif local_position_side == "long" and price >= current_take_profit:
        qty = max(int(local_position_qty * exit_fraction), min_qty)
        if submit_order(symbol, qty, OrderSide.SELL, price):
            local_position_qty -= qty
            llm_message = f"[LIVE] TP HIT (LONG): PARTIAL SELL {qty} @ {price:.2f}"
            if local_position_qty <= 0:
                local_position_qty = 0
                local_position_side = None

    elif local_position_side == "short" and price <= current_take_profit:
        qty = max(int(local_position_qty * exit_fraction), min_qty)
        if submit_order(symbol, qty, OrderSide.BUY, price):
            local_position_qty -= qty
            llm_message = f"[LIVE] TP HIT (SHORT): PARTIAL COVER {qty} @ {price:.2f}"
            if local_position_qty <= 0:
                local_position_qty = 0
                local_position_side = None

    # === SL ===
    elif local_position_side == "long" and price <= current_stop_loss:
        if submit_order(symbol, local_position_qty, OrderSide.SELL, price):
            llm_message = f"[LIVE] SL HIT (LONG): FULL EXIT {local_position_qty} @ {price:.2f}"
            local_position_qty = 0
            local_position_side = None

    elif local_position_side == "short" and price >= current_stop_loss:
        if submit_order(symbol, local_position_qty, OrderSide.BUY, price):
            llm_message = f"[LIVE] SL HIT (SHORT): FULL EXIT {local_position_qty} @ {price:.2f}"
            local_position_qty = 0
            local_position_side = None
    
        # === PYRAMIDING LOGIC ===
    # === CONTINUE LONG PYRAMID ===
    if signal == 1:
        if local_position_side == "long":
            if pyramid_layer < MAX_PYRAMID_LAYERS:
                submit_order(symbol, quantity, OrderSide.BUY, price)
                local_position_qty += quantity
                pyramid_layer += 1
                llm_message = f"[LIVE] PYRAMID ADD (LONG): BUY {quantity} {symbol} @ {price:.2f} | Layer: {pyramid_layer}"
        elif local_position_side != "long":
            # flatten short and start new long
            if local_position_side == "short" and local_position_qty > 0:
                submit_order(symbol, local_position_qty, OrderSide.BUY, price)
                logger.info(f"[LIVE] FLATTEN SHORT BEFORE LONG: BUY {local_position_qty} {symbol} @ {price:.2f}")
            submit_order(symbol, quantity, OrderSide.BUY, price)
            local_position_qty = quantity
            local_position_side = "long"
            pyramid_layer = 1
            current_stop_loss = price - (atr * 2)
            current_take_profit = price * tp_mult
            entry_price = price
            llm_message = f"[LIVE] LONG ENTRY: BUY {quantity} {symbol} @ {price:.2f} | SL: {current_stop_loss:.2f} | TP: {current_take_profit:.2f}"

    # === CONTINUE SHORT PYRAMID ===
    elif signal == -1:
        if local_position_side == "short":
            if pyramid_layer < MAX_PYRAMID_LAYERS:
                submit_order(symbol, quantity, OrderSide.SELL, price)
                local_position_qty += quantity
                pyramid_layer += 1
                llm_message = f"[LIVE] PYRAMID ADD (SHORT): SELL {quantity} {symbol} @ {price:.2f} | Layer: {pyramid_layer}"
        elif local_position_side != "short":
            # flatten long and start new short
            if local_position_side == "long" and local_position_qty > 0:
                submit_order(symbol, local_position_qty, OrderSide.SELL, price)
                logger.info(f"[LIVE] FLATTEN LONG BEFORE SHORT: SELL {local_position_qty} {symbol} @ {price:.2f}")
            submit_order(symbol, quantity, OrderSide.SELL, price)
            local_position_qty = quantity
            local_position_side = "short"
            pyramid_layer = 1
            current_stop_loss = price + (atr * 2)
            current_take_profit = price * (2 - tp_mult)
            entry_price = price
            llm_message = f"[LIVE] SHORT ENTRY: SELL {quantity} {symbol} @ {price:.2f} | SL: {current_stop_loss:.2f} | TP: {current_take_profit:.2f}"

    # === TRAILING STOP ADJUSTMENT ===
    elif local_position_side == "long":
        current_stop_loss = max(current_stop_loss, price * 0.97)
    elif local_position_side == "short":
        current_stop_loss = min(current_stop_loss, price * 1.03)

    # === DISCORD + LLM ===
    if llm_message:
        logger.info(llm_message)
        prompt = f"Analyze this trade event and provide a short summary with risk commentary:\n{llm_message}"
        response = query_local_llm(prompt)
        if response:
            send_discord_message(f"🤖 LLM Insight:\n{response}")

    portfolio_value = cash + position * price
    peak_portfolio_value = max(peak_portfolio_value, portfolio_value)
    drawdown = (portfolio_value - peak_portfolio_value) / peak_portfolio_value
    portfolio_history.append({
        "Date": datetime.now(),
        "Portfolio_Value": portfolio_value,
        "Cash": cash,
        "Position": position,
        "Price": price,
        "Drawdown": drawdown
    })


stream.subscribe_bars(on_bar, symbol)


stream_started = False
retry_cooldown_seconds = 60

async def run_stream_with_retry():
    global stream_started
    while not stream_started:
        try:
            logger.info("Attempting to start Alpaca stream...")
            await stream.run()
            stream_started = True
        except Exception as e:
            logger.error(f"Stream failed to start: {e}")
            send_discord_message(f"⚠️ Stream failed to start. Retrying in 5 minutes.\nError: {e}")
            try:
                await stream.close()
                await stream.stop()
                logger.info("Closed existing stream before retrying.")
            except Exception as close_err:
                logger.warning(f"Failed to close stream cleanly: {close_err}")
            await asyncio.sleep(retry_cooldown_seconds)

async def main():
    heartbeat_task = asyncio.create_task(send_heartbeat())
    await asyncio.gather(heartbeat_task, run_stream_with_retry())

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    asyncio.run(main())

In [None]:
# multi stock prototype
import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import os
import asyncio
import logging
import pandas as pd
from datetime import datetime, date
from collections import deque
from dotenv import load_dotenv
import requests
import json
from pytz import timezone

from alpaca.trading.client import TradingClient
from alpaca.trading.requests import MarketOrderRequest
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.data.live import StockDataStream

from strategies.strategy_registry.momentum_strategy import MomentumStrategy
from core.position_sizer import DynamicPositionSizer
from indicators.atr import ATRIndicator

load_dotenv(r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\venv\.env")

API_KEY = os.getenv("ALPACA_API_KEY")
API_SECRET = os.getenv("ALPACA_SECRET")
DISCORD_WEBHOOK = os.getenv("DISCORD_WEBHOOK")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("LiveRunner")

trading_client = TradingClient(API_KEY, API_SECRET, paper=True)
stream = StockDataStream(API_KEY, API_SECRET)

symbols = ["AAPL", "MSFT"]

class SymbolContext:
    def __init__(self, symbol):
        self.symbol = symbol
        self.daily_buffer = deque(maxlen=100)
        self.current_day_bar = None
        self.strategy = MomentumStrategy()
        self.sizer = DynamicPositionSizer(risk_percentage=0.07)
        self.portfolio_history = []
        self.peak_portfolio_value = 0
        self.last_trade_date = None
        self.market_closed = False

        path = Path(f"data/data_storage/proc_data/proc_{symbol}_file.json")
        if path.exists():
            with open(path, "r") as f:
                for item in json.load(f)[-100:]:
                    self.daily_buffer.append({
                        "Date": pd.to_datetime(item["Date"], unit='ms'),
                        "Open": item["Open"], "High": item["High"], "Low": item["Low"],
                        "Close": item["Close"], "Volume": item["Volume"],
                        "ATR": item.get("ATR"), "RSI": item.get("RSI"), "Momentum": item.get("Momentum"),
                        "VWAP": item.get("VWAP"), "OBV": item.get("OBV"), "MACD": item.get("MACD"),
                        "MACD_Signal": item.get("MACD_Signal"), "Price_Change": item.get("Price_Change"),
                        "Daily_Return": item.get("Daily_Return"), "Lag_Close_1": item.get("Lag_Close_1")
                    })
            logger.info(f"[{symbol}] Loaded historical bars.")
        else:
            logger.warning(f"[{symbol}] Historical data not found.")

symbol_contexts = {sym: SymbolContext(sym) for sym in symbols}

def is_market_close(timestamp):
    eastern = timezone("US/Eastern")
    et = timestamp.astimezone(eastern)
    return et.hour == 16 and et.minute == 0

def send_discord_message(msg):
    if DISCORD_WEBHOOK:
        try:
            requests.post(DISCORD_WEBHOOK, json={"content": msg})
        except Exception as e:
            logger.error(f"Discord error: {e}")

def query_local_llm(prompt):
    try:
        resp = requests.post("http://localhost:11434/api/generate",
                             json={"model": "deepseek", "prompt": prompt, "stream": False})
        return resp.json().get("response", "").strip()
    except Exception as e:
        logger.error(f"LLM error: {e}")
        return ""

def submit_order(symbol, qty, side):
    try:
        return trading_client.submit_order(MarketOrderRequest(
            symbol=symbol, qty=int(qty), side=side, time_in_force=TimeInForce.DAY))
    except Exception as e:
        logger.error(f"Order error: {e}")
        return None

async def handle_bar_for_symbol(bar, ctx: SymbolContext):
    ts = pd.to_datetime(bar.timestamp)
    bar_date = ts.date()

    if is_market_close(ts) and not ctx.market_closed:
        if ctx.current_day_bar:
            ctx.daily_buffer.append(ctx.current_day_bar)
        df = pd.DataFrame(list(ctx.daily_buffer))
        df.to_csv(f"data/processed/final_{ctx.symbol}_bars.csv", index=False)
        pd.DataFrame(ctx.portfolio_history).to_csv(f"data/processed/portfolio_{ctx.symbol}.csv", index=False)
        if ctx.portfolio_history:
            pv = ctx.portfolio_history[-1]['Portfolio_Value']
            dd = ctx.portfolio_history[-1]['Drawdown']
            msg = f"📈 [{ctx.symbol}] Market closed. Value: ${pv:.2f} | Drawdown: {dd:.2%}"
            send_discord_message(msg)
            insight = query_local_llm(f"Market closed. Summary: {msg}. Risk commentary?")
            if insight:
                send_discord_message(f"🤖 LLM Insight (EOD):\n{insight}")
        ctx.market_closed = True
        return

    if ctx.current_day_bar and ctx.current_day_bar['Date'].date() != bar_date:
        ctx.daily_buffer.append(ctx.current_day_bar)
        ctx.current_day_bar = None

    if ctx.current_day_bar is None:
        ctx.current_day_bar = {"Date": ts, "Open": bar.open, "High": bar.high, "Low": bar.low, "Close": bar.close, "Volume": bar.volume}
    else:
        ctx.current_day_bar["High"] = max(ctx.current_day_bar["High"], bar.high)
        ctx.current_day_bar["Low"] = min(ctx.current_day_bar["Low"], bar.low)
        ctx.current_day_bar["Close"] = bar.close
        ctx.current_day_bar["Volume"] += bar.volume

    logger.info(f"[{ctx.symbol}] {ts} O: {bar.open} H: {bar.high} L: {bar.low} C: {bar.close} V: {bar.volume}")

    df = pd.DataFrame(list(ctx.daily_buffer) + [ctx.current_day_bar])
    if len(df) < 20:
        return
    df_tail = df.tail(20).copy()
    df_tail = ATRIndicator(df_tail).compute()
    df.update(df_tail)
    df = ctx.strategy.generate_signal(df)

    latest = df.iloc[-1]
    signal = latest['Signal']
    price = latest['Close']
    atr = latest['ATR']

    if pd.isna(atr) or atr <= 0:
        return

    atr_25, atr_75 = df['ATR'].quantile(0.25), df['ATR'].quantile(0.75)
    regime = "low_volatility" if atr < atr_25 else "high_volatility" if atr > atr_75 else "normal"

    try:
        cash = float(trading_client.get_account().cash)
    except Exception as e:
        logger.error(f"Acct fetch fail: {e}")
        return

    try:
        position = int(float(trading_client.get_open_position(ctx.symbol).qty))
    except:
        position = 0

    stop = price - (2 * atr)
    qty = ctx.sizer.calculate_position_size(price, stop, cash, regime)
    qty = min(qty, cash // (price * 1.001))

    print(f"\n=== [{ctx.symbol}] LIVE BAR ===\n{df.tail(3)}")
    print(f"Signal: {signal} | ATR: {atr:.2f} | Market: {regime} | SL: {stop:.2f}")

    today = date.today().isoformat()
    if ctx.last_trade_date == today:
        return

    if signal == 1 and qty > 0 and position == 0:
        submit_order(ctx.symbol, qty, OrderSide.BUY)
        ctx.last_trade_date = today
        send_discord_message(f"[LIVE] BUY {qty} {ctx.symbol} @ {price:.2f}")
    elif signal == -1 and position > 0:
        submit_order(ctx.symbol, position, OrderSide.SELL)
        ctx.last_trade_date = today
        send_discord_message(f"[LIVE] SELL {position} {ctx.symbol} @ {price:.2f}")
    elif position > 0 and price <= stop:
        submit_order(ctx.symbol, position, OrderSide.SELL)
        ctx.last_trade_date = today
        send_discord_message(f"[LIVE] STOP-LOSS SELL {position} {ctx.symbol} @ {price:.2f}")

    portfolio_value = cash + position * price
    ctx.peak_portfolio_value = max(ctx.peak_portfolio_value, portfolio_value)
    drawdown = (portfolio_value - ctx.peak_portfolio_value) / ctx.peak_portfolio_value
    ctx.portfolio_history.append({
        "Date": datetime.now(), "Portfolio_Value": portfolio_value, "Cash": cash,
        "Position": position, "Price": price, "Drawdown": drawdown
    })

def make_on_bar(symbol):
    async def on_bar(bar):
        await handle_bar_for_symbol(bar, symbol_contexts[symbol])
    return on_bar

for sym in symbols:
    stream.subscribe_bars(make_on_bar(sym), sym)

async def send_heartbeat():
    await asyncio.sleep(10)
    send_discord_message(":rocket: Bot running.")
    while True:
        await asyncio.sleep(3600)
        send_discord_message(":heartbeat: Bot alive.")

async def run_stream():
    while True:
        try:
            logger.info("Starting stream...")
            await stream.run()
        except Exception as e:
            logger.error(f"Stream error: {e}")
            await asyncio.sleep(300)

async def main():
    await asyncio.gather(send_heartbeat(), run_stream())

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    asyncio.run(main())


In [None]:
#modularized with symbol context for multisymbol management:
import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import os
import asyncio
import logging
import pandas as pd
from datetime import datetime, date
from collections import deque
from dotenv import load_dotenv
import requests
import json
from pytz import timezone

from alpaca.trading.client import TradingClient
from alpaca.trading.requests import MarketOrderRequest
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.data.live import StockDataStream

from strategies.strategy_registry.momentum_strategy import MomentumStrategy
from strategies.strategy_registry.mean_reversion_strategy import MeanReversionStrategy
from core.position_sizer import DynamicPositionSizer
from indicators.atr import ATRIndicator

load_dotenv(r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\venv\.env")

API_KEY = os.getenv("ALPACA_API_KEY")
API_SECRET = os.getenv("ALPACA_SECRET")
DISCORD_WEBHOOK = os.getenv("DISCORD_WEBHOOK")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("MultiSymbolLiveRunner")

trading_client = TradingClient(API_KEY, API_SECRET, paper=True)
stream = StockDataStream(API_KEY, API_SECRET)

symbol_config = {
    "AAPL": [MomentumStrategy()],
    "MSFT": [MomentumStrategy(), MeanReversionStrategy()],
}

class SymbolContext:
    def __init__(self, symbol, strategies):
        self.symbol = symbol
        self.strategies = strategies
        self.daily_buffer = deque(maxlen=100)
        self.current_day_bar = None
        self.sizer = DynamicPositionSizer(risk_percentage=0.07)
        self.portfolio_history = []
        self.peak_portfolio_value = 0
        self.last_trade_date = None
        self.market_closed = False
        self.load_historical_bars()

    def load_historical_bars(self):
        path = Path(f"data/data_storage/proc_data/proc_{self.symbol}_file.json")
        if path.exists():
            with open(path, "r") as f:
                processed_data = json.load(f)
                for item in processed_data[-100:]:
                    self.daily_buffer.append({
                        "Date": pd.to_datetime(item["Date"], unit='ms'),
                        "Open": item["Open"],
                        "High": item["High"],
                        "Low": item["Low"],
                        "Close": item["Close"],
                        "Volume": item["Volume"],
                        "ATR": item.get("ATR"),
                        "RSI": item.get("RSI"),
                        "Momentum": item.get("Momentum"),
                        "VWAP": item.get("VWAP"),
                        "OBV": item.get("OBV"),
                        "MACD": item.get("MACD"),
                        "MACD_Signal": item.get("MACD_Signal"),
                        "Price_Change": item.get("Price_Change"),
                        "Daily_Return": item.get("Daily_Return"),
                        "Lag_Close_1": item.get("Lag_Close_1")
                    })
            logger.info(f"Loaded {len(self.daily_buffer)} preprocessed bars for {self.symbol}.")
        else:
            logger.warning(f"Processed data not found for {self.symbol}.")

    def finalize_bar(self, timestamp):
        if self.current_day_bar:
            self.daily_buffer.append(self.current_day_bar)
            logger.info(f"Finalized bar for {self.symbol} on {timestamp.date()}")
            self.current_day_bar = None

    def update_bar(self, bar):
        timestamp = pd.to_datetime(bar.timestamp)
        bar_date = timestamp.date()

        if self.current_day_bar and self.current_day_bar['Date'].date() != bar_date:
            self.finalize_bar(timestamp)

        if self.current_day_bar is None:
            self.current_day_bar = {
                "Date": timestamp,
                "Open": bar.open,
                "High": bar.high,
                "Low": bar.low,
                "Close": bar.close,
                "Volume": bar.volume
            }
        else:
            self.current_day_bar["High"] = max(self.current_day_bar["High"], bar.high)
            self.current_day_bar["Low"] = min(self.current_day_bar["Low"], bar.low)
            self.current_day_bar["Close"] = bar.close
            self.current_day_bar["Volume"] += bar.volume

    def process_signal(self):
        if not self.current_day_bar:
            return

        df = pd.DataFrame(list(self.daily_buffer) + [self.current_day_bar])
        if len(df) < 20:
            return

        df_tail = df.tail(20).copy()
        df_tail = ATRIndicator(df_tail).compute()
        df.update(df_tail)

        for strategy in self.strategies:
            df = strategy.generate_signal(df)

        latest = df.iloc[-1]
        signal = latest['Signal']
        price = latest['Close']
        atr_value = latest['ATR']

        if pd.isna(atr_value) or atr_value <= 0:
            return

        atr_25 = df['ATR'].quantile(0.25)
        atr_75 = df['ATR'].quantile(0.75)

        market_conditions = (
            "low_volatility" if atr_value < atr_25 else
            "high_volatility" if atr_value > atr_75 else
            "normal"
        )

        try:
            account = trading_client.get_account()
            cash = float(account.cash)
        except Exception as e:
            logger.error(f"[{self.symbol}] Failed to fetch account info: {e}")
            return

        try:
            position_data = trading_client.get_open_position(self.symbol)
            position = int(float(position_data.qty))
        except:
            position = 0

        stop_loss_price = price - (atr_value * 2)
        quantity = self.sizer.calculate_position_size(price, stop_loss_price, cash, market_conditions)
        trade_fee = 0.001 * price * quantity
        max_affordable_qty = cash // (price + trade_fee)
        quantity = min(quantity, max_affordable_qty)

        today_str = date.today().isoformat()
        if self.last_trade_date == today_str:
            return

        msg = None
        if signal == 1 and quantity > 0 and position == 0:
            submit_order(self.symbol, quantity, OrderSide.BUY)
            msg = f"BUY {quantity} {self.symbol} @ {price:.2f}"
            self.last_trade_date = today_str
        elif signal == -1 and position > 0:
            submit_order(self.symbol, position, OrderSide.SELL)
            msg = f"SELL {position} {self.symbol} @ {price:.2f}"
            self.last_trade_date = today_str
        elif position > 0 and price <= stop_loss_price:
            submit_order(self.symbol, position, OrderSide.SELL)
            msg = f"STOP-LOSS SELL {position} {self.symbol} @ {price:.2f}"
            self.last_trade_date = today_str
         
        if msg:
            logger.info(msg)
            insight = query_local_llm(f"Analyze this trade: {msg}")
            send_discord_message(f"{msg}\n🤖 {insight}")

        portfolio_value = cash + position * price
        self.peak_portfolio_value = max(self.peak_portfolio_value, portfolio_value)
        drawdown = (portfolio_value - self.peak_portfolio_value) / self.peak_portfolio_value
        self.portfolio_history.append({
            "Date": datetime.now(),
            "Portfolio_Value": portfolio_value,
            "Cash": cash,
            "Position": position,
            "Price": price,
            "Drawdown": drawdown
        })

def submit_order(symbol, qty, side):
    order = MarketOrderRequest(symbol=symbol, qty=int(qty), side=side, time_in_force=TimeInForce.DAY)
    try:
        return trading_client.submit_order(order)
    except Exception as e:
        logger.error(f"Order failed: {e}")
        return None

def query_local_llm(prompt):
    try:
        res = requests.post("http://localhost:11434/api/generate", json={"model": "deepseek", "prompt": prompt, "stream": False})
        return res.json().get("response", "")
    except Exception as e:
        logger.error(f"LLM query failed: {e}")
        return ""

def send_discord_message(msg):
    if DISCORD_WEBHOOK:
        try:
            requests.post(DISCORD_WEBHOOK, json={"content": msg})
        except Exception as e:
            logger.error(f"Discord send failed: {e}")

symbol_contexts = {symbol: SymbolContext(symbol, strategies) for symbol, strategies in symbol_config.items()}

def is_market_close(timestamp_utc):
    return timestamp_utc.astimezone(timezone("US/Eastern")).hour == 16

async def on_bar(bar):
    ctx = symbol_contexts.get(bar.symbol)
    if ctx:
        ctx.update_bar(bar)
        ctx.process_signal()
        if is_market_close(pd.to_datetime(bar.timestamp)):
            ctx.finalize_bar(pd.to_datetime(bar.timestamp))

for symbol in symbol_contexts:
    stream.subscribe_bars(on_bar, symbol)

async def send_heartbeat():
    await asyncio.sleep(5)
    send_discord_message(":rocket: Bot running with multi-symbol support")
    while True:
        await asyncio.sleep(3600)
        send_discord_message(":heartbeat: Still running...")

async def run():
    await asyncio.gather(
        send_heartbeat(),
        stream.run()
    )

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    asyncio.run(run())


In [None]:
#--------------------------------------------------------------------------------------------------------------------------#
# system level stuff to make sure we get the right root and can import the stuff we want
import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
#--------------------------------------------------------------------------------------------------------------------------#
# wokring Schwab version
import json
import requests
import websockets
import asyncio
import nest_asyncio
import pandas as pd
from collections import defaultdict, deque
from datetime import datetime
from data.streaming.authenticator import Authenticator
from utils.logger import Logger
from utils.configloader import ConfigLoader
from data.streaming.schwab_client import SchwabClient
from core.eventhandler import EventHandler  

# ---------------------- Historical Loader ---------------------- #
class HistoricalBarLoader:
    def __init__(self, path):
        self.path = path

    def load_last_n_bars(self, symbol: str, n=10):
        try:
            with open(f"{self.path}/proc_{symbol}_file.json", 'r') as f:
                data = json.load(f)
            df = pd.DataFrame(data)
            df = df.dropna(subset=["Open", "High", "Low", "Close", "Volume"])
            df = df.sort_values(by="Date")
            return df.iloc[-n:].to_dict("records")
        except Exception as e:
            print(f"[ERROR] Failed to load historical bars for {symbol}: {e}")
            return []

# ---------------------- Strategy ---------------------- #
class MomentumStrategy:
    def generate_signal(self, bars):
        if len(bars) < 5:
            return None
        last_bar = bars[-1]
        if last_bar["close"] > last_bar["open"]:
            return {"action": "BUY", "confidence": 0.75}
        elif last_bar["close"] < last_bar["open"]:
            return {"action": "SELL", "confidence": 0.75}
        return None

# ---------------------- Executor ---------------------- #
class MockExecutor:
    def __init__(self):
        self.config = ConfigLoader().load_config()
        self.logger = Logger('app.log', 'MockExecutor', log_dir=f'{self.config['folders']['logs']}').get_logger()
        self.orders = []

    def execute(self, symbol: str, signal: dict):
        action = signal.get("action")
        confidence = signal.get("confidence", 0)
        order_log = {"symbol": symbol, "action": action, "confidence": confidence}
        self.orders.append(order_log)
        self.logger.info(f"[MOCK ORDER] {action} {symbol} (confidence: {confidence})")

# ---------------------- Strategy Handler ---------------------- #
def strategy_on_bar(event, strategy, executor, handler):
    data = event.payload  # ✅ This is correct now
    symbol = data["symbol"]
    bar = data["bar"]

    handler.bar_windows[symbol].append(bar)

    if len(handler.bar_windows[symbol]) < 5:
        handler.logger.info(f"[Handler] Waiting for full 5-bar window for {symbol}")
        return

    recent_bars = list(handler.bar_windows[symbol])
    signal = strategy.generate_signal(recent_bars)

    #handler.logger.info(f"[Handler] Strategy signal for {symbol}: {signal}")
    if signal:
        executor.execute(symbol, signal)

    if len(handler.bar_windows[symbol]) > 5:
        handler.bar_windows[symbol].popleft()


# ---------------------- Streaming Client ---------------------- #
class SchwabStreamingClient():
    def __init__(self, apikey, secretkey):
        self.authenticator = Authenticator()
        self.config = ConfigLoader().load_config()
        self.apikey = apikey
        self.secretkey = secretkey
        self.streamer_info = None
        self.connection = None
        self.event_handler = EventHandler()
        self.streaming_logger = Logger('app.log', 'SchwabStreamingClient', log_dir=f'{self.config['folders']['logs']}').get_logger()
        self.tick_buffer = defaultdict(list)
        self.bar_data = defaultdict(deque)
        self.bar_interval = 5

    async def websocket_client(self, symbols):
        url = r"https://api.schwabapi.com/trader/v1/userPreference"
        headers = {'Authorization': f"Bearer {self.authenticator.access_token()}"}
        try:
            response = requests.get(headers=headers, url=url)
            response.raise_for_status()
            user_preference = response.json()
            self.streamer_info = user_preference['streamerInfo'][0]
            self.streaming_logger.info("Retrieved user preferences successfully")
        except Exception as e:
            self.streaming_logger.error(f"Failed to retrieve user preferences: {e}")
            return

        login_request = {
            'service': 'ADMIN', 'requestid': 0, 'command': 'LOGIN',
            'SchwabClientCustomerId': self.streamer_info['schwabClientCustomerId'],
            'SchwabClientCorrelId': self.streamer_info['schwabClientCorrelId'],
            'parameters': {
                'Authorization': self.authenticator.access_token(),
                'SchwabClientChannel': self.streamer_info['schwabClientChannel'],
                'SchwabClientFunctionId': self.streamer_info['schwabClientFunctionId']
            }
        }

        symbol_request = {
            'service': 'LEVELONE_EQUITIES', 'requestid': 1, 'command': 'SUBS',
            'SchwabClientCustomerId': self.streamer_info['schwabClientCustomerId'],
            'SchwabClientCorrelId': self.streamer_info['schwabClientCorrelId'],
            'parameters': {
                'keys': ','.join(symbols),
                'fields': ','.join(str(field) for field in range(0, 42))
            }
        }

        async with websockets.connect(self.streamer_info['streamerSocketUrl']) as ws:
            await ws.send(json.dumps(login_request))
            while True:
                try:
                    message = await ws.recv()
                    self.streaming_logger.info("Login Message Received:")
                    self.streaming_logger.info(message)
                    break
                except websockets.ConnectionClosed:
                    self.streaming_logger.error("Connection closed during login")
                    return

            await ws.send(json.dumps(symbol_request))

            while True:
                try:
                    message = await ws.recv()
                    parsed_message = json.loads(message)
                    content = parsed_message.get('data', [{}])[0].get('content', [])

                    for item in content:
                        symbol = item.get('key')
                        trade_price = item.get('3')
                        trade_size = item.get('9')

                        if symbol and trade_price and trade_size:
                            try:
                                price = float(trade_price)
                                volume = int(trade_size)
                                timestamp = datetime.utcnow()
                                self.tick_buffer[symbol].append((timestamp, price, volume))
                            except ValueError:
                                continue

                except websockets.ConnectionClosed:
                    self.streaming_logger.error("WebSocket connection closed")
                    break
                except json.JSONDecodeError:
                    self.streaming_logger.warning("Received non-JSON message")
                    continue

    async def aggregate_bars(self):
        while True:
            await asyncio.sleep(self.bar_interval)
            now = datetime.utcnow()

            for symbol, ticks in self.tick_buffer.items():
                if not ticks:
                    continue

                prices = [p for _, p, _ in ticks]
                volumes = [v for _, _, v in ticks]
                bar = {
                    'timestamp': now.replace(second=0, microsecond=0),
                    'symbol': symbol,
                    'open': prices[0],
                    'high': max(prices),
                    'low': min(prices),
                    'close': prices[-1],
                    'volume': sum(volumes)
                }
                self.bar_data[symbol].append(bar)
                self.streaming_logger.info(f"[BAR] {symbol} - {bar}")
                self.tick_buffer[symbol] = []
                self.event_handler.emit("BAR_CREATED", {"symbol": symbol, "bar": bar})

    async def run(self, symbols):
        await asyncio.gather(
            self.websocket_client(symbols),
            self.aggregate_bars()
        )

# ---------------------- Main Entry ---------------------- #
if __name__ == "__main__":
    symbols = ['AAPL', 'MSFT', 'TSLA']
    streamer = SchwabStreamingClient(apikey="your_api", secretkey="your_secret")

    loader = HistoricalBarLoader("C:/Users/kwasi/OneDrive/Documents/Personal Projects/schwab_trader/data/data_storage/proc_data")
    handler = EventHandler()
    strategy = MomentumStrategy()
    executor = MockExecutor()

    for symbol in symbols:
        history = loader.load_last_n_bars(symbol, n=4)
        handler.bar_windows[symbol].extend(history)

    handler.subscribe("BAR_CREATED", lambda event: strategy_on_bar(event, strategy, executor, handler))
    streamer.event_handler = handler

    nest_asyncio.apply()
    asyncio.run(streamer.run(symbols))



In [None]:
import asyncio
import websockets
import json
import requests
from datetime import datetime, UTC, timedelta
from collections import defaultdict
from core.eventhandler import EventHandler  # Adjust import as needed

class SchwabStreamingClient():
    def __init__(self, apikey, secretkey):
        self.authenticator = Authenticator()
        self.config = ConfigLoader().load_config()
        self.apikey = apikey
        self.secretkey = secretkey
        self.streamer_info = None
        self.connection = None
        self.event_handler = EventHandler()
        self.streaming_logger = Logger('app.log', 'SchwabStreamingClient', log_dir=f'{self.config['folders']['logs']}').get_logger()
        self.tick_buffer = defaultdict(list)
        self.bar_data = defaultdict(deque)
        self.bar_interval = 5

    async def websocket_client(self, symbols):
        url = r"https://api.schwabapi.com/trader/v1/userPreference"
        headers = {'Authorization': f"Bearer {self.authenticator.access_token()}"}
        self.handler = EventHandler()
        self.last_emit_time = defaultdict(lambda: datetime.now(UTC))
        self.live_bars = defaultdict(lambda: {
            "open": None, "high": None, "low": None, "close": None, "volume": 0
        })

        try:
            response = requests.get(headers=headers, url=url)
            response.raise_for_status()
            user_preference = response.json()
            self.streamer_info = user_preference['streamerInfo'][0]
            self.streaming_logger.info("Retrieved user preferences successfully")
        except Exception as e:
            self.streaming_logger.error(f"Failed to retrieve user preferences: {e}")
            return

        login_request = {
            'service': 'ADMIN', 'requestid': 0, 'command': 'LOGIN',
            'SchwabClientCustomerId': self.streamer_info['schwabClientCustomerId'],
            'SchwabClientCorrelId': self.streamer_info['schwabClientCorrelId'],
            'parameters': {
                'Authorization': self.authenticator.access_token(),
                'SchwabClientChannel': self.streamer_info['schwabClientChannel'],
                'SchwabClientFunctionId': self.streamer_info['schwabClientFunctionId']
            }
        }

        symbol_request = {
            'service': 'LEVELONE_EQUITIES', 'requestid': 1, 'command': 'SUBS',
            'SchwabClientCustomerId': self.streamer_info['schwabClientCustomerId'],
            'SchwabClientCorrelId': self.streamer_info['schwabClientCorrelId'],
            'parameters': {
                'keys': ','.join(symbols),
                'fields': ','.join(str(field) for field in range(0, 42))
            }
        }

        async with websockets.connect(self.streamer_info['streamerSocketUrl']) as ws:
            await ws.send(json.dumps(login_request))

            # Wait for login acknowledgment
            while True:
                try:
                    message = await ws.recv()
                    self.streaming_logger.info("Login Message Received:")
                    self.streaming_logger.info(message)
                    break
                except websockets.ConnectionClosed:
                    self.streaming_logger.error("Connection closed during login")
                    return

            await ws.send(json.dumps(symbol_request))

            # Begin processing tick messages
            while True:
                try:
                    message = await ws.recv()
                    parsed_message = json.loads(message)
                    content = parsed_message.get('data', [{}])[0].get('content', [])

                    for item in content:
                        symbol = item.get('key')
                        trade_price = item.get('3')
                        trade_size = item.get('9')

                        if symbol and trade_price and trade_size:
                            try:
                                price = float(trade_price)
                                volume = int(trade_size)
                                now = datetime.now(UTC)

                                bar = self.live_bars[symbol]

                                if bar["open"] is None:
                                    bar["open"] = price
                                    bar["high"] = price
                                    bar["low"] = price
                                else:
                                    bar["high"] = max(bar["high"], price)
                                    bar["low"] = min(bar["low"], price)

                                bar["close"] = price
                                bar["volume"] += volume

                                # Emit the bar every 60 seconds
                                if (now - self.last_emit_time[symbol]).total_seconds() >= 5:
                                    bar_payload = {
                                        "timestamp": now,
                                        "symbol": symbol,
                                        "open": bar["open"],
                                        "high": bar["high"],
                                        "low": bar["low"],
                                        "close": bar["close"],
                                        "volume": bar["volume"]
                                    }
                                    self.handler.emit("BAR_CREATED", {
                                        "symbol": symbol,
                                        "bar": bar_payload
                                    })
                                    self.last_emit_time[symbol] = now
                                    self.live_bars[symbol] = {
                                        "open": None, "high": None, "low": None, "close": None, "volume": 0
                                    }

                            except ValueError:
                                continue

                except websockets.ConnectionClosed:
                    self.streaming_logger.error("WebSocket connection closed")
                    break
                except json.JSONDecodeError:
                    self.streaming_logger.warning("Received non-JSON message")
                    continue

    async def run(self, symbols):
        await asyncio.gather(
        self.websocket_client(symbols),
                )
        

if __name__ == "__main__":
    symbols = ['AAPL', 'MSFT', 'TSLA']
    streamer = SchwabStreamingClient(apikey="your_api", secretkey="your_secret")

    loader = HistoricalBarLoader("C:/Users/kwasi/OneDrive/Documents/Personal Projects/schwab_trader/data/data_storage/proc_data")
    handler = EventHandler()
    strategy = MomentumStrategy()
    executor = MockExecutor()

    for symbol in symbols:
        history = loader.load_last_n_bars(symbol, n=4)
        #handler.bar_windows[symbol].extend(history)

    handler.subscribe("BAR_CREATED", lambda event: strategy_on_bar(event, strategy, executor, handler))
    streamer.event_handler = handler

    nest_asyncio.apply()
    asyncio.run(streamer.run(symbols))

In [None]:
#--------------------------------------------------------------------------------------------------------------------------#
# system level stuff to make sure we get the right root and can import the stuff we want # THIS VERSION IS WORKING!
import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
#--------------------------------------------------------------------------------------------------------------------------#
# mock streamer emits bars to the system and we can hadle them via executor shows how we can deal with them just random changes in price
import asyncio
import nest_asyncio
import logging
import random
import json
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, UTC
from collections import defaultdict, deque
import matplotlib.pyplot as plt
from indicators.atr import ATRIndicator
from strategies.strategy_registry.stochastic_strategy import StochasticStrategy
from strategies.strategy_registry.rsi_strategy import RSIStrategy
from strategies.strategy_registry.macd_strategy import MACDStrategy
from strategies.strategy_registry.momentum_strategy import MomentumStrategy
from core.position_sizer import DynamicPositionSizer
from core.eventhandler import EventHandler

# ---------------------------- Logger Setup ----------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("LiveRunner")
logger.setLevel(logging.DEBUG)

# ---------------------------- Historical Loader ----------------------------
class HistoricalBarLoader:
    def __init__(self, path):
        self.path = path

    def load_last_n_bars(self, symbol: str, n=99):
        file_path = Path(self.path) / f"proc_{symbol}_file.json"
        if not file_path.exists():
            logger.warning(f"No historical file found for {symbol} at {file_path}")
            return []
        with open(file_path, "r") as f:
            data = json.load(f)
        return data[-n:]
    
    def get_latest_close_price(self, symbol: str) -> float:
        bars = self.load_last_n_bars(symbol, n=1)
        if bars:
            return bars[0].get("Close") or bars[0].get("close") or 300.0
        return 300.0  # fallback if missing
#--------------------------------------- Live Plotter -------------------------------------------------#
class LivePlotter:
    def __init__(self, symbols, window=100):
        self.symbols = symbols
        self.window = window

        self.price_buffers = {s: deque(maxlen=window) for s in symbols}
        self.timestamps = {s: deque(maxlen=window) for s in symbols}
        self.volume_buffers = {s: deque(maxlen=window) for s in symbols}
        self.signal_buffers = {s: deque(maxlen=window) for s in symbols}
        self.pnl_buffers = {s: deque(maxlen=window) for s in symbols}
        self.trade_markers = {s: [] for s in symbols}  # list of (timestamp, price, type)

        plt.ion()
        self.fig, self.axes = plt.subplots(len(symbols), 1, figsize=(10, 3.5 * len(symbols)), sharex=True)
        if len(symbols) == 1:
            self.axes = [self.axes]

    def update_bar(self, symbol, bar):
        self.price_buffers[symbol].append(bar["close"])
        self.timestamps[symbol].append(bar["timestamp"])
        self.volume_buffers[symbol].append(bar.get("volume", 0))

    def record_signal(self, symbol, timestamp, signal):
        self.signal_buffers[symbol].append((timestamp, signal))

    def record_trade(self, symbol, timestamp, price, side):
        self.trade_markers[symbol].append((timestamp, price, side))  # side = "BUY" or "SELL"

    def record_pnl(self, symbol, timestamp, pnl):
        self.pnl_buffers[symbol].append((timestamp, pnl))

    def draw(self):
        for i, symbol in enumerate(self.symbols):
            ax = self.axes[i]
            ax.clear()

            # Plot close price
            ax.plot(self.timestamps[symbol], self.price_buffers[symbol], label="Price", color="black")

            # Plot trades
            for ts, px, side in self.trade_markers[symbol]:
                if ts >= self.timestamps[symbol][0]:
                    color = "green" if side == "BUY" else "red"
                    ax.scatter(ts, px, color=color, marker="^" if side == "BUY" else "v", s=50, zorder=5)

            # Plot signals (optional triangles)
            for ts, sig in self.signal_buffers[symbol]:
                if ts >= self.timestamps[symbol][0]:
                    if sig == 1:
                        ax.plot(ts, self.price_buffers[symbol][-1], "^", color="blue", alpha=0.3)
                    elif sig == -1:
                        ax.plot(ts, self.price_buffers[symbol][-1], "v", color="blue", alpha=0.3)

            # Plot volume on secondary Y axis
            ax2 = ax.twinx()
            ax2.bar(self.timestamps[symbol], self.volume_buffers[symbol], color="gray", alpha=0.2, width=0.001)
            ax2.set_yticks([])

            # Plot cumulative PnL (optional)
            pnl_vals = self.pnl_buffers[symbol]
            if pnl_vals:
                pnl_ts, pnl_series = zip(*pnl_vals)
                ax.plot(pnl_ts, pnl_series, label="PnL", color="orange", linestyle="--")

            ax.set_title(symbol)
            ax.grid(True)
            ax.legend()

        self.fig.tight_layout()
        plt.pause(0.001)
# ---------------------------- GBM Simulator ----------------------------
import numpy as np
import random
import logging
from datetime import datetime
from pytz import UTC

import math


class GBMSimulator:
    """
    Simulates Geometric Brownian Motion (GBM) with cyclical drift and occasional price shocks.
    Generates synthetic OHLCV bars with symbol-specific sinusoidal drift and volatility.
    """

    def __init__(self, symbols, base_price=300.0, log_prices=False):
        self.symbols = symbols
        self.price_state = {
            symbol: base_price[symbol] if isinstance(base_price, dict) else base_price + random.uniform(-50, 50)
            for symbol in symbols
        }
        self.volatility = {
            symbol: random.uniform(0.01, 0.03) for symbol in symbols
        }

        # ⬇️ REPLACE random drift with time-tracked sinusoidal drift
        self.t = {symbol: 0 for symbol in symbols}              # bar counter per symbol
        self.cycle_length = 1000                                # bars per full up/down cycle
        self.max_drift = 0.0006                                 # max daily drift ~15% annualized

        self.dt = 1 / 390  # 1-minute bar
        self.log_prices = log_prices
        self.logger = logging.getLogger("MockStream")
        self.logger.setLevel(logging.DEBUG)

        # Shock parameters
        self.shock_probability = 0.002
        self.shock_magnitude_range = (0.05, 0.10)
        self.shock_vol_boost = 1.5

    def maybe_apply_shock(self, symbol, price):
        if random.random() < self.shock_probability:
            direction = random.choice([-1, 1])
            magnitude = random.uniform(*self.shock_magnitude_range)
            shock_factor = 1 + direction * magnitude
            shocked_price = max(1.0, price * shock_factor)
            self.volatility[symbol] *= self.shock_vol_boost
            self.logger.warning(f"[{symbol}] *** PRICE SHOCK *** | {'UP' if direction > 0 else 'DOWN'} {magnitude:.2%} → {shocked_price:.2f}")
            return shocked_price
        return price

    def generate_bar(self, symbol):
        prev_price = self.price_state[symbol]
        sigma = self.volatility[symbol]
        Z = np.random.normal(0, 1)

        # ⬇️ Update time and compute sinusoidal drift
        self.t[symbol] += 1
        t = self.t[symbol]
        mu = self.max_drift * math.sin(2 * math.pi * t / self.cycle_length)

        # GBM step
        change = (mu - 0.5 * sigma ** 2) * self.dt + sigma * np.sqrt(self.dt) * Z
        new_price = max(1.0, prev_price * np.exp(change))

        # Possibly apply a shock
        close = self.maybe_apply_shock(symbol, new_price)

        # OHLCV construction
        open_ = prev_price
        wick_range = abs(np.random.normal(0, 0.002))
        high = max(open_, close) * (1 + wick_range)
        low = min(open_, close) * (1 - wick_range)
        volume = random.randint(500, 5000)

        self.price_state[symbol] = close

        bar = {
            "timestamp": datetime.now(UTC),
            "symbol": symbol,
            "open": round(open_, 2),
            "high": round(high, 2),
            "low": round(low, 2),
            "close": round(close, 2),
            "volume": volume
        }

        if self.log_prices:
            self.logger.info(f"[{symbol}] Bar - O: {bar['open']} | H: {bar['high']} | L: {bar['low']} | C: {bar['close']} | Vol: {bar['volume']}")

        return bar

    def update_all(self):
        return {symbol: self.generate_bar(symbol) for symbol in self.symbols}


# ---------------------------- Strategy Router ----------------------------
class StrategyRouter:
    """
    Maps each stock symbol to its preferred trading strategy instance.
    Allows you to run multiple strategies concurrently across different stocks.
    """
    def __init__(self, default_strategy, custom_strategies=None):
        self.default_strategy = default_strategy
        self.custom_strategies = custom_strategies or {}  # e.g., {"AAPL": RSIStrategy(), "TSLA": StochasticStrategy()}

    def get_strategy(self, symbol):
        return self.custom_strategies.get(symbol, self.default_strategy)

# ---------------------------- Drawdown Monitor ----------------------------
from collections import defaultdict
import logging

class DrawdownMonitor:
    """
    Tracks per-symbol drawdowns and disables trading if drawdown exceeds a threshold.
    Also supports logging, unlocking, and manual reset.
    Useful for capital preservation during live trading or simulation.
    """

    def __init__(self, max_drawdown=0.35):
        self.max_drawdown = max_drawdown
        self.peak = {}  # Fix: track first value instead of 0
        self.locked = defaultdict(lambda: False)
        self.logger = logging.getLogger("DrawdownMonitor")
        self.logger.setLevel(logging.DEBUG)

    def update(self, symbol, portfolio_value):
        if symbol not in self.peak:
            self.peak[symbol] = portfolio_value
            self.logger.debug(f"[InitPeak] {symbol} initialized at {portfolio_value:.2f}")
            return True

        if self.locked[symbol]:
            drawdown = (portfolio_value - self.peak[symbol]) / self.peak[symbol]
            if portfolio_value >= 0.85 * self.peak[symbol]:
                self.locked[symbol] = False
                self.logger.info(f"[UNLOCKED] {symbol} trading re-enabled after drawdown recovery.")
            return not self.locked[symbol]

        # Update peak if new high
        if portfolio_value > self.peak[symbol]:
            self.peak[symbol] = portfolio_value

        drawdown = (portfolio_value - self.peak[symbol]) / self.peak[symbol]
        if drawdown < -self.max_drawdown:
            self.locked[symbol] = True
            self.logger.warning(f"[LOCKED] {symbol} drawdown exceeded: {drawdown:.2%}")
            return False

        self.logger.debug(
            f"[DrawdownCheck] {symbol} | Peak: {self.peak[symbol]:.2f} | Current: {portfolio_value:.2f} | Drawdown: {drawdown:.2%} | Locked: {self.locked[symbol]}"
        )
        return True

    def is_locked(self, symbol):
        """
        Check if trading is locked for a given symbol.
        """
        return self.locked[symbol]

    def reset(self, symbol):
        """
        Fully reset the lock and peak value for a given symbol.
        """
        self.logger.info(f"[RESET] {symbol} lock and peak reset.")
        self.locked[symbol] = False
        self.peak[symbol] = 0.0

    def unlock(self, symbol):
        """
        Unlock trading without resetting the peak.
        Use when drawdown improves above safety buffer.
        """
        if self.locked[symbol]:
            self.locked[symbol] = False
            self.logger.info(f"[UNLOCKED] {symbol} trading re-enabled after drawdown recovery.")

    def record_drawdown(self, symbol, drawdown):
        """
        Pass in current drawdown to check against threshold.
        Should be called after drawdown calculation.
        """
        if self.locked[symbol]:
            return

        if drawdown < -self.max_drawdown:
            self.locked[symbol] = True
            self.logger.warning(f"[LOCKED] {symbol} drawdown exceeded: {drawdown:.2%}")

# ---------------------------- Mock Executor ----------------------------
class MockExecutor:
    def __init__(self):
        self.logger = logging.getLogger("MockExecutor")
        self.logger.setLevel(logging.DEBUG)
        self.peak_portfolio_value = defaultdict(lambda: 0)
        self.portfolio_history = defaultdict(list)
        self.cash = 100_000  # Shared pool across all symbols
        self.position = defaultdict(int)
        self.total_fees = defaultdict(float)
        self.trailing_stop = defaultdict(lambda: None)
        self.sizer = DynamicPositionSizer(risk_percentage=0.07)
        self.MIN_TRADE_QTY = 1
        self.MIN_FORCE_EXIT_QTY = 10
        self.MAX_PARTIAL_EXITS = 5
        self.partial_exit_counts = defaultdict(int)
        self.drawdown_monitor = DrawdownMonitor(max_drawdown=0.3)
        self.entry_prices = defaultdict(float)
        self.realized_pnl = defaultdict(float)
        self.unrealized_pnl = defaultdict(float)
        self.market_regimes = defaultdict(lambda: "unknown")
        self.last_price = {}
        self.take_profit = defaultdict(lambda: None)
        # Adaptive TP multipliers per regime
        self.TP_MULTIPLIERS = {
            "low_volatility": 1.03,
            "medium_volatility": 1.05,
            "high_volatility": 1.08,
            "unknown": 1.04
        }

    def log_trade_details(self, symbol, action_type, qty, price, sl=None, ts=None, fee=0.0, pnl=None, cash_before=None, cash_after=None, pos_before=None, pos_after=None, atr=None, regime=None):
        now = datetime.now(UTC)
        self.logger.info(
            f"[{now}] [{symbol}] {action_type} | "
            f"Qty: {qty:.1f} | "
            f"Price: ${price:.2f} | "
            f"{f'Stop Loss: ${sl:.2f} | ' if sl else ''}"
            f"{f'Trail Stop: ${ts:.2f} | ' if ts else ''}"
            f"Cash: ${cash_before:.2f} → ${cash_after:.2f} | "
            f"Fee: ${fee:.2f} | "
            f"Pos: {pos_before:.1f} → {pos_after:.1f} | "
            f"{f'PnL: ${pnl:.2f} | ' if pnl is not None else ''}"
            f"ATR: {atr:.2f} | "
            f"Regime: {regime}"
        )

    def force_exit_position(self, symbol, price, atr_value, stop_loss_price, market_conditions):
        qty = abs(self.position[symbol])
        if qty < self.MIN_TRADE_QTY:
            self.logger.debug(f"[{symbol}] FORCE EXIT SKIPPED: Qty={qty} < MIN_TRADE_QTY={self.MIN_TRADE_QTY}")
            return

        cash_before = self.cash
        pos_before = self.position[symbol]
        trade_fee = 0.001 * price * qty

        if self.position[symbol] > 0:
            proceeds = price * qty - trade_fee
            entry_price = stop_loss_price + (2 * atr_value)
            pnl = proceeds - (entry_price * qty)
            self.cash += proceeds
            self.position[symbol] = 0
            action = "FORCED SELL"
        else:
            cost = price * qty + trade_fee
            entry_price = stop_loss_price - (2 * atr_value)
            pnl = -(cost - (entry_price * qty))
            self.cash -= cost
            self.position[symbol] = 0
            action = "FORCED COVER"

        self.total_fees[symbol] += trade_fee
        self.realized_pnl[symbol] += pnl
        self.log_trade_details(symbol, action, qty, price, fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                               pos_before=pos_before, pos_after=0, pnl=pnl, atr=atr_value, regime=market_conditions)
        self.partial_exit_counts[symbol] = 0
    
    def update_unrealized_pnl(self, latest_prices: dict):
        for symbol in self.positions:
            qty = self.positions[symbol]
            entry_price = self.entry_prices.get(symbol, 0)
            current_price = latest_prices.get(symbol, entry_price)
            
            if qty != 0:
                direction = 1 if qty > 0 else -1
                pnl = (current_price - entry_price) * abs(qty) * direction
                self.unrealized_pnl[symbol] = pnl
            else:
                self.unrealized_pnl[symbol] = 0.0
    
    def log_portfolio_status(self, symbol, price):
        self.last_price[symbol] = price
        pos = self.position[symbol]
        entry_price = self.entry_prices.get(symbol, price)

        # === Unrealized PnL ===
        if pos > 0:
            unrealized_pnl = (price - entry_price) * pos
        elif pos < 0:
            unrealized_pnl = (entry_price - price) * abs(pos)
        else:
            unrealized_pnl = 0.0
        self.unrealized_pnl[symbol] = unrealized_pnl

        # === Realized PnL ===
        realized = self.realized_pnl.get(symbol, 0.0)

        # === Portfolio Value (symbol-specific) ===
        portfolio_value = self.cash + abs(pos) * price

        # === Drawdown BEFORE updating peak ===
        prev_peak = self.peak_portfolio_value[symbol]
        drawdown = (portfolio_value - prev_peak) / prev_peak if prev_peak != 0 else 0.0
        self.peak_portfolio_value[symbol] = max(prev_peak, portfolio_value)

        # === Total portfolio value ===
        total_value = self.cash
        for sym, qty in self.position.items():
            current_price = self.last_price.get(sym, price)
            total_value += abs(qty) * current_price

        total_unrealized = sum(self.unrealized_pnl.values())
        total_realized = sum(self.realized_pnl.values())

        # === Record history ===
        self.portfolio_history[symbol].append({
            "Date": datetime.now(UTC),
            "Portfolio_Value": portfolio_value,
            "Cash": self.cash,
            "Position": pos,
            "Price": price,
            "Drawdown": drawdown,
            "Fees": self.total_fees[symbol],
            "Unrealized_PnL": unrealized_pnl,
            "Realized_PnL": realized,
            "Entry_Price": entry_price
        })

        # === Drawdown Monitoring ===
        self.drawdown_monitor.record_drawdown(symbol, drawdown)
        if self.drawdown_monitor.is_locked(symbol) and drawdown > -0.10:
            self.drawdown_monitor.unlock(symbol)

        # === Logging ===
        self.logger.info(
            f"[{symbol}] STATUS | Port: ${portfolio_value:,.2f} | Total: ${total_value:,.2f} | "
            f"Cash: ${self.cash:,.2f} | Pos: {pos:.1f} | Px: ${price:.2f} | "
            f"U-PnL: ${unrealized_pnl:,.2f} | R-PnL: ${realized:,.2f} | DD: {drawdown:.2%} | "
            f"Total U: ${total_unrealized:,.2f} | Total R: ${total_realized:,.2f}"
        )

    def execute(self, symbol, df, signal, price, atr_value):
        if self.drawdown_monitor.is_locked(symbol):
            pos = self.position[symbol]
            if pos != 0:
                self.logger.warning(f"[{symbol}] LOCKED: Force exiting open position due to drawdown.")
                self.force_exit_position(
                    symbol=symbol,
                    price=price,
                    atr_value=atr_value,
                    stop_loss_price=price - atr_value * 2 if pos > 0 else price + atr_value * 2,
                    market_conditions=self.market_regimes.get(symbol, "unknown")
                )
            else:
                self.logger.warning(f"[{symbol}] SKIPPED: Trading is locked and no position is open.")
            return

        
        exit_fraction = 0.5  # or whatever your partial exit ratio is
        market_conditions = self.market_regimes.get(symbol, "unknown")

        stop_loss_price = price - (atr_value * 2) if signal == 1 else price + (atr_value * 2)

        quantity = self.sizer.calculate_position_size(
            stock_price=price,
            stop_loss_price=stop_loss_price,
            current_cash=self.cash,
            market_conditions=market_conditions,
            signal=signal
        )

        trade_fee = 0.001 * price * quantity
        max_affordable_qty = self.cash // (price + trade_fee)
        quantity = min(quantity, max_affordable_qty)

        if quantity < self.MIN_TRADE_QTY:
            self.logger.debug(f"[{symbol}] SKIPPED: Qty={quantity:.0f} < MIN_TRADE_QTY={self.MIN_TRADE_QTY}")
            return

        now = datetime.now(UTC)
        cash_before = self.cash
        pos_before = self.position[symbol]

        # === OPEN LONG ===
        if signal == 1 and self.position[symbol] == 0:
            self.cash -= (price * quantity + trade_fee)
            self.position[symbol] += quantity
            self.total_fees[symbol] += trade_fee
            self.trailing_stop[symbol] = price * 0.97
            self.partial_exit_counts[symbol] = 0
            self.entry_prices[symbol] = price

            tp_mult = self.TP_MULTIPLIERS.get(market_conditions, 1.05)
            self.take_profit[symbol] = price * tp_mult

            self.log_trade_details(symbol, "BUY", quantity, price, sl=stop_loss_price, ts=self.trailing_stop[symbol],
                                fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                                pos_before=pos_before, pos_after=self.position[symbol],
                                atr=atr_value, regime=market_conditions)

        # === OPEN SHORT ===
        elif signal == -1 and self.position[symbol] == 0:
            self.cash += (price * quantity - trade_fee)
            self.position[symbol] -= quantity
            self.total_fees[symbol] += trade_fee
            self.trailing_stop[symbol] = price * 1.03
            self.partial_exit_counts[symbol] = 0
            self.entry_prices[symbol] = price

            tp_mult = self.TP_MULTIPLIERS.get(market_conditions, 1.05)
            self.take_profit[symbol] = price * (2 - tp_mult)

            self.log_trade_details(symbol, "SHORT SELL", quantity, price, sl=stop_loss_price, ts=self.trailing_stop[symbol],
                                fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                                pos_before=pos_before, pos_after=self.position[symbol],
                                atr=atr_value, regime=market_conditions)

        # === TAKE PROFIT HIT - LONG ===
        if self.position[symbol] > 0 and self.take_profit[symbol] and price >= self.take_profit[symbol]:
            qty = int(self.position[symbol] * exit_fraction)
            if qty >= self.MIN_TRADE_QTY:
                trade_fee = 0.001 * price * qty
                entry_price = self.entry_prices[symbol]
                pnl = (price - entry_price) * qty - trade_fee

                self.cash += price * qty - trade_fee
                self.position[symbol] -= qty
                self.total_fees[symbol] += trade_fee
                self.partial_exit_counts[symbol] += 1
                self.realized_pnl[symbol] += pnl

                self.log_trade_details(symbol, "TAKE PROFIT - SELL (PARTIAL)", qty, price,
                                    fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                                    pos_before=pos_before, pos_after=self.position[symbol],
                                    pnl=pnl, atr=atr_value, regime=market_conditions)

                if abs(self.position[symbol]) <= self.MIN_FORCE_EXIT_QTY or self.partial_exit_counts[symbol] >= self.MAX_PARTIAL_EXITS:
                    self.force_exit_position(symbol, price, atr_value, stop_loss_price, market_conditions)

        # === TAKE PROFIT HIT - SHORT ===
        elif self.position[symbol] < 0 and self.take_profit[symbol] and price <= self.take_profit[symbol]:
            qty = int(abs(self.position[symbol]) * exit_fraction)
            if qty >= self.MIN_TRADE_QTY:
                trade_fee = 0.001 * price * qty
                entry_price = self.entry_prices[symbol]
                pnl = (entry_price - price) * qty - trade_fee

                self.cash -= price * qty + trade_fee
                self.position[symbol] += qty
                self.total_fees[symbol] += trade_fee
                self.partial_exit_counts[symbol] += 1
                self.realized_pnl[symbol] += pnl

                self.log_trade_details(symbol, "TAKE PROFIT - COVER SHORT (PARTIAL)", qty, price,
                                    fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                                    pos_before=pos_before, pos_after=self.position[symbol],
                                    pnl=pnl, atr=atr_value, regime=market_conditions)

                if abs(self.position[symbol]) <= self.MIN_FORCE_EXIT_QTY or self.partial_exit_counts[symbol] >= self.MAX_PARTIAL_EXITS:
                    self.force_exit_position(symbol, price, atr_value, stop_loss_price, market_conditions)

        # === SIGNAL FLIP - PARTIAL EXIT ===
        elif (signal == -1 and self.position[symbol] > 0) or (signal == 1 and self.position[symbol] < 0):
            qty = int(abs(self.position[symbol]) * exit_fraction)
            if qty < self.MIN_TRADE_QTY:
                self.logger.debug(f"[{symbol}] SKIPPED: Partial exit qty={qty} < MIN_TRADE_QTY={self.MIN_TRADE_QTY}")
                return

            trade_fee = 0.001 * price * qty
            entry_price = self.entry_prices[symbol]
            pnl = (price - entry_price) * qty - trade_fee if self.position[symbol] > 0 else (entry_price - price) * qty - trade_fee

            self.cash += price * qty - trade_fee if self.position[symbol] > 0 else - (price * qty + trade_fee)
            self.position[symbol] += qty if self.position[symbol] < 0 else -qty
            self.total_fees[symbol] += trade_fee
            self.partial_exit_counts[symbol] += 1
            self.realized_pnl[symbol] += pnl

            action = "SELL (PARTIAL)" if self.position[symbol] > 0 else "COVER SHORT (PARTIAL)"
            self.log_trade_details(symbol, action, qty, price, fee=trade_fee,
                                cash_before=cash_before, cash_after=self.cash,
                                pos_before=pos_before, pos_after=self.position[symbol],
                                pnl=pnl, atr=atr_value, regime=market_conditions)

            if abs(self.position[symbol]) <= self.MIN_FORCE_EXIT_QTY or self.partial_exit_counts[symbol] >= self.MAX_PARTIAL_EXITS:
                self.force_exit_position(symbol, price, atr_value, stop_loss_price, market_conditions)

        # === TRAILING STOP HIT - LONG ===
        elif self.position[symbol] > 0 and price < self.trailing_stop[symbol]:
            qty = int(self.position[symbol] * exit_fraction)
            if qty < self.MIN_TRADE_QTY:
                return

            trade_fee = 0.001 * price * qty
            entry_price = self.entry_prices[symbol]
            pnl = (price - entry_price) * qty - trade_fee

            self.cash += price * qty - trade_fee
            self.position[symbol] -= qty
            self.total_fees[symbol] += trade_fee
            self.partial_exit_counts[symbol] += 1
            self.realized_pnl[symbol] += pnl

            self.log_trade_details(symbol, "TRAIL STOP HIT - SELL (PARTIAL)", qty, price, ts=self.trailing_stop[symbol],
                                fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                                pos_before=pos_before, pos_after=self.position[symbol],
                                pnl=pnl, atr=atr_value, regime=market_conditions)

            if abs(self.position[symbol]) <= self.MIN_FORCE_EXIT_QTY or self.partial_exit_counts[symbol] >= self.MAX_PARTIAL_EXITS:
                self.force_exit_position(symbol, price, atr_value, stop_loss_price, market_conditions)

        # === TRAILING STOP HIT - SHORT ===
        elif self.position[symbol] < 0 and price > self.trailing_stop[symbol]:
            qty = int(abs(self.position[symbol]) * exit_fraction)
            if qty < self.MIN_TRADE_QTY:
                return

            trade_fee = 0.001 * price * qty
            entry_price = self.entry_prices[symbol]
            pnl = (entry_price - price) * qty - trade_fee

            self.cash -= price * qty + trade_fee
            self.position[symbol] += qty
            self.total_fees[symbol] += trade_fee
            self.partial_exit_counts[symbol] += 1
            self.realized_pnl[symbol] += pnl

            self.log_trade_details(symbol, "TRAIL STOP HIT - COVER (PARTIAL)", qty, price, ts=self.trailing_stop[symbol],
                                fee=trade_fee, cash_before=cash_before, cash_after=self.cash,
                                pos_before=pos_before, pos_after=self.position[symbol],
                                pnl=pnl, atr=atr_value, regime=market_conditions)

            if abs(self.position[symbol]) <= self.MIN_FORCE_EXIT_QTY or self.partial_exit_counts[symbol] >= self.MAX_PARTIAL_EXITS:
                self.force_exit_position(symbol, price, atr_value, stop_loss_price, market_conditions)
        
        # === PYRAMIDING - Add to winning position ===
        max_pyramid_multiplier = 2  # Allow up to 2x original entry size
        if signal == 1 and self.position[symbol] > 0:
            entry_price = self.entry_prices[symbol]
            if price > entry_price * 1.01 and self.position[symbol] < max_pyramid_multiplier * quantity:
                # Reinvest (pyramid)
                reinvest_qty = int(quantity * 0.5)
                if reinvest_qty >= self.MIN_TRADE_QTY:
                    reinvest_fee = 0.001 * price * reinvest_qty
                    total_cost = price * reinvest_qty + reinvest_fee
                    if self.cash >= total_cost:
                        self.cash -= total_cost
                        self.position[symbol] += reinvest_qty
                        self.total_fees[symbol] += reinvest_fee
                        self.log_trade_details(symbol, "PYRAMID BUY", reinvest_qty, price,
                            fee=reinvest_fee, cash_before=cash_before, cash_after=self.cash,
                            pos_before=pos_before, pos_after=self.position[symbol],
                            atr=atr_value, regime=market_conditions)
        
        # === Update Trailing Stop ===
        if self.position[symbol] > 0:
            self.trailing_stop[symbol] = max(self.trailing_stop[symbol], price * 0.97)
        elif self.position[symbol] < 0:
            self.trailing_stop[symbol] = min(self.trailing_stop[symbol], price * 1.03)
        self.market_regimes[symbol] = market_conditions
        self.log_portfolio_status(symbol, price)

# ---------------------------- Event Handler ----------------------------
class Event:
    def __init__(self, name, payload):
        self.name = name
        self.payload = payload

class EventHandler:
    _instance = None
    _lock = asyncio.Lock()

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(EventHandler, cls).__new__(cls)
            cls._instance.listeners = defaultdict(list)
            cls._instance.bar_windows = defaultdict(lambda: deque(maxlen=100))
            cls._instance.current_day_bar = defaultdict(lambda: None)
            cls._instance.logger = logging.getLogger("EventHandler")
        return cls._instance

    def subscribe(self, event_name, callback):
        self.logger.debug(f"Subscribed to '{event_name}'")
        self.listeners[event_name].append(callback)

    def emit(self, event_name, payload):
        event = Event(event_name, payload)
        self.logger.debug(f"Emitting '{event_name}' with payload: {payload}")
        for callback in self.listeners[event_name]:
            callback(event)
#----------------------------- classify regime--------------------
def classify_regime(price: float, atr: float) -> str:
    ratio = atr / price
    if ratio < 0.01:
        return "low_volatility"
    elif ratio > 0.03:
        return "high_volatility"
    else:
        return "medium_volatility"
# ---------------------------- Strategy Callback ----------------------------
def strategy_on_bar(event, strategy_map, executor, handler, plotter: LivePlotter = None):
    """
    Handles incoming bars for a symbol, updates the intraday aggregate bar,
    computes indicators, and executes trades based on strategy signals.

    Supports per-symbol strategy from the strategy_map.
    """
    data = event.payload
    symbol = data["symbol"]
    bar = data["bar"]

    # Check if the symbol has an assigned strategy
    if symbol not in strategy_map:
        executor.logger.warning(f"[{symbol}] No strategy configured. Skipping.")
        return

    strategy = strategy_map[symbol]

    # Check if we need to start a new intraday bar
    current = handler.current_day_bar[symbol]
    bar_date = bar["timestamp"].date()

    if current is None or current["Date"].date() != bar_date:
        if current:
            handler.bar_windows[symbol].append(current)
        handler.current_day_bar[symbol] = {
            "Date": bar["timestamp"],
            "Open": bar["open"],
            "High": bar["high"],
            "Low": bar["low"],
            "Close": bar["close"],
            "Volume": bar["volume"]
        }
    else:
        # Update intraday bar
        cb = handler.current_day_bar[symbol]
        cb["High"] = max(cb["High"], bar["high"])
        cb["Low"] = min(cb["Low"], bar["low"])
        cb["Close"] = bar["close"]
        cb["Volume"] += bar["volume"]

    # Combine bars into DataFrame for indicator computation
    full_window = list(handler.bar_windows[symbol]) + [handler.current_day_bar[symbol]]
    df = pd.DataFrame(full_window)

    if len(df) < 20:
        return  # Not enough data

    df = ATRIndicator(df).compute()
    df = strategy.generate_signal(df)

    latest = df.iloc[-1]
    signal = latest.get("Signal", None)
    price = latest["Close"]
    atr = latest["ATR"]
    if atr is not None and price > 0:
        regime = classify_regime(price, atr)
    else:
        regime = "unknown"

    executor.market_regimes[symbol] = regime
    if plotter:
        plotter.record_signal(symbol, bar["timestamp"], bar.get("Signal", 0))
    
    
    # ⬇️ Always log portfolio status — even if there's no trade signal
    #executor.log_portfolio_status(symbol, price)

    #executor.logger.debug(f"[DEBUG] {symbol} strategy output:\n{df.tail(3)}")
    executor.execute(symbol, df, signal, price, atr)


# ---------------------------- Main Bootstrap ----------------------------
if __name__ == "__main__":
    symbols = ["AAPL", "MSFT", "TSLA", "AMD", "META", "AMZN"]
    hist_path = r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\data\data_storage\proc_data"

    # Setup components
    #drawdown_monitor = DrawdownMonitor(max_drawdown=0.20)
    executor = MockExecutor()
    handler = EventHandler()
    loader = HistoricalBarLoader(hist_path)
    plotter = LivePlotter(symbols, window=100)
    last_prices = {
            symbol: loader.get_latest_close_price(symbol) for symbol in symbols
        }

    # Define strategies per symbol
    strategy_map = {
        "AAPL": MomentumStrategy(),
        "MSFT": MomentumStrategy(),
        "TSLA": MomentumStrategy(),
        "AMD": MomentumStrategy(),
        "META": MomentumStrategy(),
        "AMZN": MomentumStrategy()
    }

    # Load historical bars
    for symbol in symbols:
        history = loader.load_last_n_bars(symbol, n=99)
        for bar in history:
            bar["Date"] = pd.to_datetime(bar["Date"], unit='ms') if isinstance(bar["Date"], (int, float)) else pd.to_datetime(bar["Date"])
            handler.bar_windows[symbol].append(bar)
    
    handler.subscribe("BAR_CREATED", lambda event: strategy_on_bar(event, strategy_map, executor, handler, plotter))

    # Simulate bar emission (to be replaced with actual streamer)

    async def mock_stream(symbols, handler: EventHandler, base_price=300.0, interval_sec=1, plotter: LivePlotter = None):
        """
        Asynchronously emits synthetic OHLC bars using GBMSimulator.
        Includes realistic drift, volatility, and price shocks.
        """
        logger = logging.getLogger("MockStream")
        logger.setLevel(logging.DEBUG)

        sim = GBMSimulator(symbols, base_price=last_prices, log_prices=True)

        while True:
            bars = sim.update_all()
            for symbol, bar in bars.items():
                handler.emit("BAR_CREATED", {"symbol": symbol, "bar": bar})
                if plotter:
                    plotter.update_bar(symbol, bar)
            if plotter:
                plotter.draw()
            await asyncio.sleep(interval_sec)


    nest_asyncio.apply()
    asyncio.run(mock_stream(symbols, handler))

In [None]:
# okay another prototype:
# ============================ #
#   FULL LIVE TRADING SYSTEM  #
# ============================ #

import sys
from pathlib import Path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import os
import json
import asyncio
import logging
from datetime import datetime, date, timedelta
from collections import defaultdict, deque
import requests
import pandas as pd
from dotenv import load_dotenv
from pytz import timezone
from time import time

from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
from alpaca.trading.requests import MarketOrderRequest
from alpaca.data.live import StockDataStream
from alpaca.trading.stream import TradingStream

from strategies.strategy_registry.momentum_strategy import MomentumStrategy
from strategies.strategy_registry.rsi_strategy import RSIStrategy
from strategies.strategy_registry.stochastic_strategy import StochasticStrategy
from strategies.strategy_registry.macd_strategy import MACDStrategy

from indicators.atr import ATRIndicator
from core.position_sizer import DynamicPositionSizer

load_dotenv(r"C:\Users\kwasi\OneDrive\Documents\Personal Projects\schwab_trader\venv\.env")
API_KEY = os.getenv("ALPACA_API_KEY")
API_SECRET = os.getenv("ALPACA_SECRET")
DISCORD_WEBHOOK = os.getenv("DISCORD_WEBHOOK")
LLM_URL = "http://localhost:11434/api/generate"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("LiveRunner")
logger.setLevel(logging.DEBUG)

SYMBOLS = ["AAPL", "MSFT", "TSLA", "AMD", "META", "AMZN"]

trading_client = TradingClient(API_KEY, API_SECRET, paper=True)
stream = StockDataStream(API_KEY, API_SECRET)
trade_stream = TradingStream(API_KEY, API_SECRET, paper=True)

strategy_map = {
    "AAPL": MomentumStrategy(),
    "MSFT": RSIStrategy(),
    "TSLA": MACDStrategy(),
    "AMD": StochasticStrategy(),
    "META": MomentumStrategy(),
    "AMZN": RSIStrategy(),
}
sizer = DynamicPositionSizer(risk_percentage=0.07)

MIN_TRADE_QTY = 1
FORCE_EXIT_QTY = 10
MAX_PARTIALS = 5
last_trade_date = None
partial_exit_count = defaultdict(int)
daily_buffer = deque(maxlen=100)
current_day_bar = None
peak_portfolio_value = 0
portfolio_history = []

ACCOUNT_REFRESH_INTERVAL = timedelta(minutes=1)
last_account_refresh_time = None
live_account_state = {
    "cash": 0.0,
    "equity": 0.0,
    "buying_power": 0.0
}
live_positions = defaultdict(lambda: {"qty": 0, "avg_entry_price": 0.0})

def should_refresh_account(now, last_refresh_time, interval):
    return last_refresh_time is None or (now - last_refresh_time) >= interval

def send_discord_message(content):
    if not DISCORD_WEBHOOK: return
    try:
        requests.post(DISCORD_WEBHOOK, json={"content": content})
    except Exception as e:
        logger.error(f"Discord error: {e}")

def query_local_llm(prompt):
    try:
        resp = requests.post(LLM_URL, json={"model": "deepseek", "prompt": prompt, "stream": False})
        return resp.json().get("response", "").strip()
    except Exception as e:
        logger.error(f"LLM error: {e}")
        return ""

def is_market_close(timestamp_utc):
    eastern = timezone("US/Eastern")
    eastern_time = timestamp_utc.astimezone(eastern)
    return eastern_time.hour == 16 and eastern_time.minute == 0

def submit_order(symbol, qty, side):
    order = MarketOrderRequest(
        symbol=symbol,
        qty=int(qty),
        side=side,
        time_in_force=TimeInForce.DAY
    )
    try:
        return trading_client.submit_order(order)
    except Exception as e:
        logger.error(f"Order submission failed: {e}")
        return None

async def on_bar(bar):
    global current_day_bar, peak_portfolio_value, last_trade_date, last_account_refresh_time

    symbol = bar.symbol
    timestamp = pd.to_datetime(bar.timestamp)
    bar_date = timestamp.date()

    if is_market_close(timestamp):
        send_discord_message(f"📈 Market closed for {symbol}.")
        return

    if current_day_bar and current_day_bar['Date'].date() != bar_date:
        daily_buffer.append(current_day_bar)
        logger.info(f"Finalized bar for {current_day_bar['Date'].date()}")
        current_day_bar = None

    if current_day_bar is None:
        current_day_bar = {
            "Date": timestamp,
            "Open": bar.open,
            "High": bar.high,
            "Low": bar.low,
            "Close": bar.close,
            "Volume": bar.volume
        }
    else:
        current_day_bar["High"] = max(current_day_bar["High"], bar.high)
        current_day_bar["Low"] = min(current_day_bar["Low"], bar.low)
        current_day_bar["Close"] = bar.close
        current_day_bar["Volume"] += bar.volume

    logger.info(f"[{symbol}] {timestamp} | O: {bar.open:.2f} H: {bar.high:.2f} L: {bar.low:.2f} C: {bar.close:.2f} V: {bar.volume}")
    logger.info(peak_portfolio_value, last_trade_date, last_account_refresh_time, live_account_state, live_positions)


    df = pd.DataFrame(list(daily_buffer) + [current_day_bar])
    if len(df) < 20:
        return

    now = datetime.now()
    if should_refresh_account(now, last_account_refresh_time, ACCOUNT_REFRESH_INTERVAL):
        try:
            account = trading_client.get_account()
            live_account_state["cash"] = float(account.cash)
            live_account_state["equity"] = float(account.portfolio_value)
            live_account_state["buying_power"] = float(account.buying_power)
            last_account_refresh_time = now
            logger.debug("Account state refreshed.")
        except Exception as e:
            logger.warning(f"Account update failed: {e}")

    df = ATRIndicator(df).compute()
    df = strategy_map[symbol].generate_signal(df)
    latest = df.iloc[-1]

    signal = latest['Signal']
    price = latest['Close']
    atr_value = latest['ATR']

    if pd.isna(atr_value) or atr_value <= 0:
        return

    position_data = live_positions[symbol]
    position_qty = int(position_data["qty"])
    entry_price = position_data["avg_entry_price"]
    cash = live_account_state.get("cash", 0.0)
    portfolio_value = live_account_state.get("equity", 0.0)

    stop_loss_price = price - (atr_value * 2) if signal == 1 else price + (atr_value * 2)
    quantity = sizer.calculate_position_size(price, stop_loss_price, cash, "unknown")
    trade_fee = 0.001 * price * quantity
    quantity = min(quantity, int(cash // (price + trade_fee)))

    today_str = date.today().isoformat()
    if last_trade_date == today_str:
        return

    llm_message = None
    if signal == 1 and quantity >= MIN_TRADE_QTY and position_qty == 0:
        submit_order(symbol, quantity, OrderSide.BUY)
        last_trade_date = today_str
        llm_message = f"[LIVE] BUY {quantity} {symbol} @ {price:.2f}"

    elif signal == -1 and position_qty > 0:
        submit_order(symbol, position_qty, OrderSide.SELL)
        last_trade_date = today_str
        llm_message = f"[LIVE] SELL {position_qty} {symbol} @ {price:.2f}"

    elif position_qty > 0 and price <= stop_loss_price:
        submit_order(symbol, position_qty, OrderSide.SELL)
        last_trade_date = today_str
        llm_message = f"[LIVE] STOP-LOSS SELL {position_qty} {symbol} @ {price:.2f}"

    if llm_message:
        logger.info(llm_message)
        prompt = f"Analyze the following trade activity and provide a short summary with risk commentary:\n{llm_message}"
        llm_response = query_local_llm(prompt)
        if llm_response:
            send_discord_message(f"🤖 LLM Insight:\n{llm_response}")

    peak_portfolio_value = max(peak_portfolio_value, portfolio_value)
    drawdown = (portfolio_value - peak_portfolio_value) / peak_portfolio_value
    portfolio_history.append({
        "Date": datetime.now(),
        "Portfolio_Value": portfolio_value,
        "Cash": cash,
        "Position": position_qty,
        "Price": price,
        "Drawdown": drawdown
    })

async def periodic_account_logger():
    global last_account_refresh_time
    while True:
        now = datetime.now()
        if should_refresh_account(now, last_account_refresh_time, ACCOUNT_REFRESH_INTERVAL):
            try:
                account = trading_client.get_account()
                live_account_state["cash"] = float(account.cash)
                live_account_state["equity"] = float(account.portfolio_value)
                live_account_state["buying_power"] = float(account.buying_power)
                last_account_refresh_time = now

                logger.info(f"[ACCOUNT] Cash: {account.cash:.2f} | Equity: {account.portfolio_value:.2f} | Buying Power: {account.buying_power:.2f}")
                logger.debug(f"[ACCOUNT STATE] {live_account_state}")
                logger.debug(f"[POSITIONS] {dict(live_positions)}")

            except Exception as e:
                logger.warning(f"Account update failed: {e}")
        await asyncio.sleep(15)


async def main():
    for symbol in SYMBOLS:
        stream.subscribe_bars(on_bar, symbol)
    clock = trading_client.get_clock()
    if not clock.is_open:
        logger.warning("Market is currently closed. Bars may not stream.")
        send_discord_message("Market is currently closed. Bars may not stream.")
    else:
        logger.info("Market is open. Starting bar stream...")
        send_discord_message(":rocket: Live Trading Bot Started")
        await asyncio.gather(
        stream.run(),
        trade_stream.run(),
        periodic_account_logger()
    )

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    asyncio.run(main())
