Skip to content

Commit

Permalink
Better CSV tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 27, 2024
1 parent c8dcc06 commit 9cc2b7b
Show file tree
Hide file tree
Showing 13 changed files with 17,706 additions and 19,355 deletions.
20 changes: 2 additions & 18 deletions roboquant/brokers/alpacabroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,9 @@ def __init__(self, api_key=None, secret_key=None) -> None:
api_key = api_key or config.get("alpaca.public.key")
secret_key = secret_key or config.get("alpaca.secret.key")
self.__client = TradingClient(api_key, secret_key)
self.__has_new_orders_since_sync = False
self.price_type = "DEFAULT"
self.sleep_after_cancel = 0.0

def _should_sync(self, now: datetime):
"""Avoid too many API calls"""
return self.__has_new_orders_since_sync or now - self.__account.last_update > timedelta(seconds=1)

def _sync_orders(self):
for order in self.__account.open_orders():
assert order.id is not None
Expand Down Expand Up @@ -64,8 +59,6 @@ def _sync_positions(self):
self.__account.positions[p.symbol] = new_pos

def sync(self, event: Event | None = None) -> Account:

logger.debug("start sync")
now = datetime.now(timezone.utc)

if event:
Expand All @@ -82,17 +75,11 @@ def sync(self, event: Event | None = None) -> Account:

self._sync_positions()
self._sync_orders()
logger.debug("end sync")
return self.__account

def place_orders(self, orders):

self.__has_new_orders_since_sync = len(orders) > 0

for idx, order in enumerate(orders, start=1):
if idx % 25 == 0:
# avoid to many API calls
time.sleep(1)
for order in orders:

assert order.is_open, "can only place open orders"
if order.size.is_zero():
Expand Down Expand Up @@ -122,10 +109,7 @@ def _get_order_request(self, order: Order):
)
else:
result = MarketOrderRequest(
symbol=order.symbol,
qty=abs(float(order.size)),
side=side,
time_in_force=TimeInForce.GTC
symbol=order.symbol, qty=abs(float(order.size)), side=side, time_in_force=TimeInForce.GTC
)
return result

Expand Down
17 changes: 17 additions & 0 deletions roboquant/brokers/broker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone

from roboquant.account import Account
from roboquant.event import Event
Expand Down Expand Up @@ -51,3 +52,19 @@ def _update_positions(account: Account, event: Event | None, price_type: str = "
for symbol, position in account.positions.items():
if price := event.get_price(symbol, price_type):
position.mkt_price = price


class LiveBroker(Broker):

def __init__(self) -> None:
super().__init__()
self.max_delay = timedelta(minutes=30)

def guard(self, event: Event | None = None):
if not event:
return

now = datetime.now(timezone.utc)

if now - event.time > self.max_delay:
raise ValueError(f"received event too far in the past now={now} event-time={event.time}")
4 changes: 2 additions & 2 deletions roboquant/feeds/csvfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
frequency="",
):
super().__init__()
columns = columns or ["Date", "Open", "High", "Low", "Close", "Volume", "AdjClose", "Time"]
columns = columns or ["Date", "Open", "High", "Low", "Close", "Volume", "Adj Close", "Time"]
self.ohlcv_columns = columns[1:6]
self.adj_close_column = columns[6] if adj_close else None
self.date_column = columns[0]
Expand Down Expand Up @@ -147,7 +147,7 @@ def stooq_us_intraday(cls, path):

class StooqIntradayFeed(CSVFeed):
def __init__(self):
# from Python 3.11 onwards we can use the fast standard ISO parsing
# from Python 3.11 onwards we can use the faster standard ISO parsing
if sys.version_info >= (3, 11):
super().__init__(path, columns=columns, has_time_column=True, endswith=".txt")
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


def get_feed() -> CSVFeed:
root = pathlib.Path(__file__).parent.resolve().joinpath("data", "csv")
return CSVFeed(str(root), time_offset="21:00:00+00:00", date_fmt="%Y%m%d")
root = pathlib.Path(__file__).parent.resolve().joinpath("data", "yahoo")
return CSVFeed(str(root), time_offset="21:00:00+00:00")


def get_recent_start_date(days=10):
Expand Down
9,682 changes: 0 additions & 9,682 deletions tests/data/csv/aapl.csv

This file was deleted.

6,469 changes: 0 additions & 6,469 deletions tests/data/csv/amzn.csv

This file was deleted.

3,174 changes: 0 additions & 3,174 deletions tests/data/csv/tsla.csv

This file was deleted.

10,913 changes: 10,913 additions & 0 deletions tests/data/yahoo/AAPL.csv

Large diffs are not rendered by default.

6,761 changes: 6,761 additions & 0 deletions tests/data/yahoo/AMZN.csv

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions tests/data/yahoo/TSLA.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Date,Open,High,Low,Close,Adj Close,Volume
2022-08-22,291.913330,292.399994,286.296661,289.913330,289.913330,55843200
2022-08-23,291.453339,298.826660,287.923340,296.453339,296.453339,63984900
2022-08-24,297.563324,303.646667,296.500000,297.096680,297.096680,57259800
2022-08-25,302.359985,302.959991,291.600006,296.070007,296.070007,53230000
2022-08-26,297.429993,302.000000,287.470001,288.089996,288.089996,57163900
7 changes: 1 addition & 6 deletions tests/unit/test_csvfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,10 @@ def _get_root_dir(*paths):
root = pathlib.Path(__file__).parent.resolve().joinpath("..", "data", *paths)
return str(root)

def test_csv_feed_generic(self):
root = self._get_root_dir("csv")
feed = CSVFeed(root, time_offset="21:00:00+00:00")
run_price_item_feed(feed, ["AAPL", "AMZN", "TSLA"], self)

def test_csv_feed_yahoo(self):
root = self._get_root_dir("yahoo")
feed = CSVFeed.yahoo(root)
run_price_item_feed(feed, ["META"], self)
run_price_item_feed(feed, ["META", "AAPL", "AMZN", "TSLA"], self)

def test_csv_feed_stooq_daily(self):
root = self._get_root_dir("stooq", "daily")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_flextrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class _MyTrader(FlexTrader):

def _get_orders(self, symbol, size, item, rating, time):
def _get_orders(self, symbol, size, item, signal, time):
price = item.price("CLOSE")
if price:
limit_price = price * 0.99 if size > 0 else price * 1.01
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_lstm_model(self):
tf = rq.Timeframe.fromisoformat("2020-01-01", "2024-01-01")
rq.run(feed, strategy, timeframe=tf)
predictions = strategy.prediction_results
self.assertEqual(760, len(predictions))
self.assertEqual(987, len(predictions))
self.assertNotEqual(max(predictions), min(predictions))


Expand Down

0 comments on commit 9cc2b7b

Please sign in to comment.