In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
!pip install matplotlib
!pip install --upgrade ipympl
!pip install mplcursors
!pip install plotly pandas nbformat --upgrade
!pip install jsonpickle





[notice] A new release of pip is available: 24.3 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.3 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.3 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.3 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.3 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
from typing import List, Dict, Tuple, Any

# Assuming datamodel.py and template2.py are in the same directory or Python path
try:
    from datamodel import (
        TradingState, OrderDepth, Listing, Trade, Observation,
        ConversionObservation, Order, ProsperityEncoder, Symbol, Product, Position
    )
    from template2 import Trader,Status # Import Status to access its methods if needed
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Make sure datamodel.py and template2.py are accessible.")
    # You might need to add the directory to the path:
    # import sys
    # sys.path.append('path/to/your/files')
    # from datamodel import ...
    # from template2 import ...

# Configure plot style (optional)
plt.style.use('seaborn-v0_8-darkgrid') # Use a seaborn style

In [2]:
# Cell 2: Load Data
try:
    df = pd.read_csv("data/round1/prices_round_1_day_0.csv", delimiter=";")
    print(f"Data loaded successfully. Shape: {df.shape}")
except FileNotFoundError:
    print("Error: prices_round_1_day_0.csv not found.")
    print("Make sure the path 'data/round1/prices_round_1_day_0.csv' is correct relative to the notebook.")
    df = None # Set df to None to prevent errors in the next cell
except Exception as e:
    print(f"An error occurred loading the CSV: {e}")
    df = None

# Group data by timestamp if loaded successfully
if df is not None:
    grouped_by_timestamp = df.groupby('timestamp')
    print(f"Data grouped into {len(grouped_by_timestamp)} timestamps.")
else:
    grouped_by_timestamp = None

Data loaded successfully. Shape: (30000, 17)
Data grouped into 10000 timestamps.


In [3]:
# Cell 3: Backtesting Loop - Collecting Trade Signals, Mid-Prices, AND EMA

if grouped_by_timestamp is not None:
    trader = Trader() # Initialize your Trader

    # Data structure to hold trade signals (Unchanged)
    trade_signals: Dict[Product, List[Dict[str, Any]]] = {}

    # Data structure to hold market context (NOW INCLUDES EMA)
    # Format: { "PRODUCT_NAME": {"timestamps": [], "mid_prices": [], "ema_20": []}, ... }


    # Initialize simulation state variables
    current_position: Dict[Product, Position] = {}
    traderData = "" # Initialize traderData

    print("Starting backtest simulation and data collection (including EMA)...")

    # Explicitly map product names to their corresponding status objects in the Trader
    # Adjust these names if they differ in your Trader class!
    product_to_status_obj_map = {
        "SQUID_INK": trader.state_SQUID_INK,
        "RAINFOREST_RESIN": trader.state_RAINFOREST_RESIN,
        "KELP": trader.state_KELP,
        "CROISSANTS": trader.state_CROISSANTS,
        "JAMS": trader.state_JAMS,
        "DJEMBES": trader.state_DJEMBES,
        "PICNIC_BASKET1": trader.state_PICNIC_BASKET1,
        "PICNIC_BASKET2": trader.state_PICNIC_BASKET2,
        # Add mappings for any other products your trader handles
    }

    market_context: Dict[Product, Dict[str, List]] = {}


    # Iterate through each timestamp
    processed_timestamps = 0
    for timestamp, group in grouped_by_timestamp:
        # --- Build TradingState (Same logic as before) ---
        listings: Dict[Symbol, Listing] = {}
        order_depths: Dict[Symbol, OrderDepth] = {}
        market_trades: Dict[Symbol, List[Trade]] = {}
        own_trades: Dict[Symbol, List[Trade]] = {}
        products_in_timestamp = group['product'].unique()

        for product in products_in_timestamp:
            row = group[group['product'] == product].iloc[0]
            symbol = product
            listings[symbol] = Listing(symbol=symbol, product=product, denomination="SEASHELLS")
            depth = OrderDepth()
            buy_orders: Dict[int, int] = {}
            sell_orders: Dict[int, int] = {}
            for level in range(1, 4):
                bid_price, bid_vol = row.get(f'bid_price_{level}'), row.get(f'bid_volume_{level}')
                ask_price, ask_vol = row.get(f'ask_price_{level}'), row.get(f'ask_volume_{level}')
                if pd.notna(bid_price) and pd.notna(bid_vol) and bid_vol > 0:
                    buy_orders[int(bid_price)] = int(bid_vol)
                if pd.notna(ask_price) and pd.notna(ask_vol) and ask_vol > 0:
                    sell_orders[int(ask_price)] = -int(ask_vol)
            depth.buy_orders = buy_orders
            depth.sell_orders = sell_orders
            order_depths[symbol] = depth

            if product not in current_position:
                current_position[product] = 0
            market_trades[symbol] = []
            own_trades[symbol] = []

            # Initialize data collection lists if first time seen
            if product not in trade_signals:
                 trade_signals[product] = []
            if product not in market_context:
                 # Initialize with ema_20 list now
                 market_context[product] = {"timestamps": [], "mid_prices": [], "top": [],"bottom":[],"ema_mid_35":[],"ema_mid_100":[],
                                            "RSI":[],"rsi_ema":[]
                                            }


        observations = Observation(plainValueObservations={}, conversionObservations={})
        state = TradingState(
            traderData=traderData, timestamp=int(timestamp), listings=listings,
            order_depths=order_depths, own_trades=own_trades, market_trades=market_trades,
            position=current_position.copy(), observations=observations
        )
        # --- Run Trader Logic (THIS UPDATES THE STATUS OBJECTS, INCLUDING EMA) ---
        orders_dict, conversions, next_traderData = trader.run(state)
    
        # --- Collect Trade Signals (Unchanged) ---
        for symbol, order_list in orders_dict.items():

            product = state.listings[symbol].product

            market_context[product]["timestamps"].append(product_to_status_obj_map[product].time[-1])
            market_context[product]["mid_prices"].append(product_to_status_obj_map[product].mid)
            market_context[product]["top"].append(product_to_status_obj_map[product].best_ask)
            market_context[product]["bottom"].append(product_to_status_obj_map[product].best_bid)
            market_context[product]["ema_mid_35"].append(product_to_status_obj_map[product].ema_mid_35[-1])
            market_context[product]["ema_mid_100"].append(product_to_status_obj_map[product].ema_mid_100[-1])
            market_context[product]["RSI"].append(product_to_status_obj_map[product].rsi[-1])
            market_context[product]["rsi_ema"].append(product_to_status_obj_map[product].rsi_ema[-1])


            for order in order_list:
                signal_type = "BUY" if order.quantity > 0 else "SELL"
                signal_info = {
                    "timestamp": int(timestamp),
                    "type": signal_type,
                    "price": order.price,
                    "quantity": order.quantity
                }

                trade_signals[product].append(signal_info)


        # --- Update State for Next Iteration (Unchanged) ---
        traderData = next_traderData
        for symbol, orders in orders_dict.items():
             # ... unchanged ...
             product = listings[symbol].product
             for order in orders:
                 current_position[product] = current_position.get(product, 0) + order.quantity


        processed_timestamps += 1
        if processed_timestamps % 1000 == 0: # Print progress
             print(f"Processed {processed_timestamps}/{len(grouped_by_timestamp)} timestamps...")

    print("Backtest simulation and data collection finished.")

else:
    print("Skipping simulation as data was not loaded.")

Starting backtest simulation and data collection (including EMA)...
Processed 1000/10000 timestamps...
(SQUID_INK, 1961, 2)
(SQUID_INK, 1962, 31)
(SQUID_INK, 1961, -27)
(SQUID_INK, 1960, 27)
(SQUID_INK, 1958, 17)
(SQUID_INK, 1957, -1)
(SQUID_INK, 1955, -27)
(SQUID_INK, 1954, -28)
(SQUID_INK, 1955, -22)
(SQUID_INK, 1957, -1)
(SQUID_INK, 1955, -21)
Processed 2000/10000 timestamps...
(SQUID_INK, 1949, 50)
(SQUID_INK, 1945, 3)
(SQUID_INK, 1946, 2)
(SQUID_INK, 1947, 27)
(SQUID_INK, 1947, 2)
(SQUID_INK, 1948, 16)
(SQUID_INK, 1931, -23)
(SQUID_INK, 1933, -29)
(SQUID_INK, 1935, -27)
(SQUID_INK, 1934, -21)
Processed 3000/10000 timestamps...
(SQUID_INK, 1912, 50)
(SQUID_INK, 1914, 23)
(SQUID_INK, 1913, 1)
(SQUID_INK, 1914, 4)
(SQUID_INK, 1916, 22)
(SQUID_INK, 1924, -23)
(SQUID_INK, 1924, -29)
(SQUID_INK, 1923, -25)
(SQUID_INK, 1924, -1)
(SQUID_INK, 1922, -22)
(SQUID_INK, 1933, 6)
(SQUID_INK, 1934, 23)
(SQUID_INK, 1933, 28)
(SQUID_INK, 1930, 8)
(SQUID_INK, 1931, 27)
(SQUID_INK, 1934, 8)
(SQUID_IN

In [4]:
# Cell 4: Plotting Trade Signals (Interactive using Plotly - WITH EMA)

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd # Required for easier data handling
import numpy as np # Make sure numpy is imported

# Make sure this cell is run after Cell 3 where 'trade_signals'
# and 'market_context' (now including EMA) are populated.

if 'trade_signals' not in globals() or 'market_context' not in globals():
     print("Error: Data dictionaries ('trade_signals', 'market_context') not found.")
     print("Please ensure Cell 3 has been run successfully before running this cell.")
elif not trade_signals and not market_context:
    print("No data collected for visualization in Cell 3.")
else:
    print("Generating interactive plots using Plotly (with EMA)...")

    # Combine keys from both dictionaries to ensure all products are attempted
    all_products = set(trade_signals.keys()) | set(market_context.keys())

    if not all_products:
         print("No products found in the collected data.")

    for product in all_products:
        # --- Data Verification ---
        # Check specifically if timestamps list exists and is not empty
        has_timestamps = product in market_context and market_context[product].get("timestamps", [])
        has_signals = product in trade_signals and trade_signals[product]

        if not has_timestamps and not has_signals:
             print(f"  - Skipping plot for {product}: No timestamp or signal data.")
             continue # Skip to the next product

        print(f"Plotting for {product}...")

        # Create a Plotly Figure
        fig = go.Figure()

        # Use timestamps from context if available
        timestamps_ctx = market_context[product]["timestamps"] if has_timestamps else []

        # --- Plot Mid-Price Context ---
        mid_prices_ctx = market_context.get(product, {}).get("mid_prices", [])
        if len(timestamps_ctx) == len(mid_prices_ctx) and any(pd.notna(mid_prices_ctx)):
            fig.add_trace(go.Scatter(
                x=timestamps_ctx,
                y=mid_prices_ctx,
                mode='lines',
                name=f'{product} Mid-Price',
                line=dict(color='grey', width=1),
                hoverinfo='skip' # Optional: skip hover for the context line
            ))
        elif has_timestamps:
            print(f"  - Mid-price data missing or length mismatch for {product}.")

        # TOP line
        mid_prices_ctx = market_context.get(product, {}).get("top", [])
        if len(timestamps_ctx) == len(mid_prices_ctx) and any(pd.notna(mid_prices_ctx)):
            fig.add_trace(go.Scatter(
                x=timestamps_ctx,
                y=mid_prices_ctx,
                mode='lines',
                name=f'{product} top',
                line=dict(color='red', width=1),
                hoverinfo='skip' # Optional: skip hover for the context line
            ))
        elif has_timestamps:
            print(f"  - top data missing or length mismatch for {product}.")

        # bottom line
        mid_prices_ctx = market_context.get(product, {}).get("bottom", [])
        if len(timestamps_ctx) == len(mid_prices_ctx) and any(pd.notna(mid_prices_ctx)):
            fig.add_trace(go.Scatter(
                x=timestamps_ctx,
                y=mid_prices_ctx,
                mode='lines',
                name=f'{product} bottom',
                line=dict(color='green', width=1),
                hoverinfo='skip' # Optional: skip hover for the context line
            ))
        elif has_timestamps:
            print(f"  - bottom data missing or length mismatch for {product}.")

        # RSI
        mid_prices_ctx = market_context.get(product, {}).get("RSI", [])
        if len(timestamps_ctx) == len(mid_prices_ctx) and any(pd.notna(mid_prices_ctx)):
            fig.add_trace(go.Scatter(
                x=timestamps_ctx,
                y=mid_prices_ctx,
                mode='lines',
                name='RSI',
                line=dict(color='black', width=1),
                hoverinfo='skip' # Optional: skip hover for the context line
            ))
        elif has_timestamps:
            print(f"  - RSI data missing or length mismatch for {product}.")
            
        

        # --- **** NEW: Plot EMA Line **** ---
        ema_values_35 = market_context.get(product, {}).get("ema_mid_35", [])
        if len(timestamps_ctx) == len(ema_values_35) and any(pd.notna(ema_values_35)):
             fig.add_trace(go.Scatter(
                 x=timestamps_ctx,
                 y=ema_values_35,
                 mode='lines',
                 name='EMA (35)', # Assuming 20-period EMA from Status class
                 line=dict(color='orange', width=1.5, dash='dot'), # Style EMA line
                 hoverinfo='skip' # Optional: skip hover for EMA line too
             ))
        elif has_timestamps:
             print(f"  - EMA data missing or length mismatch for {product}.")


        ema_values_100 = market_context.get(product, {}).get("ema_mid_100", [])
        if len(timestamps_ctx) == len(ema_values_100) and any(pd.notna(ema_values_100)):
             fig.add_trace(go.Scatter(
                 x=timestamps_ctx,
                 y=ema_values_100,
                 mode='lines',
                 name='EMA (100)', # Assuming 20-period EMA from Status class
                 line=dict(color='blue', width=1.5, dash='dot'), # Style EMA line
                 hoverinfo='skip' # Optional: skip hover for EMA line too
             ))
        elif has_timestamps:
             print(f"  - EMA data missing or length mismatch for {product}.")


        # --- Plot Buy/Sell Signals (Unchanged from previous Plotly version) ---
        if has_signals:
            signals = trade_signals[product]
            buy_signals = [s for s in signals if s['type'] == 'BUY']
            sell_signals = [s for s in signals if s['type'] == 'SELL']

            if buy_signals:
                buy_df = pd.DataFrame(buy_signals)
                fig.add_trace(go.Scatter(
                    x=buy_df['timestamp'], y=buy_df['price'], mode='markers', name='Buy Signal',
                    marker=dict(symbol='triangle-up', color='lime', size=10, line=dict(width=1, color='black')),
                    customdata=buy_df[['quantity', 'timestamp']],
                    hovertemplate='<b>BUY</b><br>Price: %{y}<br>Qty: %{customdata[0]}<br>Time: %{customdata[1]}<extra></extra>'
                ))

            if sell_signals:
                sell_df = pd.DataFrame(sell_signals)
                fig.add_trace(go.Scatter(
                    x=sell_df['timestamp'], y=sell_df['price'], mode='markers', name='Sell Signal',
                    marker=dict(symbol='triangle-down', color='red', size=10, line=dict(width=1, color='black')),
                    customdata=sell_df[['quantity', 'timestamp']],
                    hovertemplate='<b>SELL</b><br>Price: %{y}<br>Qty: %{customdata[0]}<br>Time: %{customdata[1]}<extra></extra>'
                ))
        else:
             print(f"  - No trade signal data collected for {product}.")


        # --- Final Plot Configuration (Unchanged) ---
        fig.update_layout(
            title=f"Interactive Trade Signals for {product} (with EMA)",
            xaxis_title="Timestamp",
            yaxis_title="Price",
            yaxis_tickformat=',.0f',
            hovermode="closest",
            legend_title_text="Signals & Indicators"
        )

        fig.show() # Display the Plotly figure


    for product in all_products:
        # --- Data Verification ---
        # Check specifically if timestamps list exists and is not empty
        has_timestamps = product in market_context and market_context[product].get("timestamps", [])
        has_signals = product in trade_signals and trade_signals[product]

        if not has_timestamps and not has_signals:
             continue # Skip to the next product

        # Create a Plotly Figure
        fig = go.Figure()

        # Use timestamps from context if available
        timestamps_ctx = market_context[product]["timestamps"] if has_timestamps else []


        # RSI
        mid_prices_ctx = market_context.get(product, {}).get("rsi_ema", [])
        if len(timestamps_ctx) == len(mid_prices_ctx) and any(pd.notna(mid_prices_ctx)):
            fig.add_trace(go.Scatter(
                x=timestamps_ctx,
                y=mid_prices_ctx,
                mode='lines',
                name='RSI_ema',
                line=dict(color='black', width=1),
                hoverinfo='skip' # Optional: skip hover for the context line
            ))
        elif has_timestamps:
            print(f"  - RSI data missing or length mismatch for {product}.")
            

        # --- Final Plot Configuration (Unchanged) ---
        fig.update_layout(
            title=f"Interactive Trade Signals for {product} (with EMA)",
            xaxis_title="Timestamp",
            yaxis_title="Price",
            yaxis_tickformat=',.0f',
            hovermode="closest",
            legend_title_text="Signals & Indicators"
        )

        fig.show() # Display the Plotly figure

    print("Plotly plot generation finished.")

Generating interactive plots using Plotly (with EMA)...
  - Skipping plot for RAINFOREST_RESIN: No timestamp or signal data.
  - Skipping plot for KELP: No timestamp or signal data.
Plotting for SQUID_INK...


Plotly plot generation finished.
