-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
184 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from gymnasium.envs.registration import register | ||
|
||
register(id="roboquant/Trading-v0", entry_point="roboquant.ml.envs:TradingEnv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import logging | ||
import time | ||
import unittest | ||
from statistics import mean, stdev | ||
|
||
from roboquant import Timeframe | ||
from roboquant.feeds.alpacafeed import AlpacaLiveFeed | ||
|
||
|
||
class TestAlpacaDelay(unittest.IsolatedAsyncioTestCase): | ||
|
||
async def test_alpaca_delay(self): | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
feed = AlpacaLiveFeed(market="iex") | ||
|
||
# subscribe to popular IEX stocks for Quotes | ||
feed.subscribe_quotes( | ||
"TSLA", "MSFT", "NVDA", "AMD", "AAPL", "AMZN", "META", "GOOG", "XOM", "JPM", "NLFX", "BA", "INTC", "V" | ||
) | ||
|
||
timeframe = Timeframe.next(minutes=1) | ||
channel = feed.play_background(timeframe, 1000) | ||
|
||
delays = [] | ||
while event := channel.get(70): | ||
if event.items: | ||
delays.append(time.time() - event.time.timestamp()) | ||
|
||
if delays: | ||
t = ( | ||
f"mean={mean(delays):.3f} stdev={stdev(delays):.3f} " | ||
+ f"max={max(delays):.3f} min={min(delays):.3f} n={len(delays)}" | ||
) | ||
print(t) | ||
else: | ||
print("didn't receive any items, is it perhaps outside trading hours?") | ||
|
||
await feed.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
|
||
from datetime import timedelta | ||
from roboquant.feeds.aggregate import AggregatorFeed | ||
from roboquant.feeds.alpacafeed import AlpacaLiveFeed | ||
from roboquant.feeds.feedutil import get_sp500_symbols | ||
|
||
|
||
def run(): | ||
alpaca_feed = AlpacaLiveFeed() | ||
# feed.subscribe_trades("BTC/USD", "ETH/USD") | ||
stocks = get_sp500_symbols()[:30] | ||
alpaca_feed.subscribe_quotes(*stocks) | ||
|
||
# feed.subscribe("SPXW240312C05190000") | ||
feed = AggregatorFeed(alpaca_feed, timedelta(seconds=15), item_type="quote") | ||
|
||
channel = feed.play_background() | ||
while event := channel.get(30.0): | ||
print(event) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
from sb3_contrib import RecurrentPPO | ||
from roboquant.feeds.yahoo import YahooFeed | ||
from roboquant.ml.features import PriceFeature, VolumeFeature, PositionPNLFeature | ||
from roboquant.ml.envs import TradingEnv | ||
|
||
|
||
def run(): | ||
# pylint: disable=unused-variable | ||
|
||
symbols = ["IBM", "JPM"] | ||
yahoo = YahooFeed(*symbols, start_date="2000-01-01", end_date="2020-12-31") | ||
|
||
features = [PriceFeature(*symbols).returns(), VolumeFeature(*symbols).returns(), PositionPNLFeature(*symbols)] | ||
|
||
env = TradingEnv(features, yahoo, symbols, warmup=50) | ||
env.calc_normalization(1000) | ||
print(env) | ||
|
||
model = RecurrentPPO("MlpLstmPolicy", env, verbose=1) | ||
|
||
# Train the model | ||
model.learn(log_interval=10, total_timesteps=10_000) | ||
|
||
# Run the trained model on out of sample data | ||
venv = model.get_env() | ||
assert venv is not None | ||
env.feed = YahooFeed(*symbols, start_date="2021-01-01") | ||
obs = venv.reset() | ||
done = np.zeros((1,), dtype=bool) | ||
state = None | ||
while not done: | ||
action, state = model.predict(obs, state=state, episode_start=done, deterministic=True) # type: ignore | ||
print(action) | ||
obs, reward, done, info = venv.step(action) | ||
|
||
print(env) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |