Skip to content

Commit

Permalink
made signals a list
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 27, 2024
1 parent 9cc2b7b commit 45b68d5
Show file tree
Hide file tree
Showing 21 changed files with 106 additions and 94 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ jobs:
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
18 changes: 18 additions & 0 deletions bin/local_install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

[[ ! -f "LICENSE" ]] && echo "run the script from the project root directory like this: ./bin/publish.sh" && exit 1

source .venv/bin/activate

rm -rf ./runs

# QA
flake8 roboquant tests || exit 1
pylint roboquant tests || exit 1
python -m unittest discover -s tests/unit || exit 1

# Build
rm -rf dist
python -m build || exit 1

# Install
pip install .
2 changes: 1 addition & 1 deletion roboquant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from .event import Event, PriceItem, Bar, Trade, Quote
from .order import Order, OrderStatus
from .run import run
from .signal import SignalType, Signal, BUY, SELL
from .signal import SignalType, Signal
from .timeframe import Timeframe
23 changes: 13 additions & 10 deletions roboquant/brokers/alpacabroker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import time
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import OrderSide, TimeInForce
Expand All @@ -13,16 +12,17 @@
from roboquant.account import Account, Position
from roboquant.config import Config
from roboquant.event import Event
from roboquant.brokers.broker import Broker
from roboquant.brokers.broker import LiveBroker
from roboquant.order import Order, OrderStatus


logger = logging.getLogger(__name__)


class AlpacaBroker(Broker):
class AlpacaBroker(LiveBroker):

def __init__(self, api_key=None, secret_key=None) -> None:
super().__init__()
self.__account = Account()
config = Config()
api_key = api_key or config.get("alpaca.public.key")
Expand Down Expand Up @@ -59,13 +59,7 @@ def _sync_positions(self):
self.__account.positions[p.symbol] = new_pos

def sync(self, event: Event | None = None) -> Account:
now = datetime.now(timezone.utc)

if event:
# Let make sure we don't use IBKRBroker by mistake during a back-test.
if now - event.time > timedelta(minutes=30):
logger.critical("received event from the past, now=%s event-time=%s", now, event.time)
raise ValueError(f"received event too far in the past now={now} event-time={event.time}")
now = self.guard(event)

client = self.__client
acc: TradeAccount = client.get_account() # type: ignore
Expand Down Expand Up @@ -122,8 +116,17 @@ def _get_replace_request(self, order: Order):
broker = AlpacaBroker()
account = broker.sync()
print(account)

tsla_order = Order("TSLA", 10)
broker.place_orders([tsla_order])
time.sleep(5)
account = broker.sync()
print(account)

tesla_size = account.get_position_size("TSLA")
if tesla_size:
tsla_order = Order("TSLA", -tesla_size)
broker.place_orders([tsla_order])
time.sleep(5)
account = broker.sync()
print(account)
9 changes: 6 additions & 3 deletions roboquant/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def __init__(self) -> None:
super().__init__()
self.max_delay = timedelta(minutes=30)

def guard(self, event: Event | None = None):
if not event:
return
def guard(self, event: Event | None = None) -> datetime:

now = datetime.now(timezone.utc)

if not event:
return now

if now - event.time > self.max_delay:
raise ValueError(f"received event too far in the past now={now} event-time={event.time}")

return now
18 changes: 5 additions & 13 deletions roboquant/brokers/ibkr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import threading
import time
from datetime import datetime, timezone, timedelta
from datetime import datetime, timedelta
from decimal import Decimal

from ibapi import VERSION
Expand All @@ -14,7 +14,7 @@
from roboquant.account import Account, Position
from roboquant.event import Event
from roboquant.order import Order, OrderStatus
from roboquant.brokers.broker import Broker, _update_positions
from roboquant.brokers.broker import LiveBroker, _update_positions

assert VERSION["major"] == 10 and VERSION["minor"] == 19, "Wrong version of the IBAPI found"

Expand Down Expand Up @@ -116,7 +116,7 @@ def orderStatus(
logger.warning("received status for unknown order id=%s status=%s", orderId, status)


class IBKRBroker(Broker):
class IBKRBroker(LiveBroker):
"""
Attributes
==========
Expand All @@ -137,6 +137,7 @@ class IBKRBroker(Broker):
"""

def __init__(self, host="127.0.0.1", port=4002, client_id=123) -> None:
super().__init__()
self.__account = Account()
self.contract_mapping: dict[str, Contract] = {}
api = _IBApi()
Expand Down Expand Up @@ -170,15 +171,7 @@ def _should_sync(self, now: datetime):

def sync(self, event: Event | None = None) -> Account:
"""Sync with the IBKR account"""

logger.debug("start sync")
now = datetime.now(timezone.utc)

if event:
# Let make sure we don't use IBKRBroker by mistake during a back-test.
if now - event.time > timedelta(minutes=30):
logger.critical("received event from the past, now=%s event-time=%s", now, event.time)
raise ValueError(f"received event too far in the past now={now} event-time={event.time}")
now = self.guard(event)

api = self.__api
acc = self.__account
Expand All @@ -196,7 +189,6 @@ def sync(self, event: Event | None = None) -> Account:
acc.cash = api.get_cash()

_update_positions(acc, event)
logger.debug("end sync")
return acc

def place_orders(self, orders):
Expand Down
2 changes: 1 addition & 1 deletion roboquant/journals/journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Journal(Protocol):
It serves as a tool to track and analyze their performance, decisions, and outcomes over time
"""

def track(self, event: Event, account: Account, signals: dict[str, Signal], orders: list[Order]):
def track(self, event: Event, account: Account, signals: list[Signal], orders: list[Order]):
"""invoked at each step of a run that provides the journal with the opportunity to
track and log various metrics."""
...
4 changes: 2 additions & 2 deletions roboquant/ml/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, symbols: list[str]):
self.symbols = symbols

def get_signals(self, action, _):
return {symbol: Signal(rating) for symbol, rating in zip(self.symbols, action)}
return [Signal(symbol, float(rating)) for symbol, rating in zip(self.symbols, action)]

def get_action_space(self):
return spaces.Box(-1.0, 1.0, shape=(len(self.symbols),), dtype=np.float32)
Expand Down Expand Up @@ -137,7 +137,7 @@ def reset(self, *, seed=None, options=None):
self.event = self.channel.get()
assert self.event is not None, "feed empty during warmup"
self.account = self.broker.sync(self.event)
self.trader.create_orders({}, self.event, self.account)
self.trader.create_orders([], self.event, self.account)
observation = self.get_observation(self.event)
self.get_reward(self.event, self.account)
if not np.any(np.isnan(observation)):
Expand Down
16 changes: 8 additions & 8 deletions roboquant/ml/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from roboquant.ml.envs import Action2Signals, StrategyEnv, TraderEnv
from roboquant.ml.features import Feature, NormalizeFeature
from roboquant.order import Order
from roboquant.signal import BUY, SELL, Signal
from roboquant.signal import Signal
from roboquant.strategies.strategy import Strategy
from roboquant.traders.trader import Trader

Expand All @@ -30,7 +30,7 @@ def __init__(self, obs_feature: Feature, action_2_signals: Action2Signals, polic
def from_env(cls, env: StrategyEnv, policy):
return cls(env.obs_feature, env.action_2_signals, policy)

def create_signals(self, event) -> dict[str, Signal]:
def create_signals(self, event):
obs = self.obs_feature.calc(event, None)
if np.any(np.isnan(obs)):
return {}
Expand Down Expand Up @@ -81,17 +81,17 @@ def __init__(self, input_feature: Feature, label_feature: Feature, history: int,
self._hist = deque(maxlen=history)
self._dtype = dtype

def create_signals(self, event: Event) -> dict[str, Signal]:
def create_signals(self, event: Event):
h = self._hist
row = self.input_feature.calc(event, None)
h.append(row)
if len(h) == h.maxlen:
x = np.asarray(h, dtype=self._dtype)
return self.predict(x)
return {}
return []

@abstractmethod
def predict(self, x: NDArray) -> dict[str, Signal]: ...
def predict(self, x: NDArray) -> list[Signal]: ...

def _get_xy(self, feed, timeframe=None, warmup=0) -> tuple[NDArray, NDArray]:
channel = feed.play_background(timeframe)
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
self.symbol = symbol
self.prediction_results = []

def predict(self, x) -> dict[str, Signal]:
def predict(self, x):
x = torch.asarray(x)
x = torch.unsqueeze(x, dim=0) # add the batch dimension

Expand All @@ -168,9 +168,9 @@ def predict(self, x) -> dict[str, Signal]:

self.prediction_results.append(p)
if p >= self.buy_pct:
return {self.symbol: BUY}
return [Signal.buy(self.symbol)]
if p <= self.sell_pct:
return {self.symbol: SELL}
return [Signal.sell(self.symbol)]

return {}

Expand Down
2 changes: 1 addition & 1 deletion roboquant/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(

while event := channel.get(heartbeat_timeout):
account = broker.sync(event)
signals = strategy.create_signals(event) if strategy else {}
signals = strategy.create_signals(event) if strategy else []
orders = trader.create_orders(signals, event, account)
broker.place_orders(orders)
if journal:
Expand Down
23 changes: 9 additions & 14 deletions roboquant/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def __str__(self):

@dataclass(slots=True, frozen=True)
class Signal:
"""Signal that a strategy can create.
It contains both a rating between -1.0 and 1.0 and the type of signal.
"""Signal that a strategy can create.It contains both a rating and the type of signal.
A rating is a float normally between -1.0 and 1.0, where -1.0 is a strong sell and 1.0 is a strong buy.
But in cases it can exceed these values. It is up to the used trader to handle these values
Examples:
```
Expand All @@ -25,19 +27,19 @@ class Signal:
Signal("XYZ", 0.5, SignalType.ENTRY)
```
"""

symbol: str
rating: float
type: SignalType = SignalType.ENTRY_EXIT

@staticmethod
def buy(signal_type=SignalType.ENTRY_EXIT):
def buy(symbol, signal_type=SignalType.ENTRY_EXIT):
"""Create a BUY signal with a rating of 1.0"""
return Signal(1.0, signal_type)
return Signal(symbol, 1.0, signal_type)

@staticmethod
def sell(signal_type=SignalType.ENTRY_EXIT):
def sell(symbol, signal_type=SignalType.ENTRY_EXIT):
"""Create a SELL signal with a rating of -1.0"""
return Signal(-1.0, signal_type)
return Signal(symbol, -1.0, signal_type)

@property
def is_buy(self):
Expand All @@ -54,10 +56,3 @@ def is_entry(self):
@property
def is_exit(self):
return SignalType.EXIT in self.type


BUY = Signal.buy(SignalType.ENTRY_EXIT)
"""BUY signal with a rating of 1.0 and valid for both entry and exit signals"""

SELL = Signal.sell(SignalType.ENTRY_EXIT)
"""SELL signal with a rating of -1.0 and valid for both entry and exit signals"""
6 changes: 3 additions & 3 deletions roboquant/strategies/barstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def __init__(self, size: int) -> None:
self._data: dict[str, OHLCVBuffer] = {}
self.size = size

def create_signals(self, event) -> dict[str, Signal]:
signals = {}
def create_signals(self, event):
signals = []
for item in event.items:
if isinstance(item, Bar):
symbol = item.symbol
Expand All @@ -31,7 +31,7 @@ def create_signals(self, event) -> dict[str, Signal]:
if ohlcv.is_full():
signal = self._create_signal(symbol, ohlcv)
if signal is not None:
signals[symbol] = signal
signals.append(signal)
return signals

@abstractmethod
Expand Down
9 changes: 5 additions & 4 deletions roboquant/strategies/emacrossover.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from roboquant.event import Event
from roboquant.signal import Signal, BUY, SELL
from roboquant.signal import Signal
from roboquant.strategies.strategy import Strategy


Expand All @@ -14,8 +14,8 @@ def __init__(self, fast_period=13, slow_period=26, smoothing=2.0, price_type="DE
self.price_type = price_type
self.min_steps = max(fast_period, slow_period)

def create_signals(self, event: Event) -> dict[str, Signal]:
signals: dict[str, Signal] = {}
def create_signals(self, event: Event):
signals = []
for symbol, price in event.get_prices(self.price_type).items():

if symbol not in self._history:
Expand All @@ -28,7 +28,8 @@ def create_signals(self, event: Event) -> dict[str, Signal]:
if step > self.min_steps:
new_rating = calculator.is_above()
if old_rating != new_rating:
signals[symbol] = BUY if new_rating else SELL
signal = Signal.buy(symbol) if new_rating else Signal.sell(symbol)
signals.append(signal)

return signals

Expand Down
Loading

0 comments on commit 45b68d5

Please sign in to comment.