Skip to content

Commit

Permalink
Added collect feed
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Apr 15, 2024
1 parent c0d5c83 commit 7460037
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 24 deletions.
2 changes: 1 addition & 1 deletion roboquant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from .order import Order, OrderStatus
from .run import run
from .signal import SignalType, Signal
from .timeframe import Timeframe
from .timeframe import Timeframe, EMPTY_TIMEFRAME
6 changes: 3 additions & 3 deletions roboquant/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def empty(time: datetime | None = None):

def is_empty(self) -> bool:
"""return True if this is an empty event without any items, False otherwise"""
return len(self) == 0
return len(self.items) == 0

def __len__(self) -> int:
return len(self.items)
# def __len__(self) -> int:
# return len(self.items)

@cached_property
def price_items(self) -> dict[str, PriceItem]:
Expand Down
1 change: 1 addition & 0 deletions roboquant/feeds/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from roboquant.feeds import feedutil
from .aggregate import AggregatorFeed
from .collect import CollectorFeed
from .csvfeed import CSVFeed
from .eventchannel import EventChannel
from .feed import Feed
Expand Down
33 changes: 33 additions & 0 deletions roboquant/feeds/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

from roboquant.event import Event
from .eventchannel import EventChannel
from .feed import Feed


class CollectorFeed(Feed):
"""Collect events into one new event if they occur close to eachother.
Close to eachother is defined by the timeout is seconds. If there is no new
event in the specified timeout, all previous events will be bundled together and
put on the channel.
"""

def __init__(
self,
feed: Feed,
timeout=5.0,
):
super().__init__()
self.feed = feed
self.timeout = timeout

def play(self, channel: EventChannel):
src_channel = self.feed.play_background(channel.timeframe, channel.maxsize)
items = []
while event := src_channel.get(self.timeout):
if event.is_empty() and items:
new_event = Event(event.time, items)
channel.put(new_event)
items = []

items.extend(event.items)
1 change: 1 addition & 0 deletions roboquant/feeds/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def play(self, channel: EventChannel):
...

def timeframe(self) -> Timeframe | None:
"""Return the timeframe of this feed it has one and is known, otherwise return None."""
return None

def play_background(self, timeframe: Timeframe | None = None, channel_capacity: int = 10) -> EventChannel:
Expand Down
21 changes: 13 additions & 8 deletions roboquant/feeds/historic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List

from roboquant.event import Event, PriceItem
from roboquant.timeframe import Timeframe
from roboquant.timeframe import EMPTY_TIMEFRAME, Timeframe
from .eventchannel import EventChannel
from .feed import Feed
from .feedutil import get_ohlcv
Expand All @@ -22,7 +22,9 @@ def __init__(self):
self.__symbols = []

def _add_item(self, time: datetime, item: PriceItem):
"""Add a price-item at a moment in time to this feed"""
"""Add a price-item at a moment in time to this feed.
Subclasses should invoke this method to populate the historic-feed.
"""

self.__modified = True

Expand All @@ -38,6 +40,11 @@ def symbols(self):
self.__update()
return self.__symbols

@property
def events(self):
"""Return the total number of events"""
return len(self.__data)

def timeline(self) -> List[datetime]:
"""Return the timeline of this feed as a list of datatime objects"""
self.__update()
Expand All @@ -46,10 +53,10 @@ def timeline(self) -> List[datetime]:
def timeframe(self):
"""Return the timeframe of this feed"""
tl = self.timeline()
if not tl:
raise ValueError("Feed doesn't contain any events.")
if tl:
return Timeframe(tl[0], tl[-1], inclusive=True)

return Timeframe(tl[0], tl[-1], inclusive=True)
return EMPTY_TIMEFRAME

def get_ohlcv(self, symbol: str, timeframe=None) -> dict[str, list]:
"""Get the OHLCV values for a symbol for the (optional) provided timeframe.
Expand All @@ -71,7 +78,5 @@ def play(self, channel: EventChannel):
channel.put(evt)

def __repr__(self) -> str:
events = len(self.timeline())
timeframe = self.timeframe() if events else None
feed = self.__class__.__name__
return f"{feed}(events={events} symbols={len(self.symbols)} timeframe={timeframe})"
return f"{feed}(events={self.events} symbols={len(self.symbols)} timeframe={self.timeframe()})"
28 changes: 28 additions & 0 deletions roboquant/ml/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _full_nan(self):


class SlicedFeature(Feature):
"""Calculate a slice from another feature"""

def __init__(self, feature: Feature, args) -> None:
super().__init__()
Expand Down Expand Up @@ -400,6 +401,7 @@ class MaxReturnFeature(Feature):

def __init__(self, feature: Feature, period: int) -> None:
super().__init__()
assert feature.size() == 1
self.history = deque(maxlen=period)
self.feature: Feature = feature

Expand Down Expand Up @@ -521,6 +523,7 @@ class DayOfWeekFeature(Feature):
"""Calculate a one-hot-encoded day of the week where Monday == 0 and Sunday == 6"""

def __init__(self, tz=timezone.utc) -> None:
super().__init__()
self.tz = tz

def calc(self, evt, account):
Expand All @@ -534,9 +537,34 @@ def size(self) -> int:
return 7


class TimeDifference(Feature):
"""Calculate the time difference in seconds between two consecutive events."""

def __init__(self) -> None:
super().__init__()
self._last_time: datetime | None = None

def calc(self, evt, account):
if self._last_time:
diff = evt.time - self._last_time
self._last_time = evt.time
return np.asarray([diff.total_seconds], dtype="float32")

self._last_time = evt.time
return self._full_nan()

def size(self) -> int:
return 1

def reset(self):
self._last_time = None


class TaFeature(Feature):
"""Base class for technical analysis features"""

def __init__(self, *symbols: str, history_size: int) -> None:
super().__init__()
self._data: dict[str, OHLCVBuffer] = {}
self._size = history_size
self.symbols = list(symbols)
Expand Down
8 changes: 4 additions & 4 deletions roboquant/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class OrderStatus(Flag):
_CLOSE = REJECTED | FILLED | CANCELLED | EXPIRED

@property
def open(self):
def is_open(self):
"""Return True is the status is open, False otherwise"""
return self in OrderStatus._OPEN

@property
def closed(self):
def is_closed(self):
"""Return True is the status is closed, False otherwise"""
return self in OrderStatus._CLOSE

Expand Down Expand Up @@ -77,12 +77,12 @@ def __init__(
@property
def is_open(self) -> bool:
"""Return True is the order is open, False otherwise"""
return self.status.open
return self.status.is_open

@property
def is_closed(self) -> bool:
"""Return True is the order is closed, False otherwise"""
return self.status.closed
return self.status.is_closed

def cancel(self) -> "Order":
"""Create a cancellation order. You can only cancel orders that are still open and have an id.
Expand Down
15 changes: 11 additions & 4 deletions roboquant/timeframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Timeframe:
"""A timeframe represents a period in time with a specific start- and end-datetime.
"""A timeframe represents a period in time with a specific start- and end-datetime. Timeframes should not be mutated.
Internally it stores the start and end times as Python datetime objects with the timezone set to UTC.
"""
Expand Down Expand Up @@ -32,9 +32,9 @@ def fromisoformat(cls, start: str, end: str, inclusive=False):
e = datetime.fromisoformat(end)
return cls(s, e, inclusive)

@staticmethod
def empty():
return Timeframe.fromisoformat("1900-01-01T00:00:00+00:00", "1900-01-01T00:00:00+00:00", False)
def is_empty(self):
"""Return true if this is an empty timeframe"""
return self.start == self.end and not self.inclusive

@staticmethod
def previous(inclusive=False, **kwargs):
Expand Down Expand Up @@ -68,6 +68,9 @@ def __contains__(self, time):
return self.start <= time < self.end

def __repr__(self):
if self == EMPTY_TIMEFRAME:
return "EMPTY_TIMEFRAME"

last_char = "]" if self.inclusive else ">"
fmt_str = "%Y-%m-%d %H:%M:%S"
return f"[{self.start.strftime(fmt_str)}{self.end.strftime(fmt_str)}{last_char}"
Expand Down Expand Up @@ -126,3 +129,7 @@ def __eq__(self, other):
return self.start == other.start and self.end == other.end and self.inclusive == other.inclusive

return False


EMPTY_TIMEFRAME = Timeframe.fromisoformat("1900-01-01T00:00:00+00:00", "1900-01-01T00:00:00+00:00", False)
"""Represents an empty timeframe, one that cannot contain events"""
5 changes: 4 additions & 1 deletion roboquant/traders/sizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ def get_value(self, symbol) -> float | None:
return entry[0]
return None

def reset(self):
self.history = {}


if __name__ == "__main__":
feed = YahooFeed("IBM", "MSFT", "TSLA")
channel = feed.play_background()
e = ETR(80, 0.1)
e = ETR(20, 2.0)
while evt := channel.get():
e.add(evt)
print(e.get_value("IBM"), e.get_value("TSLA"), e.get_value("MSFT"))
26 changes: 26 additions & 0 deletions samples/forwardtest_alpaca2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# %%
from datetime import timedelta
import logging
import roboquant as rq
from roboquant.alpaca import AlpacaLiveFeed

# %%
logging.basicConfig()
logging.getLogger("roboquant").setLevel(level=logging.INFO)

# Connect to Alpaca and subscribe to some IEX stocks 1-minute bars
symbols = ["TSLA", "MSFT", "NVDA", "AMD", "AAPL", "AMZN"]
alpaca_feed = AlpacaLiveFeed(market="iex")
alpaca_feed.subscribe_bars(*symbols)

feed = rq.feeds.CollectorFeed(alpaca_feed, 10.0)
# %%
# Let run an EMACrossover strategy
strategy = rq.strategies.EMACrossover(5, 13)
timeframe = rq.Timeframe.next(minutes=60)
journal = rq.journals.BasicJournal()
account = rq.run(feed, strategy, journal=journal, timeframe=timeframe)

# %%
print(account)
print(journal)
10 changes: 9 additions & 1 deletion tests/integration/test_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,31 @@ class TestAlpaca(unittest.TestCase):

stocks = ["AAPL", "TSLA"]

def test_alpaca_feed(self):
def test_alpaca_stock_feed_bars(self):
feed = AlpacaHistoricStockFeed()
feed.retrieve_bars(*self.stocks, start="2024-03-01", end="2024-03-02")
run_price_item_feed(feed, self.stocks, self)

def test_alpaca_stock_feed_trades(self):
feed = AlpacaHistoricStockFeed()
feed.retrieve_trades(*self.stocks, start="2024-03-01T18:00:00", end="2024-03-01T18:01:00")
run_price_item_feed(feed, self.stocks, self)

def test_alpaca_stock_feed_quotes(self):
feed = AlpacaHistoricStockFeed()
feed.retrieve_quotes(*self.stocks, start="2024-03-01T18:00:00", end="2024-03-01T18:01:00")
run_price_item_feed(feed, self.stocks, self)

def test_alpaca_crypto_feed_bars(self):
feed = AlpacaHistoricCryptoFeed()
feed.retrieve_bars("BTC/USDT", start="2024-03-01", end="2024-03-02", resolution=TimeFrame.Hour) # type: ignore
run_price_item_feed(feed, ["BTC/USDT"], self)

def test_alpaca_crypto_feed_trades(self):
feed = AlpacaHistoricCryptoFeed()
feed.retrieve_trades("BTC/USDT", start="2024-04-01", end="2024-04-10")
run_price_item_feed(feed, ["BTC/USDT"], self)

def test_alpaca_broker(self):
broker = AlpacaBroker()
account = broker.sync()
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/test_yahoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ def test_yahoo_feed(self):
feed = YahooFeed("MSFT", "JPM", start_date="2018-01-01", end_date="2020-01-01")
self.assertEqual(2, len(feed.symbols))
self.assertEqual({"MSFT", "JPM"}, set(feed.symbols))
self.assertTrue(feed.timeframe().start == datetime.fromisoformat("2018-01-02T05:00:00+00:00"))
self.assertTrue(feed.timeframe().end == datetime.fromisoformat("2019-12-31T05:00:00+00:00"))

tf = feed.timeframe()
assert tf
self.assertTrue(tf.start == datetime.fromisoformat("2018-01-02T05:00:00+00:00"))
self.assertTrue(tf.end == datetime.fromisoformat("2019-12-31T05:00:00+00:00"))
self.assertEqual(503, len(feed.timeline()))

run_price_item_feed(feed, ["MSFT", "JPM"], self)
Expand Down

0 comments on commit 7460037

Please sign in to comment.