Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #750: [Sim] Refactor plot code in sim_engine #755

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
2 changes: 1 addition & 1 deletion pdr_backend/aimodel/test/test_aimodel_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enforce_typing import enforce_types
from pytest import approx

from pdr_backend.aimodel.plot_model import plot_model
from pdr_backend.aimodel.model_plotter import plot_model
from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory
from pdr_backend.aimodel.aimodel_factory import AimodelFactory
from pdr_backend.ppss.aimodel_ss import AimodelSS, aimodel_ss_test_dict
Expand Down
4 changes: 2 additions & 2 deletions pdr_backend/cli/test/test_cli_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_do_sim(monkeypatch):
mock_f = Mock()
monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f)

with patch("pdr_backend.sim.sim_engine.plt.show"):
with patch("pdr_backend.sim.sim_plotter.plt.show"):
do_sim(MockArgParser_PPSS_NETWORK().parse_args())

mock_f.assert_called()
Expand All @@ -385,7 +385,7 @@ def test_do_main(monkeypatch, capfd):
mock_f = Mock()
monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f)

with patch("pdr_backend.sim.sim_engine.plt.show"):
with patch("pdr_backend.sim.sim_plotter.plt.show"):
with patch("sys.argv", ["pdr", "sim", "ppss.yaml"]):
_do_main()

Expand Down
206 changes: 6 additions & 200 deletions pdr_backend/sim/sim_engine.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,22 @@
import copy
import logging
import os
from typing import Dict, List

from enforce_typing import enforce_types
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import random
import polars as pl
from statsmodels.stats.proportion import proportion_confint

from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory
from pdr_backend.aimodel.aimodel_factory import AimodelFactory
from pdr_backend.aimodel.plot_model import plot_model
from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory
from pdr_backend.ppss.ppss import PPSS
from pdr_backend.util.currency_types import Eth
from pdr_backend.sim.sim_state import SimState
from pdr_backend.sim.sim_plotter import SimPlotter
from pdr_backend.util.mathutil import classif_acc
from pdr_backend.util.time_types import UnixTimeMs

logger = logging.getLogger("sim_engine")
FONTSIZE = 9


# pylint: disable=too-many-instance-attributes
class SimEngineState:
def __init__(self, init_holdings: Dict[str, Eth]):
self.holdings: Dict[str, float] = {
tok: float(amt.amt_eth) for tok, amt in init_holdings.items()
}
self.init_loop_attributes()

def init_loop_attributes(self):
self.accs_train: List[float] = []
self.ybools_test: List[float] = []
self.ybools_testhat: List[float] = []
self.probs_up: List[float] = []
self.corrects: List[bool] = []
self.trader_profits_USD: List[float] = []
self.pdr_profits_OCEAN: List[float] = []


# pylint: disable=too-many-instance-attributes
Expand All @@ -57,15 +34,15 @@ def __init__(self, ppss: PPSS):

self.ppss = ppss

self.st = SimEngineState(
self.st = SimState(
copy.copy(self.ppss.trader_ss.init_holdings),
)

self.plot_state = None
self.sim_plotter = None
if self.ppss.sim_ss.do_plot:
n = self.ppss.predictoor_ss.aimodel_ss.n # num input vars
include_contour = n == 2
self.plot_state = PlotState(include_contour)
self.sim_plotter = SimPlotter(self.ppss, self.st, include_contour)

self.logfile = ""

Expand Down Expand Up @@ -244,9 +221,7 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame):

# plot
if self.do_plot(test_i, self.ppss.sim_ss.test_n):
self.plot_state.make_plot( # type: ignore[union-attr]
self.st,
self.ppss,
self.sim_plotter.make_plot( # type: ignore[union-attr]
model,
X_train,
ybool_train,
Expand Down Expand Up @@ -340,172 +315,3 @@ def do_plot(self, i: int, N: int):
return False

return True


@enforce_types
class PlotState:
def __init__(self, include_contour: bool):
self.include_contour = include_contour

fig = plt.figure()
self.fig = fig

if include_contour:
gs = gridspec.GridSpec(2, 4, width_ratios=[5, 1, 1, 5])
else:
gs = gridspec.GridSpec(2, 3, width_ratios=[5, 1, 1])

self.ax00 = fig.add_subplot(gs[0, 0])
self.ax01 = fig.add_subplot(gs[0, 1:3])
self.ax10 = fig.add_subplot(gs[1, 0])
self.ax11 = fig.add_subplot(gs[1, 1])
self.ax12 = fig.add_subplot(gs[1, 2])
if include_contour:
self.ax03 = fig.add_subplot(gs[:, 3])

self.x: List[float] = []
self.y01_est: List[float] = []
self.y01_l: List[float] = []
self.y01_u: List[float] = []
self.plotted_before: bool = False
plt.ion()
plt.show()

# pylint: disable=too-many-statements
def make_plot(self, st, ppss, model, X_train, ybool_train, colnames):
stake_amt = ppss.predictoor_ss.stake_amount.amt_eth

fig = self.fig
ax00, ax01 = self.ax00, self.ax01
ax10, ax11, ax12 = self.ax10, self.ax11, self.ax12

N = len(st.pdr_profits_OCEAN)
N_done = len(self.x) # what # points have been plotted previously

# set x
self.x = list(range(0, N))
next_x = _slice(self.x, N_done, N)
next_hx = [next_x[0], next_x[-1]] # horizontal x

# plot row 0, col 0: predictoor profit vs time
y00 = list(np.cumsum(st.pdr_profits_OCEAN))
next_y00 = _slice(y00, N_done, N)
ax00.plot(next_x, next_y00, c="g")
ax00.plot(next_hx, [0, 0], c="0.2", ls="--", lw=1)
s = f"Predictoor profit vs time. Current:{y00[-1]:.2f} OCEAN"
_set_title(ax00, s)
if not self.plotted_before:
ax00.set_ylabel("predictoor profit (OCEAN)", fontsize=FONTSIZE)
ax00.set_xlabel("time", fontsize=FONTSIZE)
_ylabel_on_right(ax00)
ax00.margins(0.005, 0.05)

# plot row 0, col 1: % correct vs time
for i in range(N_done, N):
n_correct = sum(st.corrects[: i + 1])
n_trials = len(st.corrects[: i + 1])
l, u = proportion_confint(count=n_correct, nobs=n_trials)
self.y01_est.append(n_correct / n_trials * 100)
self.y01_l.append(l * 100)
self.y01_u.append(u * 100)
next_y01_est = _slice(self.y01_est, N_done, N)
next_y01_l = _slice(self.y01_l, N_done, N)
next_y01_u = _slice(self.y01_u, N_done, N)

ax01.plot(next_x, next_y01_est, "green")
ax01.fill_between(next_x, next_y01_l, next_y01_u, color="0.9")
ax01.plot(next_hx, [50, 50], c="0.2", ls="--", lw=1)
ax01.set_ylim(bottom=40, top=60)
now_s = f"{self.y01_est[-1]:.2f}% "
now_s += f"[{self.y01_l[-1]:.2f}%, {self.y01_u[-1]:.2f}%]"
_set_title(ax01, f"% correct vs time. Current: {now_s}")
if not self.plotted_before:
ax01.set_xlabel("time", fontsize=FONTSIZE)
ax01.set_ylabel("% correct", fontsize=FONTSIZE)
_ylabel_on_right(ax01)
ax01.margins(0.01, 0.01)

# plot row 0, col 2: model contour
if self.include_contour:
ax03 = self.ax03
labels = tuple([_shift_one_earlier(colname) for colname in colnames])
plot_model(model, X_train, ybool_train, labels, (fig, ax03))
if not self.plotted_before:
ax03.margins(0.01, 0.01)

# plot row 1, col 0: trader profit vs time
y10 = list(np.cumsum(st.trader_profits_USD))
next_y10 = _slice(y10, N_done, N)
ax10.plot(next_x, next_y10, c="b")
ax10.plot(next_hx, [0, 0], c="0.2", ls="--", lw=1)
_set_title(ax10, f"Trader profit vs time. Current: ${y10[-1]:.2f}")
if not self.plotted_before:
ax10.set_xlabel("time", fontsize=FONTSIZE)
ax10.set_ylabel("trader profit (USD)", fontsize=FONTSIZE)
_ylabel_on_right(ax10)
ax10.margins(0.005, 0.05)

# reusable profits scatterplot
def _scatter_profits(ax, actor: str, denomin, mnp, mxp, st_profits):
next_probs_up = _slice(st.probs_up, N_done, N)
next_profits = _slice(st_profits, N_done, N)
c = (random(), random(), random()) # random RGB color
ax.scatter(next_probs_up, next_profits, color=c, s=1)
avg = np.average(st_profits)
s = f"{actor} profit distr'n. avg={avg:.2f} {denomin}"
_set_title(ax, s)
ax.plot([0.5, 0.5], [mnp, mxp], c="0.2", ls="-", lw=1)
if not self.plotted_before:
ax.plot([0.0, 1.0], [0, 0], c="0.2", ls="--", lw=1)
_set_xlabel(ax, "prob(up)")
_set_ylabel(ax, f"{actor} profit ({denomin})")
_ylabel_on_right(ax)
ax.margins(0.05, 0.05)

# plot row 1, col 1: 1d scatter of predictoor profits
mnp, mxp = -stake_amt, +stake_amt
_scatter_profits(ax11, "pdr", "OCEAN", mnp, mxp, st.pdr_profits_OCEAN)

# plot row 1, col 2: 1d scatter of trader profits
mnp, mxp = min(st.trader_profits_USD), max(st.trader_profits_USD)
_scatter_profits(ax12, "trader", "USD", mnp, mxp, st.trader_profits_USD)

# final pieces
HEIGHT = 7.5 # magic number
WIDTH = int(HEIGHT * 3.2) # magic number
fig.set_size_inches(WIDTH, HEIGHT)
fig.tight_layout(pad=0.5, h_pad=1.0, w_pad=1.0)
plt.pause(0.001)
self.plotted_before = True


def _shift_one_earlier(s: str):
"""eg 'binance:BTC/USDT:close:t-3' -> 'binance:BTC/USDT:close:t-2'"""
val = int(s[-1])
return s[:-1] + str(val - 1)


def _set_xlabel(ax, s: str):
ax.set_xlabel(s, fontsize=FONTSIZE)


def _set_ylabel(ax, s: str):
ax.set_ylabel(s, fontsize=FONTSIZE)


def _set_title(ax, s: str):
ax.set_title(s, fontsize=FONTSIZE, fontweight="bold")


def _slice(a: list, N_done: int, N: int) -> list:
return [a[i] for i in range(max(0, N_done - 1), N)]


def _ylabel_on_right(ax):
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")


def _del_lines(ax):
for l in ax.lines:
l.remove()
Loading
Loading