In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display

from deep_orderbook.test_learn import train_and_predict


# Define your asynchronous function to update the figure
async def update_figure(max_samples=100, epoch=1):
    # Initialize the figure with three subplots and adjust row heights
    fig = make_subplots(
        rows=5, cols=1,
        # subplot_titles=("Books", "Level Proximity", "Bid and Ask Price Levels"),
        vertical_spacing=0.05,
        row_heights=[0.2, 0.2, 0.2, 0.2, 0.2]  # Adjust the relative heights of the rows
    )

    # Increase the overall figure height
    fig.update_layout(
        height=1200,  # Adjusted for three subplots
        width=1200,
        margin=dict(t=50, b=50, l=50, r=50),
        showlegend=True
    )

    # Create the figure widget
    fig_widget = go.FigureWidget(fig)
    display(fig_widget)

    # Initialize empty traces for heatmaps
    im_trace = go.Heatmap(
        z=np.zeros((10, 10)),
        colorscale='RdBu',
        zmin=-1, zmax=1,
        showscale=False
    )
    t2l_trace = go.Heatmap(
        z=np.zeros((10, 10)),
        colorscale='Turbo',
        zmin=0, zmax=1,
        showscale=False
    )

    # Initialize line traces for bid and ask price levels
    bid_trace = go.Scatter(
        x=[], y=[],
        mode='lines',
        line=dict(color='green'),
        showlegend=False
    )
    ask_trace = go.Scatter(
        x=[], y=[],
        mode='lines',
        line=dict(color='red'),
        showlegend=False
    )

    pred_trace = go.Heatmap(
        z=np.zeros((10, 10)),
        colorscale='Turbo',
        zmin=0, zmax=1,
        showscale=False
    )

    # Initialize line traces for bid and ask price levels
    loss_trace = go.Scatter(
        x=[], y=[],
        mode='lines',
        line=dict(color='green'),
        showlegend=False
    )

    # Add traces to the figure widget
    fig_widget.add_trace(im_trace, row=1, col=1)
    fig_widget.add_trace(t2l_trace, row=2, col=1)
    fig_widget.add_trace(bid_trace, row=3, col=1)
    fig_widget.add_trace(ask_trace, row=3, col=1)
    fig_widget.add_trace(pred_trace, row=4, col=1)
    fig_widget.add_trace(loss_trace, row=5, col=1)

    # Asynchronous loop to update the figure
    losses = []
    async for shaped, t2l, pxar, prediction, loss in train_and_predict(max_samples=max_samples, epoch=epoch):
        # Process your data for heatmaps
        im_data = shaped.copy().transpose(1, 0, 2)
        im_data[:, :, 0] *= -0.5
        im_data[:, :, 1:3] *= 1e6
        im_data = np.clip(im_data, -1, 1)
        t2l_data = np.clip(t2l[:, :, 0].T, -1, 1)

        # Extract bid and ask prices from pxar
        bid_prices = pxar[:, 0]
        ask_prices = pxar[:, 1]
        times = np.arange(len(bid_prices))  # Assuming levels from 0 to 511

        pred_shape = prediction.reshape(t2l.shape).transpose(1, 0, 2)
        pred_shape = np.clip(pred_shape[:, :, 0], -1, 1)
        losses.append(loss)
        losses =losses[-512:]

        # Update the figure widget traces
        with fig_widget.batch_update():
            # Update heatmaps
            fig_widget.data[0].z = im_data.mean(axis=2)
            fig_widget.data[1].z = t2l_data

            # Update bid and ask price traces
            fig_widget.data[2].x = times
            fig_widget.data[2].y = bid_prices
            fig_widget.data[3].x = times
            fig_widget.data[3].y = ask_prices

            # Update prediction heatmap
            fig_widget.data[4].z = pred_shape
            fig_widget.data[5].x = np.arange(len(losses))
            fig_widget.data[5].y = losses

# Run the asynchronous function
await update_figure(max_samples=100000, epoch=10)