Skip to content

Commit

Permalink
Historic Alpaca data feeds
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Apr 14, 2024
1 parent d632c5f commit 75b2094
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 38 deletions.
113 changes: 113 additions & 0 deletions roboquant/alpaca/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,27 @@
from typing import Literal

from alpaca.data import DataFeed
from alpaca.data.historical.crypto import CryptoHistoricalDataClient
from alpaca.data.historical.stock import StockHistoricalDataClient
from alpaca.data.live.crypto import CryptoDataStream
from alpaca.data.live.stock import StockDataStream
from alpaca.data.live.option import OptionDataStream
from alpaca.data.models.bars import BarSet
from alpaca.data.models.quotes import QuoteSet
from alpaca.data.models.trades import TradeSet
from alpaca.data.requests import (
CryptoBarsRequest,
CryptoTradesRequest,
StockBarsRequest,
StockQuotesRequest,
StockTradesRequest,
)
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit

from roboquant.config import Config
from roboquant.event import Event, Quote, Trade, Bar
from roboquant.feeds.feedutil import print_feed_items
from roboquant.feeds.historic import HistoricFeed
from roboquant.feeds.live import LiveFeed


Expand Down Expand Up @@ -65,3 +80,101 @@ def subscribe_quotes(self, *symbols: str):

def subscribe_bars(self, *symbols: str):
self.stream.subscribe_bars(self.__handle_bars, *symbols)


class AlpacaHistoricFeed(HistoricFeed):

def _process_bars(self, bar_set, freq: str):
for symbol, data in bar_set.items():
for d in data:
time = d.timestamp
ohlcv = array("f", [d.open, d.high, d.low, d.close, d.volume])
item = Bar(symbol, ohlcv, freq)
super()._add_item(time, item)

def _process_trades(self, quote_set):
for symbol, data in quote_set.items():
for d in data:
time = d.timestamp
item = Trade(symbol, d.price, d.size)
super()._add_item(time, item)

def _process_quotes(self, quote_set):
for symbol, data in quote_set.items():
for d in data:
time = d.timestamp
arr = array("f", [d.ask_price, d.ask_size, d.bid_price, d.bid_size])
item = Quote(symbol, arr)
super()._add_item(time, item)


class AlpacaHistoricStockFeed(AlpacaHistoricFeed):

def __init__(self, api_key=None, secret_key=None, data_api_url=None):
super().__init__()
config = Config()
api_key = api_key or config.get("alpaca.public.key")
secret_key = secret_key or config.get("alpaca.secret.key")
self.client = StockHistoricalDataClient(api_key, secret_key, url_override=data_api_url)

def retrieve_bars(self, *symbols, start=None, end=None, resolution: TimeFrame | None = None):
resolution = resolution or TimeFrame(amount=1, unit=TimeFrameUnit.Day)
req = StockBarsRequest(symbol_or_symbols=list(symbols), timeframe=resolution, start=start, end=end)
res = self.client.get_stock_bars(req)
assert isinstance(res, BarSet)
freq = str(resolution)
self._process_bars(res.data, freq)

def retrieve_trades(self, *symbols, start=None, end=None):
req = StockTradesRequest(symbol_or_symbols=list(symbols), start=start, end=end)
res = self.client.get_stock_trades(req)
assert isinstance(res, TradeSet)
self._process_trades(res.data)

def retrieve_quotes(self, *symbols, start=None, end=None):
req = StockQuotesRequest(symbol_or_symbols=list(symbols), start=start, end=end)
res = self.client.get_stock_quotes(req)
assert isinstance(res, QuoteSet)
self._process_quotes(res.data)


class AlpacaHistoricCryptoFeed(AlpacaHistoricFeed):

def __init__(self, api_key=None, secret_key=None, data_api_url=None):
super().__init__()
config = Config()
api_key = api_key or config.get("alpaca.public.key")
secret_key = secret_key or config.get("alpaca.secret.key")
self.client = CryptoHistoricalDataClient(api_key, secret_key, url_override=data_api_url)

def retrieve_bars(self, *symbols, start=None, end=None, resolution: TimeFrame | None = None):
resolution = resolution or TimeFrame(amount=1, unit=TimeFrameUnit.Day)
req = CryptoBarsRequest(symbol_or_symbols=list(symbols), timeframe=resolution, start=start, end=end)
res = self.client.get_crypto_bars(req)
assert isinstance(res, BarSet)
freq = str(resolution)
self._process_bars(res.data, freq)

def retrieve_trades(self, *symbols, start=None, end=None):
req = CryptoTradesRequest(symbol_or_symbols=list(symbols), start=start, end=end)
res = self.client.get_crypto_trades(req)
assert isinstance(res, TradeSet)
self._process_trades(res.data)


if __name__ == "__main__":
feed = AlpacaHistoricStockFeed()
feed.retrieve_bars("AAPL", "TSLA", start="2024-03-01", end="2024-03-02")
print_feed_items(feed)

feed = AlpacaHistoricStockFeed()
feed.retrieve_trades("AAPL", "TSLA", start="2024-03-01T18:00:00", end="2024-03-01T18:01:00")
print_feed_items(feed)

feed = AlpacaHistoricStockFeed()
feed.retrieve_quotes("AAPL", "TSLA", start="2024-03-01T18:00:00", end="2024-03-01T18:01:00")
print_feed_items(feed)

feed = AlpacaHistoricCryptoFeed()
feed.retrieve_bars("BTC/USDT", start="2024-03-01", end="2024-03-02", resolution=TimeFrame.Hour) # type: ignore
print_feed_items(feed)
58 changes: 28 additions & 30 deletions roboquant/ml/strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from collections import deque
from datetime import datetime
import logging
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -71,14 +72,12 @@ def reset(self):

class FeatureStrategy(Strategy):
"""Abstract base class for strategies wanting to use features
for their input and target.
for their input.
"""

def __init__(self, input_feature: Feature, label_feature: Feature, history: int, dtype="float32"):
self._features_x = []
self._features_y = []
def __init__(self, input_feature: Feature, history: int, dtype="float32"):
self.input_feature = input_feature
self.label_feature = label_feature
self.history = history
self._hist = deque(maxlen=history)
self._dtype = dtype

Expand All @@ -88,26 +87,11 @@ def create_signals(self, event: Event):
h.append(row)
if len(h) == h.maxlen:
x = np.asarray(h, dtype=self._dtype)
return self.predict(x)
return self.predict(x, event.time)
return []

@abstractmethod
def predict(self, x: NDArray) -> list[Signal]: ...

def _get_xy(self, feed, timeframe=None, warmup=0) -> tuple[NDArray, NDArray]:
channel = feed.play_background(timeframe)
x = []
y = []
while evt := channel.get():
if warmup:
self.label_feature.calc(evt, None)
self.input_feature.calc(evt, None)
warmup -= 1
else:
x.append(self.input_feature.calc(evt, None))
y.append(self.label_feature.calc(evt, None))

return np.asarray(x, dtype=self._dtype), np.asarray(y, dtype=self._dtype)
def predict(self, x: NDArray, time: datetime) -> list[Signal]: ...


class SequenceDataset(Dataset):
Expand Down Expand Up @@ -146,15 +130,14 @@ def __init__(
buy_pct: float = 0.01,
sell_pct=0.0,
):
super().__init__(input_feature, label_feature, sequences)
self.sequences = sequences
super().__init__(input_feature, sequences)
self.label_feature = label_feature
self.model = model
self.buy_pct = buy_pct
self.sell_pct = sell_pct
self.symbol = symbol
self.prediction_results = []

def predict(self, x):
def predict(self, x, time):
x = torch.asarray(x)
x = torch.unsqueeze(x, dim=0) # add the batch dimension

Expand All @@ -167,7 +150,7 @@ def predict(self, x):
else:
p = output.item()

self.prediction_results.append(p)
logger.info("prediction p=%s time=%s", p, time)
if p >= self.buy_pct:
return [Signal.buy(self.symbol)]
if p <= self.sell_pct:
Expand All @@ -182,18 +165,33 @@ def _get_dataloaders(self, x, y, prediction: int, validation_split: float, batch
x_train = x[: border - prediction]
y_train = y[prediction:border]

train_dataset = SequenceDataset(x_train, y_train, self.sequences)
train_dataset = SequenceDataset(x_train, y_train, self.history)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataloader = None
if validation_split > 0.0:
x_valid = x[border - prediction: -prediction]
y_valid = y[border:]
valid_dataset = SequenceDataset(x_valid, y_valid, self.sequences)
valid_dataset = SequenceDataset(x_valid, y_valid, self.history)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

return train_dataloader, valid_dataloader

def __get_xy(self, feed, timeframe=None, warmup=0) -> tuple[NDArray, NDArray]:
channel = feed.play_background(timeframe)
x = []
y = []
while evt := channel.get():
if warmup:
self.label_feature.calc(evt, None)
self.input_feature.calc(evt, None)
warmup -= 1
else:
x.append(self.input_feature.calc(evt, None))
y.append(self.label_feature.calc(evt, None))

return np.asarray(x, dtype=self._dtype), np.asarray(y, dtype=self._dtype)

@staticmethod
def describe(x):
print("shape=", x.shape, "min=", np.min(x, axis=0), "max=", np.max(x, axis=0), "mean=", np.mean(x, axis=0))
Expand Down Expand Up @@ -229,7 +227,7 @@ def fit(
optimizer = optimizer or torch.optim.Adam(self.model.parameters(), lr=0.001)
criterion = criterion or torch.nn.MSELoss()

x, y = self._get_xy(feed, timeframe, warmup=warmup)
x, y = self.__get_xy(feed, timeframe, warmup=warmup)
logger.info("x-shape=%s", x.shape)
logger.info("y-shape=%s", y.shape)

Expand Down
4 changes: 0 additions & 4 deletions samples/torch_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,4 @@ def forward(self, inputs):
# %%
# Print some results
print(journal)
predictions = strategy.prediction_results
print(max(predictions), min(predictions))
print(account)

# %%
1 change: 1 addition & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run_price_item_feed(feed: Feed, symbols: list[str], test_case: TestCase, tim
case Quote():
for f in item.data:
test_case.assertTrue(math.isfinite(f))
test_case.assertGreaterEqual(item.data[0], item.data[2]) # ask >= bid

test_case.assertGreaterEqual(n_items, min_items)

Expand Down
29 changes: 29 additions & 0 deletions tests/integration/test_alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest
from alpaca.data.timeframe import TimeFrame

from roboquant.alpaca.feed import AlpacaHistoricCryptoFeed, AlpacaHistoricStockFeed
from tests.common import run_price_item_feed


class TestAlpaca(unittest.TestCase):

def test_alpaca_feed(self):
feed = AlpacaHistoricStockFeed()
feed.retrieve_bars("AAPL", "TSLA", start="2024-03-01", end="2024-03-02")
run_price_item_feed(feed, ["AAPL", "TSLA"], self)

feed = AlpacaHistoricStockFeed()
feed.retrieve_trades("AAPL", "TSLA", start="2024-03-01T18:00:00", end="2024-03-01T18:01:00")
run_price_item_feed(feed, ["AAPL", "TSLA"], self)

feed = AlpacaHistoricStockFeed()
feed.retrieve_quotes("AAPL", "TSLA", start="2024-03-01T18:00:00", end="2024-03-01T18:01:00")
run_price_item_feed(feed, ["AAPL", "TSLA"], self)

feed = AlpacaHistoricCryptoFeed()
feed.retrieve_bars("BTC/USDT", start="2024-03-01", end="2024-03-02", resolution=TimeFrame.Hour) # type: ignore
run_price_item_feed(feed, ["BTC/USDT"], self)


if __name__ == "__main__":
unittest.main()
7 changes: 3 additions & 4 deletions tests/unit/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ def test_lstm_model(self):

# Run the trained model with the last 4 years of data
tf = rq.Timeframe.fromisoformat("2020-01-01", "2024-01-01")
rq.run(feed, strategy, timeframe=tf)
predictions = strategy.prediction_results
self.assertEqual(987, len(predictions))
self.assertNotEqual(max(predictions), min(predictions))
account = None
account = rq.run(feed, strategy, timeframe=tf)
self.assertTrue(account)


if __name__ == "__main__":
Expand Down

0 comments on commit 75b2094

Please sign in to comment.