Skip to content

Commit

Permalink
better multi strategy support
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Apr 6, 2024
1 parent 5558280 commit 24b17bf
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ __pycache__
runs/
.idea/
*.db
scratch*.py
scratch/
dist/
roboquant.egg-info/
build/
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ testpaths = [

[tool.pyright]
reportOptionalOperand = "none"
exclude = ["samples/*.py"]
exclude = ["samples/*.py", "scratch/*.py"]

[tool.pylint.MASTER]
ignore-paths = 'samples'
Expand All @@ -27,7 +27,7 @@ requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
exclude = ["docs*", "tests*", "samples*"]
exclude = ["docs*", "tests*", "samples*", "scratch*"]

[tool.setuptools.package-data]
"*" = ["*.json"]
Expand Down
1 change: 1 addition & 0 deletions roboquant/feeds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .randomwalk import RandomWalk
from .sqllitefeed import SQLFeed
from .tiingo import TiingoLiveFeed, TiingoHistoricFeed
from .feedutil import get_sp500_symbols

try:
from .alpacafeed import AlpacaLiveFeed
Expand Down
3 changes: 3 additions & 0 deletions roboquant/feeds/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def play(self, channel: EventChannel):
"""
...

def timeframe(self) -> Timeframe | None:
return None

def play_background(self, timeframe: Timeframe | None = None, channel_capacity: int = 10) -> EventChannel:
"""
Plays this feed in the background on its own thread.
Expand Down
8 changes: 0 additions & 8 deletions roboquant/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
try:
from gymnasium.envs.registration import register

register(id="roboquant/StrategyEnv-v0", entry_point="roboquant.ml.envs:StrategyEnv")
register(id="roboquant/TraderEnv-v0", entry_point="roboquant.ml.envs:TraderEnv")

except ImportError:
pass
3 changes: 3 additions & 0 deletions roboquant/ml/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.registration import register
import numpy as np
from numpy.typing import NDArray
from roboquant.account import Account
Expand All @@ -19,6 +20,8 @@
from roboquant.traders.trader import Trader


register(id="roboquant/StrategyEnv-v0", entry_point="roboquant.ml.envs:StrategyEnv")
register(id="roboquant/TraderEnv-v0", entry_point="roboquant.ml.envs:TraderEnv")
logger = logging.getLogger(__name__)


Expand Down
16 changes: 8 additions & 8 deletions roboquant/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@


def run(
feed: Feed,
strategy: Strategy | None = None,
trader: Trader | None = None,
broker: Broker | None = None,
journal: Journal | None = None,
timeframe: Timeframe | None = None,
capacity: int = 10,
heartbeat_timeout: float | None = None
feed: Feed,
strategy: Strategy | None = None,
trader: Trader | None = None,
broker: Broker | None = None,
journal: Journal | None = None,
timeframe: Timeframe | None = None,
capacity: int = 10,
heartbeat_timeout: float | None = None,
) -> Account:
"""Start a new run.
Expand Down
33 changes: 33 additions & 0 deletions roboquant/strategies/emacrossover.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,36 @@ def add_price(self, price: float):
self.price2 = m2 * self.price2 + (1.0 - m2) * price
self.step += 1
return self.step


class _Calculator2:

__slots__ = "entries", "step"

def __init__(self, *momentums, price):
self.entries = [[m, price] for m in momentums]
self.step = 0

def is_above(self):
prev = None
for _, p in self.entries:
if prev is not None and p <= prev:
return False
prev = p
return True

def is_below(self):
prev = None
for _, p in self.entries:
if prev is not None and p >= prev:
return False
prev = p
return True

def add_price(self, price: float):
for entry in self.entries:
m, p = entry
entry[1] = m * p + (1 - m) * price

self.step += 1
return self.step
33 changes: 22 additions & 11 deletions roboquant/strategies/multistrategy.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,44 @@
from typing import Literal
from itertools import groupby
from statistics import mean

from roboquant.event import Event
from roboquant.signal import Signal
from roboquant.strategies.strategy import Strategy


class MultiStrategy(Strategy):
"""Combine one or more strategies. The MultiStrategy provides additional control on how to handle conflicting
signals for the same symbols:
signals for the same symbols via the signal_filter:
- first: in case of multiple signals for a symbol, the first strategy wins
- last: in case of multiple signals for a symbol, the last strategy wins. This is also the default policy
- first: in case of multiple signals for the same symbol, the first one wins
- last: in case of multiple signals for the same symbol, the last one wins.
- avg: return the avgerage of the signals. All signals will be ENTRY and EXIT.
- none: return all signals. This is also the default.
"""

def __init__(self, *strategies: Strategy, policy: Literal["last", "first", "all"] = "last"):
def __init__(self, *strategies: Strategy, signal_filter: Literal["last", "first", "avg", "none"] = "none"):
self.strategies = list(strategies)
self.policy = policy
self.signal_filter = signal_filter

def create_signals(self, event: Event):
signals = []
signals: list[Signal] = []
for strategy in self.strategies:
tmp = strategy.create_signals(event)
signals += tmp
signals += strategy.create_signals(event)

match self.policy:
match self.signal_filter:
case "none":
return signals
case "last":
s = {s.symbol: s for s in signals}
return list(s.values())
case "first":
s = {s.symbol: s for s in reversed(signals)}
return list(s.values())
case "all":
return signals
case "avg":
result = []
g = groupby(signals, lambda x: x.symbol)
for symbol, v in g:
rating = mean(s.rating for s in v)
result.append(Signal(symbol, rating))
return result
5 changes: 1 addition & 4 deletions samples/alpaca_feed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# %%
from datetime import timedelta
from roboquant.feeds.aggregate import AggregatorFeed
from roboquant.feeds.alpacafeed import AlpacaLiveFeed
from roboquant.feeds.feedutil import get_sp500_symbols

from roboquant.feeds import AggregatorFeed, AlpacaLiveFeed, get_sp500_symbols

# %%
alpaca_feed = AlpacaLiveFeed()
Expand Down
11 changes: 6 additions & 5 deletions samples/tensorboard_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from roboquant.journals import TensorboardJournal, PNLMetric, RunMetric, FeedMetric, PriceItemMetric, AlphaBeta

# %%
# Compare 3 runs with different parameters using tensorboard
feed = rq.feeds.YahooFeed("JPM", "IBM", "F", start_date="2000-01-01")
# Compare runs with different parameters using tensorboard
feed = rq.feeds.YahooFeed("JPM", "IBM", "F", "MSFT", "V", "GE","CSCO", "WMT", "XOM", "INTC", start_date="2010-01-01")

params = [(3, 5), (13, 26), (12, 50)]
hyper_params = [(3, 5), (13, 26), (12, 50)]

for p1, p2 in params:
for p1, p2 in hyper_params:
s = rq.strategies.EMACrossover(p1, p2)
log_dir = f"""runs/ema_{p1}_{p2}"""
writer = Writer(log_dir)
journal = TensorboardJournal(writer, PNLMetric(), RunMetric(), FeedMetric(), PriceItemMetric("JPM"), AlphaBeta(200))
rq.run(feed, s, journal=journal)
account = rq.run(feed, s, journal=journal)
print(p1, p2, account.equity())
writer.close()

0 comments on commit 24b17bf

Please sign in to comment.