In [None]:
!pip install matplotlib
!pip install ipympl
!pip install mplcursors

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
from typing import List, Dict, Tuple, Any

# Assuming datamodel.py and template2.py are in the same directory or Python path
try:
    from datamodel import (
        TradingState, OrderDepth, Listing, Trade, Observation,
        ConversionObservation, Order, ProsperityEncoder, Symbol, Product, Position
    )
    from template2 import Trader,Status # Import Status to access its methods if needed
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Make sure datamodel.py and template2.py are accessible.")
    # You might need to add the directory to the path:
    # import sys
    # sys.path.append('path/to/your/files')
    # from datamodel import ...
    # from template2 import ...

# Configure plot style (optional)
plt.style.use('seaborn-v0_8-darkgrid') # Use a seaborn style

In [None]:
# Cell 2: Load Data
try:
    df = pd.read_csv("data/round1/prices_round_1_day_0.csv", delimiter=";")
    print(f"Data loaded successfully. Shape: {df.shape}")
except FileNotFoundError:
    print("Error: prices_round_1_day_0.csv not found.")
    print("Make sure the path 'data/round1/prices_round_1_day_0.csv' is correct relative to the notebook.")
    df = None # Set df to None to prevent errors in the next cell
except Exception as e:
    print(f"An error occurred loading the CSV: {e}")
    df = None

# Group data by timestamp if loaded successfully
if df is not None:
    grouped_by_timestamp = df.groupby('timestamp')
    print(f"Data grouped into {len(grouped_by_timestamp)} timestamps.")
else:
    grouped_by_timestamp = None

In [None]:
# Cell 3: Backtesting Loop - Collecting Trade Signals and Mid-Prices

if grouped_by_timestamp is not None:
    trader = Trader() # Initialize your Trader

    # Data structure to hold trade signals
    # Format: { "PRODUCT_NAME": [ {"timestamp": ts, "type": "BUY/SELL", "price": p, "quantity": q}, ... ], ... }
    trade_signals: Dict[Product, List[Dict[str, Any]]] = {}

    # Data structure to hold mid-prices for context
    # Format: { "PRODUCT_NAME": {"timestamps": [ts1, ts2,...], "mid_prices": [p1, p2,...]}, ... }
    market_context: Dict[Product, Dict[str, List]] = {}


    # Initialize simulation state variables
    current_position: Dict[Product, Position] = {}
    traderData = "" # Initialize traderData

    print("Starting backtest simulation and data collection...")

    # Iterate through each timestamp
    processed_timestamps = 0
    for timestamp, group in grouped_by_timestamp:
        # --- Build TradingState (Same logic as before) ---
        listings: Dict[Symbol, Listing] = {}
        order_depths: Dict[Symbol, OrderDepth] = {}
        market_trades: Dict[Symbol, List[Trade]] = {}
        own_trades: Dict[Symbol, List[Trade]] = {}
        products_in_timestamp = group['product'].unique()

        for product in products_in_timestamp:
            row = group[group['product'] == product].iloc[0]
            symbol = product
            listings[symbol] = Listing(symbol=symbol, product=product, denomination="SEASHELLS")
            depth = OrderDepth()
            # ... (rest of order depth creation logic - unchanged) ...
            buy_orders: Dict[int, int] = {}
            sell_orders: Dict[int, int] = {}
            for level in range(1, 4):
                bid_price, bid_vol = row.get(f'bid_price_{level}'), row.get(f'bid_volume_{level}')
                ask_price, ask_vol = row.get(f'ask_price_{level}'), row.get(f'ask_volume_{level}')
                if pd.notna(bid_price) and pd.notna(bid_vol) and bid_vol > 0:
                    buy_orders[int(bid_price)] = int(bid_vol)
                if pd.notna(ask_price) and pd.notna(ask_vol) and ask_vol > 0:
                    sell_orders[int(ask_price)] = -int(ask_vol)
            depth.buy_orders = buy_orders
            depth.sell_orders = sell_orders
            order_depths[symbol] = depth

            if product not in current_position:
                current_position[product] = 0
            market_trades[symbol] = []
            own_trades[symbol] = []

            # Initialize data collection lists if first time seen
            if product not in trade_signals:
                 trade_signals[product] = []
            if product not in market_context:
                 market_context[product] = {"timestamps": [], "mid_prices": []}

            # --- Calculate and Store Mid-Price ---
            best_bid = max(buy_orders.keys()) if buy_orders else np.nan
            best_ask = min(sell_orders.keys()) if sell_orders else np.nan
            mid_price = np.nan
            if not np.isnan(best_bid) and not np.isnan(best_ask):
                mid_price = (best_bid + best_ask) / 2.0
            elif not np.isnan(best_bid):
                mid_price = best_bid # Or handle differently
            elif not np.isnan(best_ask):
                mid_price = best_ask # Or handle differently

            market_context[product]["timestamps"].append(int(timestamp))
            market_context[product]["mid_prices"].append(mid_price)


        observations = Observation(plainValueObservations={}, conversionObservations={})
        state = TradingState(
            traderData=traderData, timestamp=int(timestamp), listings=listings,
            order_depths=order_depths, own_trades=own_trades, market_trades=market_trades,
            position=current_position.copy(), observations=observations
        )
        # --- Run Trader Logic ---
        orders_dict, conversions, next_traderData = trader.run(state)

        # --- Collect Trade Signals ---
        for symbol, order_list in orders_dict.items():
             product = state.listings[symbol].product
             for order in order_list:
                 signal_type = "BUY" if order.quantity > 0 else "SELL"
                 signal_info = {
                     "timestamp": int(timestamp),
                     "type": signal_type,
                     "price": order.price,
                     "quantity": order.quantity
                 }
                 trade_signals[product].append(signal_info)

        # --- Update State for Next Iteration ---
        traderData = next_traderData
        for symbol, orders in orders_dict.items():
             product = listings[symbol].product
             for order in orders:
                 current_position[product] = current_position.get(product, 0) + order.quantity

        processed_timestamps += 1
        if processed_timestamps % 1000 == 0: # Print progress
             print(f"Processed {processed_timestamps}/{len(grouped_by_timestamp)} timestamps...")

    print("Backtest simulation and data collection finished.")

else:
    print("Skipping simulation as data was not loaded.")

In [None]:
# Cell 4: Plotting Trade Signals with Interactivity (Zoom + Hover)

# --- IMPORTANT: Enable Interactive Backend ---
# Put this magic command at the VERY START of the cell
%matplotlib widget
# ------------------------------------------

import matplotlib.pyplot as plt
import mplcursors # Import the library for hover effects
import numpy as np

# Make sure this cell is run after Cell 3 where 'trade_signals'
# and 'market_context' are populated.

if not trade_signals and not market_context:
    print("No data collected for visualization.")
else:
    print("Generating interactive plots...")
    print("NOTE: Use the toolbar icons above the plot for zooming and panning.")

    all_products = set(trade_signals.keys()) | set(market_context.keys())

    # Store references to cursor objects if needed later (usually not necessary)
    # active_cursors = {}

    for product in all_products:
        print(f"Plotting for {product}...")

        # Create a NEW figure for each plot when using %matplotlib widget
        # to avoid plots interfering with each other.
        fig, ax = plt.subplots(figsize=(12, 6)) # Create figure and axes

        # --- Plot Mid-Price Context ---
        if product in market_context and market_context[product]["timestamps"]:
            timestamps_ctx = market_context[product]["timestamps"]
            mid_prices_ctx = market_context[product]["mid_prices"]
            # Plotting the line
            line, = ax.plot(timestamps_ctx, mid_prices_ctx, label=f'{product} Mid-Price', color='grey', alpha=0.6, linestyle='-')
            # Add basic hover for the line (optional)
            mplcursors.cursor(line, hover=True) # Shows x, y by default
        else:
            print(f"  - No market context (mid-price) data for {product}.")

        # --- Plot Buy/Sell Signals ---
        if product in trade_signals and trade_signals[product]:
            signals = trade_signals[product]
            # Separate signals for easier processing by mplcursors
            buy_signals = [s for s in signals if s['type'] == 'BUY']
            sell_signals = [s for s in signals if s['type'] == 'SELL']

            buy_scatter, sell_scatter = None, None # Initialize scatter objects

            # Plot Buy Signals
            if buy_signals:
                buy_times = [s['timestamp'] for s in buy_signals]
                buy_prices = [s['price'] for s in buy_signals]
                buy_scatter = ax.scatter(buy_times, buy_prices, label='Buy Signal', marker='^', color='lime', s=100, edgecolors='black', alpha=0.9, zorder=5) # zorder puts points on top

            # Plot Sell Signals
            if sell_signals:
                sell_times = [s['timestamp'] for s in sell_signals]
                sell_prices = [s['price'] for s in sell_signals]
                sell_scatter = ax.scatter(sell_times, sell_prices, label='Sell Signal', marker='v', color='red', s=100, edgecolors='black', alpha=0.9, zorder=5)

            # --- Add Custom Hover Tooltips using mplcursors ---

            # Define a function to format the hover text
            def create_annotation_text(signal_data):
                # signal_data will be one dictionary from buy_signals or sell_signals
                 return (f"{signal_data['type']} @ {signal_data['price']}\n"
                         f"Qty: {signal_data['quantity']}\n"
                         f"Time: {signal_data['timestamp']}")

            # Apply tooltips to BUY points
            if buy_scatter:
                # Use a lambda function to link the selection index to the correct signal data
                cursor_buy = mplcursors.cursor(buy_scatter, hover=True)
                @cursor_buy.connect("add")
                def on_add_buy(sel):
                     signal_info = buy_signals[sel.index] # Get the corresponding dict
                     sel.annotation.set_text(create_annotation_text(signal_info))
                     sel.annotation.get_bbox_patch().set(alpha=0.8) # Make tooltip slightly transparent


            # Apply tooltips to SELL points
            if sell_scatter:
                cursor_sell = mplcursors.cursor(sell_scatter, hover=True)
                @cursor_sell.connect("add")
                def on_add_sell(sel):
                     signal_info = sell_signals[sel.index] # Get the corresponding dict
                     sel.annotation.set_text(create_annotation_text(signal_info))
                     sel.annotation.get_bbox_patch().set(alpha=0.8) # Make tooltip slightly transparent


            if not buy_signals and not sell_signals:
                 print(f"  - No Buy/Sell signals generated by the trader for {product}.")
        else:
             print(f"  - No trade signal data collected for {product}.")


        ax.set_xlabel("Timestamp")
        ax.set_ylabel("Price")
        ax.set_title(f"Interactive Trade Signals for {product}")
        ax.legend()
        ax.ticklabel_format(useOffset=False, style='plain', axis='y') # Prevent scientific notation on price axis
        fig.tight_layout() # Use fig.tight_layout() when using subplots

        # No plt.show() needed with %matplotlib widget, the figure appears automatically
        # plt.show() # Usually not required with the widget backend

    print("Plot generation finished.")
    # Clean up references if needed (usually not critical unless managing many plots)
    # mplcursors.cursor().disconnect_all()