Skip to content

Commit

Permalink
Added support for Trade
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed May 29, 2024
1 parent eb10d29 commit 7565718
Showing 1 changed file with 47 additions and 33 deletions.
80 changes: 47 additions & 33 deletions roboquant/feeds/avrofeed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os.path
from datetime import datetime
from array import array
Expand All @@ -6,31 +7,34 @@
from fastavro import writer, reader, parse_schema

from roboquant.alpaca.feed import AlpacaHistoricStockFeed
from roboquant.event import Quote, Bar
from roboquant.event import Quote, Bar, Trade
from roboquant.event import Event
from roboquant.feeds.eventchannel import EventChannel
from roboquant.feeds.feed import Feed
from roboquant.feeds.feedutil import count_events

logger = logging.getLogger(__name__)


class AvroFeed(Feed):

_schema = {
"namespace": "org.roboquant.avro.schema",
"type": "record",
"name": "PriceItemV2",
"fields": [
{"name": "timestamp", "type": "string"},
{"name": "symbol", "type": "string"},
{"name": "type", "type": {"type": "enum", "name": "item_type", "symbols": ["BAR", "TRADE", "QUOTE", "BOOK"]}},
{"name": "values", "type": {"type": "array", "items": "double"}},
{"name": "meta", "type": ["null", "string"], "default": None},
],
}

def __init__(self, avro_file) -> None:
super().__init__()
self.avro_file = avro_file

self.schema = {
"namespace": "org.roboquant.avro.schema",
"type": "record",
"name": "PriceItemV2",
"fields": [
{"name": "timestamp", "type": "string"},
{"name": "symbol", "type": "string"},
{"name": "type", "type": {"type": "enum", "name": "item_type", "symbols": ["BAR", "TRADE", "QUOTE", "BOOK"]}},
{"name": "values", "type": {"type": "array", "items": "double"}},
{"name": "meta", "type": ["null", "string"], "default": None},
],
}
logger.info("avro feed file=%s", avro_file)

def exists(self):
return os.path.exists(self.avro_file)
Expand All @@ -45,9 +49,6 @@ def play(self, channel: EventChannel):

if t != t_old:
if items:
# time_us = int(t_old) // 1_000
# dt = datetime.fromtimestamp(time_us // 1_000_000, tz=timezone.utc)
# dt = dt.replace(microsecond=time_us % 1_000_000)
dt = datetime.fromisoformat(t)
event = Event(dt, items)
channel.put(event)
Expand All @@ -62,31 +63,44 @@ def play(self, channel: EventChannel):
case "BAR":
item = Bar(row["symbol"], array("f", row["values"]), row["other"]) # type: ignore
items.append(item)
case "TRADE":
prices = row["values"] # type: ignore
item = Trade(row["symbol"], prices[0], prices[1]) # type: ignore
items.append(item)
case _:
raise ValueError(f"Unsupported priceItem type={price_type}")

def record(self, feed: Feed, timeframe=None):
schema = parse_schema(self.schema)
schema = parse_schema(AvroFeed._schema)
channel = feed.play_background(timeframe)
records = []
while event := channel.get():
t = event.time.isoformat()
for item in event.items:
if isinstance(item, Quote):
data = {"timestamp": t, "type": "QUOTE", "symbol": item.symbol, "values": list(item.data)}
records.append(data)
if isinstance(item, Bar):
data = {
"timestamp": t,
"type": "BAR",
"symbol": item.symbol,
"values": list(item.ohlcv),
"meta": item.frequency,
}
records.append(data)

with open(self.avro_file, "wb") as out:
match item:
case Quote():
data = {"timestamp": t, "type": "QUOTE", "symbol": item.symbol, "values": list(item.data)}
records.append(data)
case Trade():
data = {
"timestamp": t,
"type": "TRADE",
"symbol": item.symbol,
"values": [item.trade_price, item.trade_volume],
}
records.append(data)
case Bar():
data = {
"timestamp": t,
"type": "BAR",
"symbol": item.symbol,
"values": list(item.ohlcv),
"meta": item.frequency,
}
records.append(data)

with open(self.avro_file, "wb") as out:
writer(out, schema, records)

def __str__(self) -> str:
Expand All @@ -98,9 +112,9 @@ def __str__(self) -> str:
avroFeed = AvroFeed("/tmp/test.avro")
if not avroFeed.exists():
alpaca_feed = AlpacaHistoricStockFeed()
alpaca_feed.retrieve_quotes("AAPL", start="2024-05-24T20:00:00Z")
alpaca_feed.retrieve_quotes("AAPL", start="2024-05-24T00:00:00Z", end="2024-05-25T00:00:00Z")
avroFeed.record(alpaca_feed)

start = time.time()
print(count_events(avroFeed), time.time() - start)
print("events=", count_events(avroFeed), "time=", time.time() - start)
# print_feed_items(feed)

0 comments on commit 7565718

Please sign in to comment.