In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import matplotlib
from matplotlib import cm

import hydra.utils as utils
from hydra.utils import split_dupes, write
from hydra.plotly_charts import PlotlyPriceChart, supersample_data


# Slider for parent window
# Slider for child window
# Slider for visible range


# SYMBOLS = ["circle", "square", "diamond-tall", "diamond-wide"]
# COLORS = ["red", "blue", "lime", "goldenrod"]
pair = "BTCUSD"

HEATMAP_COLORS = "afmhot_r"
chart_settings = dict(
    trendline=False, ifft_price=False, extrapolated=True, showHeatMap=False
)


window_and_timespan = "90-365"
outputDir = Path("../output") / f"enviro-chunky-{window_and_timespan}"


In [2]:
import re

fileList = list(outputDir.rglob(f"year=*/**/*[!.xtrp].parq"))
pattern = "year=(\d+)/month=(\d+)/day=(\d+)"
dateList = [
    (
        match := re.search(pattern, str(file)),
        f"{match.group(1)}-{match.group(2)}-{match.group(3)}",
        file.stem,
    )
    for file in fileList
]
dateList = [(pd.to_datetime(d), int(r)) for m, d, r in dateList]
dates, rootWindows = zip(*dateList)

dateRootMap = pd.Series(rootWindows, dates).sort_index()

rootWindowList = sorted({int(file.stem) for file in fileList})
print("rootWindowList", rootWindowList)

ROOT_WINDOW = 1


def loadOutputs(rootWindow="*"):
    fname = f"{rootWindow}[!.xtrp]" if rootWindow == "*" else rootWindow
    outputFiles = list(outputDir.rglob(f"year=*/**/{fname}.parq"))
    return pd.concat([pd.read_parquet(file) for file in outputFiles])


def getOutputs(rootWindow="*", chartSubset=False):
    outputs = loadOutputs(rootWindow)
    outputs = outputs.sort_index()[
        [
            "minPerCycle",
            "deviance",
            "ifft_extrapolated_wavelength",
            "ifft_extrapolated_amplitude",
            "ifft_extrapolated_deviance",
            "first_extrapolated",
            "first_extrapolated_date",
            "first_extrapolated_isup",
            "startDate",
            "endDate",
            "window",
            "window_original",
            "trend_deviance",
            "trend_slope",
            "trend_intercept",
            "rootNumber",
        ]
    ]

    outputs["text"] = (
        outputs["endDate"].astype(str)
        + "("
        + (outputs["first_extrapolated_date"] - outputs["endDate"]).astype(str)
        + ")<br>isup = "
        + outputs["first_extrapolated_isup"].astype(str)
        + "<br>🌊"
        + outputs["ifft_extrapolated_wavelength"].astype(str)
        + "<br>🔊"
        + outputs["ifft_extrapolated_amplitude"].astype(str)
    )
    outputs["symbol"] = np.where(
        outputs["first_extrapolated_isup"], "triangle-up", "triangle-down"
    )
    outputs["color"] = np.where(
        outputs["first_extrapolated_isup"], "green", "red"
    )
    outputs["visible"] = True
    outputs = outputs.set_index("first_extrapolated_date", drop=False)

    if chartSubset:
        return outputs[
            [
                "first_extrapolated",
                "text",
                "endDate",
                "ifft_extrapolated_wavelength",
                "symbol",
                "color"
            ]
        ]
    return outputs


outputs = getOutputs(ROOT_WINDOW)
print(outputs.columns)

startDate = outputs["startDate"].min()
endDate = outputs["endDate"].max()
extrapolationStart = outputs.index.min()


rootWindowList [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 78, 80, 81, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 134, 135, 136, 138, 140, 141, 142, 144, 145, 146, 147, 148, 150, 152, 153, 154, 155, 156, 159, 160, 161, 162, 164, 165, 166, 168, 170, 171, 172, 174, 175, 176, 177, 178, 180, 182, 183, 184, 185, 186, 188, 189, 190, 192, 195, 196, 198, 200, 201, 202, 204, 205, 206, 207, 208, 210, 212, 213, 215, 216, 217, 218, 220, 222, 224, 225, 226, 228, 230, 232, 234, 235, 236, 238, 240, 242, 243, 244, 245, 246, 248, 250, 252, 254, 255, 256, 258, 260, 261, 262, 264

In [3]:
print(getOutputs(1)[["first_extrapolated", "text", "endDate", "ifft_extrapolated_wavelength"]])
# print(getOutputs(2)[["first_extrapolated", "text", "endDate"]])

                            first_extrapolated  \
first_extrapolated_date                          
2018-01-01 00:00:00.000000        16245.763578   
2018-01-03 02:35:00.000000        16787.234250   
2018-01-03 16:33:49.090909        16979.636113   
2018-01-04 09:41:00.000000        17142.169613   
2018-02-09 17:20:00.000000        20752.032988   
...                                        ...   
2018-12-31 16:02:00.000000         3053.017219   
2019-01-02 00:59:31.764705         2979.125842   
2018-12-31 04:28:00.000000         3082.042220   
2019-01-02 02:05:00.000000         2988.347700   
2019-01-02 07:58:13.846153         3000.343520   

                                                                         text  \
first_extrapolated_date                                                         
2018-01-01 00:00:00.000000  2018-01-01(0 days 00:00:00)<br>isup = True<br>...   
2018-01-03 02:35:00.000000  2018-01-01(2 days 02:35:00)<br>isup = True<br>...   
2018-01-03 16:33:49.09090

In [4]:
import traceback
import random


# startDate = "2017-05-15 00:00"
# endDate = "2020-05-16 00:00"
print(startDate, endDate)


# print(trendlines)
priceLoc = (2, 1) if chart_settings["showHeatMap"] else (1, 1)
priceChart = PlotlyPriceChart(pair, startDate, endDate, loc=priceLoc)


def add_trace(
    data,
    chartParams={},
    interval=1,
    useIntervals=[],
    alwaysUseData=False,
    approximateIntervals=False,
    onlySlice=False,
    fields={"x": "index", "y": "values"},
    metaHandler=None,
):
    if useIntervals == False:
        priceChart.add_trace(
            go.Scatter(x=data.index, y=data.values, **chartParams),
            onlySlice=onlySlice,
            fields=fields,
            metaHandler=metaHandler,
            loc=priceLoc,
        )
        return

    supersampled = supersample_data(
        data,
        interval,
        useIntervals,
        alwaysUseData=alwaysUseData,
        approximate=approximateIntervals,
    )
    priceChart.add_trace(
        go.Scatter(
            x=supersampled[1440].index, y=supersampled[1440].values, **chartParams
        ),
        data=supersampled,
        onlySlice=onlySlice,
        fields=fields,
        metaHandler=metaHandler,
        loc=priceLoc,
    )


# if chart_settings["ifft_price"]:
#     for index in range(traceCount):
#         if index in chart_data:
#             data = pd.concat(chart_data[index]["ifft_price"])
#             add_trace(data, chartParams={"mode": "lines", "name": f"ifft-{index}"})
# if chart_settings["trendline"]:
#     for index in range(traceCount):
#         if index in chart_data:
#             data = pd.concat(chart_data[index]["trendline"])
#             add_trace(
#                 data,
#                 chartParams={
#                     "mode": "markers",
#                     "name": f"trend-{index}",
#                     # "marker_symbol":SYMBOLS[index]
#                 },
#             )

# color = cm.get_cmap(HEATMAP_COLORS, len(agg_data_by_distance))
# useIntervals = [
#     srate
#     for srate in [1, 5, 15, 60, 720, 1440]
#     if pd.Timedelta(f"{srate * 250 / 30}min") < window
# ]
useIntervals = [1, 5, 15, 60, 720, 1440]
print(useIntervals)


if chart_settings["extrapolated"]:

    def extrapolated_metaHandlerFactory():
        lastRootWindowStart = ROOT_WINDOW
        lastRootWindowEnd = ROOT_WINDOW
        lastWavelengthStart = 0
        lastWavelengthEnd = 0
        first = True

        def extrapolated_metaHandler(lastData, figure, meta):
            try:

                nonlocal first
                nonlocal lastRootWindowStart
                nonlocal lastRootWindowEnd
                nonlocal lastWavelengthStart
                nonlocal lastWavelengthEnd

                nextData = lastData

                nextRootWindowStart = meta.get("rootWindowStart", lastRootWindowStart)
                nextRootWindowEnd = meta.get("rootWindowEnd", lastRootWindowEnd)
                nextWavelengthStart = float(
                    meta.get("wavelengthStart", lastWavelengthStart)
                )
                nextWavelengthEnd = float(meta.get("wavelengthEnd", lastWavelengthEnd))

                write(f" === extrapolated_meta")
                write(
                    f"Root Window start {lastRootWindowStart}->{nextRootWindowStart} end {lastRootWindowEnd}=>{nextRootWindowEnd}"
                )
                write(
                    f"Wavelength start {lastWavelengthStart}=>{nextWavelengthStart} end {lastWavelengthEnd}=>{nextWavelengthEnd}"
                )

                if first:
                    first = False
                    write("FIRST!!!")
                    nextData = getOutputs(1, chartSubset=True)

                # Update to rootWindowStart
                elif nextRootWindowStart != lastRootWindowStart:
                    # if nextRootWindowEnd is invalid, only load the startRootWindow
                    if (
                        nextRootWindowEnd is None
                        or nextRootWindowEnd <= nextRootWindowStart
                    ):
                        nextData = getOutputs(nextRootWindowStart, chartSubset=True)
                        write(
                            "extrapolated_metaHandler",
                            "Data Loaded",
                            nextRootWindowStart,
                        )

                    # rootWindowEnd is valid so load the range
                    else:
                        nextData = pd.concat(
                            [
                                getOutputs(rw, chartSubset=True)
                                for rw in rootWindowList
                                if rw >= nextRootWindowStart and rw <= nextRootWindowEnd
                            ]
                        )
                        write(
                            "extrapolated_metaHandler",
                            f"Data Range Loaded {nextRootWindowStart}-{nextRootWindowEnd}",
                        )
                    lastRootWindowStart = nextRootWindowStart

                # Update to rootWindowEnd
                elif nextRootWindowEnd != lastRootWindowEnd:
                    if nextRootWindowStart is None:
                        return lastData

                    # if next end is before or equal to start
                    if nextRootWindowEnd <= nextRootWindowStart:
                        # and previous end was before start
                        if lastRootWindowEnd <= nextRootWindowStart:
                            # reuse existing data
                            return lastData

                        # and previous end was after start
                        # moved end from after start to before start
                        elif lastRootWindowEnd > nextRootWindowStart:
                            nextData = getOutputs(nextRootWindowStart, chartSubset=True)
                            write(
                                "extrapolated_metaHandler",
                                "Data Loaded",
                                nextRootWindowStart,
                            )
                    # else next end is after start; load range
                    else:
                        nextData = pd.concat(
                            [
                                getOutputs(rw, chartSubset=True)
                                for rw in rootWindowList
                                if rw >= nextRootWindowStart and rw <= nextRootWindowEnd
                            ]
                        )
                        write(
                            "extrapolated_metaHandler",
                            f"Data Range Loaded {nextRootWindowStart}-{nextRootWindowEnd}",
                        )

                    lastRootWindowEnd = nextRootWindowEnd

                if (
                    lastWavelengthStart != nextWavelengthStart
                    or lastWavelengthEnd != nextWavelengthEnd
                ) and not isinstance(nextData, pd.DataFrame):
                    # reset to smallest granularity
                    nextData = nextData[1]
                    lastWavelengthStart = nextWavelengthStart
                    lastWavelengthEnd = nextWavelengthEnd

                # Process Data
                # nextData is a dataframe when it has changed
                if isinstance(nextData, pd.DataFrame):
                    # filter by wavelength
                    if nextWavelengthEnd <= nextWavelengthStart:
                        nextData["visible"] = np.isclose(
                            nextData["ifft_extrapolated_wavelength"], nextWavelengthStart
                        )
                    else:
                        nextData["visible"] = np.where(
                            (
                                nextData["ifft_extrapolated_wavelength"]
                                >= nextWavelengthStart
                            )
                            & (
                                nextData["ifft_extrapolated_wavelength"]
                                <= nextWavelengthEnd
                            ), True,  False
                        )

                    nextData = supersample_data(
                        nextData,
                        1,
                        useIntervals,
                        alwaysUseData=True,
                    )
                    write(
                        "extrapolated_metaHandler",
                        "Data Processed",
                        lastData[1].equals(nextData[1]),
                    )

                # write("LAYOUT", figure.layout)
                sliders = list(figure.layout.sliders)

                # if sliders exist, update them
                if len(sliders) and figure.layout.xaxis.range is not None:
                    [start, end] = figure.layout.xaxis.range
                    # chart panned/zoomed
                    if nextData[1].equals(lastData[1]):
                        # Invalidate certain root windows if not within range
                        try:
                            subset = dateRootMap.loc[start:end]
                        
                            validRoots = sorted(set(subset.values))
                            # write(len(sliders[1].steps))

                            # hide invalid roots windows
                            for step in sliders[1].steps:
                                step.visible = (
                                    step.label is not None and int(step.label) in validRoots
                                )
                            for step in sliders[2].steps:
                                step.visible = (
                                    step.label is not None and int(step.label) in validRoots
                                )
                        except KeyError as err:
                            write(f"Key Error {start=} {end=}")
                            write(dateRootMap.index)
                            raise err

                    # TODO: if old frequencies == new frequencies then don't update?
                    # may provide minor performance boost

                    # Sometimes there are no wavelengths for the selected window
                    # What should the slider look like?
                    try:
                        wavelengths = sorted(
                            set(nextData[1].loc[start:end, "ifft_extrapolated_wavelength"])
                        )

                        for sliderIdx, nextWavelength, metaName in [(3, nextWavelengthStart, 'wavelengthStart'), (4, nextWavelengthEnd, 'wavelengthEnd')]:
                            newSteps = generateMetaSteps(
                                metaName, wavelengths, minimumSteps=len(sliders[sliderIdx].steps)
                            )
                            write(f"~~~ {metaName} {len(wavelengths)=} {len(newSteps)=}")

                            lastIdx = sliders[sliderIdx].active
                            closestIdx = (
                                utils.getClosestIndex_np(wavelengths, nextWavelength)
                                if len(wavelengths)
                                else 0
                            )

                            sliders[sliderIdx] = go.layout.Slider(
                                active=closestIdx,
                                currentvalue=sliders[sliderIdx].currentvalue,
                                pad=sliders[sliderIdx].pad,
                                steps=newSteps,
                            )
                            if len(wavelengths):
                                write(
                                    f"Slider {sliderIdx} step {lastIdx}=>{closestIdx} waves {nextWavelength}=>{wavelengths[closestIdx]}"
                                )

                        figure.update_layout(sliders=sliders)
                    except KeyError as err:
                        write(f"Key Error 2 {start=} {end=} {1 in nextData}")
                        write(nextData[1].index)
                        raise err

                return nextData
            except Exception as err:
                msg = "".join(
                    traceback.format_exception(type(err), err, err.__traceback__)
                )
                write("error", msg)
                raise err

        return extrapolated_metaHandler

    # print(outputs_mod["symbol"], outputs_mod[["symbol"]])
    chartOutputs = getOutputs(ROOT_WINDOW, chartSubset=True)
    add_trace(
        chartOutputs,
        chartParams={
            "mode": "markers",
            "name": f"prediction",
            "marker_size": 10,
            "marker_symbol": chartOutputs["symbol"],
            "marker_color": chartOutputs["color"],
            "text": "text",
        },
        useIntervals=useIntervals,
        alwaysUseData=True,
        # approximateIntervals=True,
        onlySlice=True,
        fields={
            "x": "index",
            "y": "first_extrapolated",
            "text": "text",
            "marker_symbol": "symbol",
            "marker_color": "color",
        },
        metaHandler=extrapolated_metaHandlerFactory(),
    )

if chart_settings["showHeatMap"]:
    y = complexity.index.tolist()
    # y = [str(float(val).round(4)) for val in y]
    priceChart.add_trace(
        go.Heatmap(
            z=complexity.values.tolist(),
            x=complexity.columns.tolist(),
            y=y,
            showscale=False,
        ),
        loc=(1, 1),
        traceArgs={},
    )
    # priceChart.figure.add_vline(x=startDate, line_width=3)


def generateMetaSteps(name, iterator, stepArgs={}, asObject=False, minimumSteps=0):
    steps = sorted(set(iterator))
    style = go.layout.slider.Step if asObject else dict
    result = [
        style(
            visible=True,
            method="relayout",
            args=[
                # {"meta.rootWindow": rootWindow}
                {"meta": {name: value}}
            ],  # layout attribute
            label=value,
            **stepArgs,
        )
        for value in steps
    ]

    if minimumSteps > len(result):
        # write(f"{minimumSteps=} < {len(result)=}")
        adds = [
            style(execute=True, visible=False)
            for value in range(minimumSteps - len(result))
        ]
        # write("Supplementing", adds)
        result = result + adds
        # write('updated steps', result)
    return result


# frequency_slider_steps = []
# for i in range(len(priceChart.figure.data)):
#     step = dict(
#         method="update",
#         args=[{"visible": [False] * len(priceChart.figure.data)},
#               {"title": "Slider switched to step: " + str(i)}],  # layout attribute
#     )
#     step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
#     time_slider_steps.append(step)


sliderPad = 150 if chart_settings["showHeatMap"] else 0
sliders = [
    dict(
        active=50,
        pad={"t": sliderPad},
        currentvalue={"prefix": "%Time: "},
        steps=generateMetaSteps("timeSlider", range(100)),
    ),
    dict(
        pad={"t": 150},
        currentvalue={"prefix": "Root Window Start: "},
        steps=generateMetaSteps("rootWindowStart", rootWindowList),
    ),
    dict(
        pad={"t": 225},
        currentvalue={"prefix": "Root Window End: ", "xanchor": "right"},
        steps=generateMetaSteps("rootWindowEnd", rootWindowList),
    ),
    dict(
        pad={"t": 300},
        currentvalue={"prefix": "Wavelength Start: "},
        steps=generateMetaSteps(
            "wavelengthStart", outputs["ifft_extrapolated_wavelength"]
        ),
    ),
    dict(
        pad={"t": 375},
        currentvalue={"prefix": "Wavelength End: ", "xanchor": "right"},
        steps=generateMetaSteps(
            "wavelengthEnd", outputs["ifft_extrapolated_wavelength"]
        ),
    ),
    # dict(
    #     active=0,
    #     currentvalue={"prefix": "Root Window Start:"},
    #     pad={"t": 150},
    #     steps=rootWindowSliderStepsStart,
    # ),
    # dict(
    #     active=0,
    #     currentvalue={"prefix": "Root Window End:", "xanchor": "right"},
    #     pad={"t": 200},
    #     steps=rootWindowSliderStepsEnd,
    # ),
]
# ,dict(
#     active=10,
#     currentvalue={"prefix": "Frequency: "},
#     pad={"t": 250},
#     steps=frequency_slider_steps
# )


windowSize = pd.to_timedelta("1d")
priceChart.render(
    sliders=sliders,
    zoomStart=extrapolationStart - windowSize,
    zoomEnd=extrapolationStart + windowSize,
)


2017-10-03 00:01:00 2018-12-31 00:00:00
registering handler 0 False
[1, 5, 15, 60, 720, 1440]
registering handler 1 True


FigureWidget({
    'data': [{'close': array([ 4309. ,  4223.6,  4324.2,  4363.1,  4432. ,  4592. ,  4772.1,  4…