Skip to content

Commit

Permalink
added some __str__
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 15, 2024
1 parent 28af910 commit 657d7dc
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 58 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tensorboard>=2.15.1
nautilus-ibapi~=10.19.2
alpaca-py~=0.18.1
stable-baselines3~=2.2.1
sb3-contrib~=2.2.1
gymnasium~=0.29.1

# Build tools
Expand Down
14 changes: 7 additions & 7 deletions roboquant/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ def __repr__(self) -> str:
o_str = ", ".join(o) or "none"

result = (
f"""buying power : {self.buying_power:_.2f}\n"""
f"""cash : {self.cash:_.2f}\n"""
f"""equity : {self.equity():_.2f}\n"""
f"""positions : {p_str}\n"""
f"""mkt value : {self.mkt_value():_.2f}\n"""
f"""open orders : {o_str}\n"""
f"""last update : {self.last_update}"""
f"buying power : {self.buying_power:_.2f}\n"
f"cash : {self.cash:_.2f}\n"
f"equity : {self.equity():_.2f}\n"
f"positions : {p_str}\n"
f"mkt value : {self.mkt_value():_.2f}\n"
f"open orders : {o_str}\n"
f"last update : {self.last_update}"
)
return result
4 changes: 4 additions & 0 deletions roboquant/brokers/simbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,7 @@ def sync(self, event: Event | None = None) -> Account:
acc.buying_power = acc.cash
acc.orders = list(self._create_orders.values())
return acc

def __str__(self) -> str:
attrs = " ".join([f"{k}={v}" for k, v in self.__dict__.items() if not k.startswith("_")])
return f"SimBroker({attrs})"
35 changes: 10 additions & 25 deletions roboquant/feeds/alpacafeed.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from array import array
from datetime import timedelta
import threading
import time
from typing import Literal

from alpaca.data.live.crypto import CryptoDataStream
from alpaca.data.live.stock import StockDataStream
from alpaca.data.live.option import OptionDataStream
from alpaca.data import DataFeed

from roboquant.feeds import AggregatorFeed
from roboquant.feeds.feedutil import get_sp500_symbols
from roboquant.config import Config
from roboquant.event import Event, Quote, Trade, Bar
from roboquant.feeds.eventchannel import EventChannel
Expand All @@ -19,20 +17,22 @@

class AlpacaLiveFeed(Feed):

def __init__(self, market: Literal["stock", "crypto", "option"] = "stock") -> None:
def __init__(self, market: Literal["iex", "sip", "crypto", "option"] = "iex") -> None:
super().__init__()
config = Config()
api_key = config.get("alpaca.public.key")
secret_key = config.get("alpaca.secret.key")
match market:
case "stock":
self.stream = StockDataStream(api_key, secret_key)
case "sip":
self.stream = StockDataStream(api_key, secret_key, feed=DataFeed.SIP)
case "iex":
self.stream = StockDataStream(api_key, secret_key, feed=DataFeed.IEX)
case "crypto":
self.stream = CryptoDataStream(api_key, secret_key)
case "option":
self.stream = OptionDataStream(api_key, secret_key)
case _:
raise ValueError(f"unsupported value market is {market}")
raise ValueError(f"unsupported value market={market}")

thread = threading.Thread(None, self.stream.run, daemon=True)
thread.start()
Expand All @@ -44,6 +44,9 @@ def play(self, channel: EventChannel):
time.sleep(1)
self._channel = None

async def close(self):
await self.stream.close()

async def __handle_trades(self, data):
if self._channel:
item = Trade(data.symbol, data.price, data.size)
Expand All @@ -70,21 +73,3 @@ def subscribe_quotes(self, *symbols: str):

def subscribe_bars(self, *symbols: str):
self.stream.subscribe_bars(self.__handle_bars, *symbols)


def run():
alpaca_feed = AlpacaLiveFeed()
# feed.subscribe_trades("BTC/USD", "ETH/USD")
stocks = get_sp500_symbols()[:30]
alpaca_feed.subscribe_quotes(*stocks)

# feed.subscribe("SPXW240312C05190000")
feed = AggregatorFeed(alpaca_feed, timedelta(seconds=15), item_type="quote")

channel = feed.play_background()
while event := channel.get(30.0):
print(event)


if __name__ == "__main__":
run()
3 changes: 3 additions & 0 deletions roboquant/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from gymnasium.envs.registration import register

register(id="roboquant/Trading-v0", entry_point="roboquant.ml.envs:TradingEnv")
38 changes: 26 additions & 12 deletions roboquant/ml/gymenv.py → roboquant/ml/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from roboquant.account import Account

from roboquant.brokers.broker import Broker
from roboquant.brokers.simbroker import SimBroker
from roboquant.event import Event
from roboquant.feeds.eventchannel import EventChannel
Expand All @@ -12,6 +13,7 @@
from roboquant.ml.features import Feature
from roboquant.ml.torch import Normalize
from roboquant.traders.flextrader import FlexTrader
from roboquant.traders.trader import Trader


logger = logging.getLogger(__name__)
Expand All @@ -23,18 +25,24 @@ class TradingEnv(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(
self, features: list[Feature], feed: Feed, rating_symbols: list[str], warmup: int = 0, broker=None, trader=None
self,
features: list[Feature],
feed: Feed,
rating_symbols: list[str],
warmup: int = 0,
broker: Broker | None = None,
trader: Trader | None = None,
):
self.broker = broker or SimBroker()
self.trader = trader or FlexTrader()
self.broker: Broker = broker or SimBroker()
self.trader: Trader = trader or FlexTrader()
self.channel = EventChannel()
self.feed = feed
self.event: Event | None = None
self.account = self.broker.sync()
self.account: Account = self.broker.sync()
self.symbols = rating_symbols
self.features = features
self.warmup = warmup
self.last_equity = self.account.equity()
self.last_equity: float = self.account.equity()
self.obs_normalizer = None
self.reward_normalizer = None

Expand All @@ -44,13 +52,9 @@ def __init__(
self.observation_space = spaces.Box(-1.0, 1.0, shape=(obs_size,), dtype=np.float32)
self.action_space = spaces.Box(-1.0, 1.0, shape=(action_size,), dtype=np.float32)

self.render_mode = None

def get_broker(self):
return SimBroker()
logger.info("observation_space=%s action_space=%s", self.observation_space, self.action_space)

def get_trader(self):
return FlexTrader()
self.render_mode = None

def calc_normalization(self, steps: int):
obs, _ = self.reset()
Expand Down Expand Up @@ -87,6 +91,7 @@ def step(self, action):
assert self.event is not None
assert self.account is not None
signals = {symbol: Signal(rating) for symbol, rating in zip(self.symbols, action)}
logger.debug("time=%s signals=%s", self.event.time, signals)

orders = self.trader.create_orders(signals, self.event, self.account)
self.broker.place_orders(orders)
Expand Down Expand Up @@ -116,9 +121,18 @@ def reset(self, *, seed=None, options=None):
self.account = self.broker.sync(self.event)
self.trader.create_orders({}, self.event, self.account)
observation = self._get_obs(self.event, self.account)
self._get_reward(self.event, self.account)
i += 1
self.last_equity = self.account.equity()
return observation, {}

def render(self):
pass

def __str__(self):
result = (
f"TradingEnv(\n\tbroker={self.broker}\n\ttrader={self.trader}\n\tfeed={self.feed}"
f"\n\tfeatures={len(self.features)}\n\twarmup={self.warmup}"
f"\n\tobservation_space={self.observation_space}\n\taction_space={self.action_space}"
"\n)"
)
return result
4 changes: 4 additions & 0 deletions roboquant/traders/flextrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,7 @@ def _get_orders(self, symbol: str, size: Decimal, item: PriceItem, rating: float
gtd = time + timedelta(days=3)
limit = item.price(self.price_type)
return [Order(symbol, size, limit, gtd)]

def __str__(self) -> str:
attrs = " ".join([f"{k}={v}" for k, v in self.__dict__.items() if not k.startswith("_")])
return f"FlexTrader({attrs})"
44 changes: 44 additions & 0 deletions tests/performance/test_alpacadelay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
import time
import unittest
from statistics import mean, stdev

from roboquant import Timeframe
from roboquant.feeds.alpacafeed import AlpacaLiveFeed


class TestAlpacaDelay(unittest.IsolatedAsyncioTestCase):

async def test_alpaca_delay(self):

logging.basicConfig(level=logging.INFO)

feed = AlpacaLiveFeed(market="iex")

# subscribe to popular IEX stocks for Quotes
feed.subscribe_quotes(
"TSLA", "MSFT", "NVDA", "AMD", "AAPL", "AMZN", "META", "GOOG", "XOM", "JPM", "NLFX", "BA", "INTC", "V"
)

timeframe = Timeframe.next(minutes=1)
channel = feed.play_background(timeframe, 1000)

delays = []
while event := channel.get(70):
if event.items:
delays.append(time.time() - event.time.timestamp())

if delays:
t = (
f"mean={mean(delays):.3f} stdev={stdev(delays):.3f} "
+ f"max={max(delays):.3f} min={min(delays):.3f} n={len(delays)}"
)
print(t)
else:
print("didn't receive any items, is it perhaps outside trading hours?")

await feed.close()


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/performance/test_tiingodelay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_tiingo_delay(self):
channel = feed.play_background(timeframe, 10_000)

delays = []
while event := channel.get():
while event := channel.get(70):
if event.items:
delays.append(time.time() - event.time.timestamp())

Expand Down
23 changes: 23 additions & 0 deletions tests/samples/alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

from datetime import timedelta
from roboquant.feeds.aggregate import AggregatorFeed
from roboquant.feeds.alpacafeed import AlpacaLiveFeed
from roboquant.feeds.feedutil import get_sp500_symbols


def run():
alpaca_feed = AlpacaLiveFeed()
# feed.subscribe_trades("BTC/USD", "ETH/USD")
stocks = get_sp500_symbols()[:30]
alpaca_feed.subscribe_quotes(*stocks)

# feed.subscribe("SPXW240312C05190000")
feed = AggregatorFeed(alpaca_feed, timedelta(seconds=15), item_type="quote")

channel = feed.play_background()
while event := channel.get(30.0):
print(event)


if __name__ == "__main__":
run()
33 changes: 20 additions & 13 deletions tests/samples/sb3.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
from gymnasium.wrappers.frame_stack import FrameStack
import logging
from stable_baselines3 import A2C

from roboquant.feeds.yahoo import YahooFeed
from roboquant.ml.features import PriceFeature, VolumeFeature, SMAFeature, PositionPNLFeature
from roboquant.ml.gymenv import TradingEnv
from roboquant.ml.envs import TradingEnv


def run():
# pylint: disable=unused-variable
yahoo = YahooFeed("IBM", "JPM", start_date="2000-01-01", end_date="2020-12-31")
logging.basicConfig(level=logging.WARNING)

symbols = ["IBM", "JPM"]
yahoo = YahooFeed(*symbols, start_date="2000-01-01", end_date="2020-12-31")

features = [
PriceFeature("IBM", "JPM").returns(),
VolumeFeature("IBM", "JPM").returns(),
SMAFeature(PriceFeature("JPM"), 10).returns(),
PositionPNLFeature("IBM", "JPM"),
PriceFeature(*symbols).returns(),
VolumeFeature(*symbols).returns(),
SMAFeature(PriceFeature(*symbols), 5).returns(),
SMAFeature(PriceFeature(*symbols), 10).returns(),
SMAFeature(PriceFeature(*symbols), 20).returns(),
PositionPNLFeature(*symbols),
]

trading = TradingEnv(features, yahoo, yahoo.symbols, warmup=20)
trading.calc_normalization(1000)
env = TradingEnv(features, yahoo, symbols, warmup=50)
env.calc_normalization(1000)
print(env)

env = FrameStack(trading, 10)
model = A2C("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=1_000_000)
model.learn(log_interval=100, total_timesteps=10_000)

# Run the trained model on out of sample data
venv = model.get_env()
assert venv is not None
trading.feed = YahooFeed("IBM", "JPM", start_date="2021-01-01")
env.feed = YahooFeed(*symbols, start_date="2021-01-01")
obs = venv.reset()
done = False

logging.getLogger("roboquant.ml.envs").setLevel(logging.DEBUG)
while not done:
action, _state = model.predict(obs, deterministic=True) # type: ignore
obs, reward, done, info = venv.step(action)

print(trading.last_equity)
print(env)


if __name__ == "__main__":
Expand Down
41 changes: 41 additions & 0 deletions tests/samples/sb3_recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from sb3_contrib import RecurrentPPO
from roboquant.feeds.yahoo import YahooFeed
from roboquant.ml.features import PriceFeature, VolumeFeature, PositionPNLFeature
from roboquant.ml.envs import TradingEnv


def run():
# pylint: disable=unused-variable

symbols = ["IBM", "JPM"]
yahoo = YahooFeed(*symbols, start_date="2000-01-01", end_date="2020-12-31")

features = [PriceFeature(*symbols).returns(), VolumeFeature(*symbols).returns(), PositionPNLFeature(*symbols)]

env = TradingEnv(features, yahoo, symbols, warmup=50)
env.calc_normalization(1000)
print(env)

model = RecurrentPPO("MlpLstmPolicy", env, verbose=1)

# Train the model
model.learn(log_interval=10, total_timesteps=10_000)

# Run the trained model on out of sample data
venv = model.get_env()
assert venv is not None
env.feed = YahooFeed(*symbols, start_date="2021-01-01")
obs = venv.reset()
done = np.zeros((1,), dtype=bool)
state = None
while not done:
action, state = model.predict(obs, state=state, episode_start=done, deterministic=True) # type: ignore
print(action)
obs, reward, done, info = venv.step(action)

print(env)


if __name__ == "__main__":
run()

0 comments on commit 657d7dc

Please sign in to comment.