Skip to content

Commit

Permalink
fixed features reset
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed May 17, 2024
1 parent a1fafa5 commit 4fe7d52
Show file tree
Hide file tree
Showing 15 changed files with 159 additions and 79 deletions.
4 changes: 3 additions & 1 deletion roboquant/brokers/simbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from roboquant.account import Account, Position
from roboquant.brokers.broker import Broker, _update_positions
from roboquant.event import Event, PriceItem
from roboquant.event import Event, PriceItem, Quote
from roboquant.order import Order, OrderStatus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,6 +86,8 @@ def _get_execution_price(self, order, item) -> float:
The default implementation is a fixed slippage percentage based on the configured price_type.
"""
if isinstance(item, Quote):
return item.ask_price if order.is_buy else item.bid_price

price = item.price(self.price_type)
correction = self.slippage if order.is_buy else -self.slippage
Expand Down
8 changes: 8 additions & 0 deletions roboquant/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def price(self, price_type: str = "DEFAULT") -> float:
# Default is the mid-point price
return (self.data[0] + self.data[2]) / 2.0

@property
def ask_price(self) -> float:
return self.data[0]

@property
def bid_price(self) -> float:
return self.data[2]

@property
def ask_volume(self) -> float:
return self.data[1]
Expand Down
2 changes: 1 addition & 1 deletion roboquant/feeds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .randomwalk import RandomWalk
from .sqllitefeed import SQLFeed
from .tiingo import TiingoLiveFeed, TiingoHistoricFeed
from .feedutil import get_sp500_symbols
from .feedutil import get_sp500_symbols, print_feed_items, count_events

try:
from .yahoo import YahooFeed
Expand Down
8 changes: 0 additions & 8 deletions roboquant/feeds/csvfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,3 @@ def yahoo(cls, path, frequency="1d"):
"""Parse one or more CSV files that meet the Yahoo Finance format"""
columns = ["Date", "Open", "High", "Low", "Close", "Volume", "Adj Close"]
return cls(path, columns=columns, adj_close=True, time_offset="21:00:00+00:00", frequency=frequency)


if __name__ == "__main__":
t = datetime.strptime("210000", "%H%M%S").time()
print(t)

t = time.fromisoformat("210000")
print(t)
11 changes: 11 additions & 0 deletions roboquant/feeds/feedutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,14 @@ def print_feed_items(feed: Feed, timeframe: Timeframe | None = None, timeout: fl
print(event.time)
for item in event.items:
print("======> ", item)


def count_events(feed: Feed, timeframe: Timeframe | None = None, timeout: float | None = None, include_empty=False):
"""Count the number of events in a feed"""

channel = feed.play_background(timeframe)
events = 0
while evt := channel.get(timeout):
if evt.items or include_empty:
events += 1
return events
10 changes: 7 additions & 3 deletions roboquant/feeds/sqllitefeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def get_item(self, row) -> PriceItem:
symbol = row[1]
prices = row[2:7]
freq = row[7]
return Bar(symbol, array('f', prices), freq)
return Bar(symbol, array("f", prices), freq)

symbol = row[1]
data = row[2:6]
return Quote(symbol, array('f', data))
return Quote(symbol, array("f", data))

def play(self, channel: EventChannel):
con = sqlite3.connect(self.db_file)
Expand All @@ -79,7 +79,11 @@ def play(self, channel: EventChannel):
t_old = ""
items = []
tf = channel.timeframe
result = cur.execute(SQLFeed._sql_select_by_date, [tf.start, tf.end]) if tf else cur.execute(SQLFeed._sql_select)
result = (
cur.execute(SQLFeed._sql_select_by_date, [tf.start.isoformat(), tf.end.isoformat()])
if tf
else cur.execute(SQLFeed._sql_select)
)

for row in result:
t = row[0]
Expand Down
42 changes: 26 additions & 16 deletions roboquant/journals/basicjournal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,43 @@

@dataclass
class BasicJournal(Journal):
"""Tracks a number of basic metrics:
- total number of events, items, signals, orders and max open positions.
"""Track the following metrics:
- total number of events, items, buy and sell signals, buy and sell orders
- the max open positions.
It will also log these values at each step in the run at `info` level.
This journal adds little overhead to a run, both CPU and memory wise, and is helpful in
determining if the setup works correctly.
"""

events: int
items: int
signals: int
orders: int
max_positions: int
events: int = 0
items: int = 0
buy_signals: int = 0
sell_signals: int = 0
buy_orders: int = 0
sell_orders: int = 0
max_positions: int = 0

def __init__(self):
self.events = 0
self.signals = 0
self.items = 0
self.orders = 0
self.max_positions = 0
def __init__(self, log_level=logging.INFO):
self.__log_level = log_level

def track(self, event, account, signals, orders):
self.items += len(event.items)
self.events += 1
self.signals += len(signals)
self.orders += len(orders)
self.buy_signals += len([s for s in signals if s.is_buy])
self.sell_signals += len([s for s in signals if s.is_sell])
self.buy_orders += len([o for o in orders if o.is_buy])
self.sell_orders += len([o for o in orders if o.is_sell])
self.max_positions = max(self.max_positions, len(account.positions))

logger.info("time=%s info=%s", event.time, self)
logger.log(self.__log_level, "time=%s info=%s", event.time, self)

def reset(self):
self.events: int = 0
self.items: int = 0
self.buy_signals: int = 0
self.sell_signals: int = 0
self.buy_orders: int = 0
self.sell_orders: int = 0
self.max_positions: int = 0
4 changes: 2 additions & 2 deletions roboquant/ml/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self, symbols: list[str]):

@staticmethod
def __limit(rating):
print(rating)
rating = float(rating)
assert math.isfinite(rating), f"rating not finite rating={rating}"
return max(-1.0, min(1.0, rating))
Expand Down Expand Up @@ -116,7 +115,7 @@ def __init__(
def get_observation(self, evt: Event) -> NDArray[np.float32]:
return self.obs_feature.calc(evt, None)

def get_reward(self, evt: Event, account: Account):
def get_reward(self, evt: Event, account: Account) -> NDArray[np.float32]:
return self.reward_feature.calc(evt, account)

def step(self, action):
Expand All @@ -139,6 +138,7 @@ def step(self, action):

def reset(self, *, seed=None, options=None):
super().reset(seed=seed, options=options)
logger.info("environment resetting")
self.broker.reset()
self.trader.reset()
self.obs_feature.reset()
Expand Down
37 changes: 28 additions & 9 deletions roboquant/ml/features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime, timezone
from typing import Any

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -40,8 +41,8 @@ def returns(self, period=1):
def normalize(self, min_period=3):
return NormalizeFeature(self, min_period)

def cache(self):
return CacheFeature(self)
def cache(self, validate=False):
return CacheFeature(self, validate)

def __getitem__(self, *args):
return SlicedFeature(self, args)
Expand Down Expand Up @@ -75,6 +76,9 @@ def calc(self, evt, account):
def size(self):
return self._size

def reset(self):
return self.feature.reset()


class TrueRangeFeature(Feature):
"""Calculates the true range value for a symbol"""
Expand Down Expand Up @@ -110,9 +114,9 @@ def reset(self):

class FixedValueFeature(Feature):

def __init__(self, value: NDArray) -> None:
def __init__(self, value: Any) -> None:
super().__init__()
self.value = value
self.value = np.array(value, dtype="float32")

def size(self) -> int:
return len(self.value)
Expand Down Expand Up @@ -251,6 +255,10 @@ def calc(self, evt, account):
def size(self) -> int:
return self._size

def reset(self):
for feature in self.features:
feature.reset()


class NormalizeFeature(Feature):
"""online normalization calculator"""
Expand All @@ -271,7 +279,7 @@ def denormalize(self, value):

def __update(self, new_value):
(count, mean, m2) = self.existing_aggregate
mask = ~ np.isnan(new_value)
mask = ~np.isnan(new_value)
count[mask] += 1
delta = new_value - mean
mean[mask] += delta[mask] / count[mask]
Expand All @@ -295,6 +303,7 @@ def size(self) -> int:

def reset(self):
self.existing_aggregate = (self._zero_int(), self._zeros(), self._zeros())
self.feature.reset()


class FillFeature(Feature):
Expand Down Expand Up @@ -327,15 +336,22 @@ class CacheFeature(Feature):
Typically, this doesn't work for features that depend on account values.
"""

def __init__(self, feature: Feature) -> None:
def __init__(self, feature: Feature, validate=False) -> None:
super().__init__()
self.feature: Feature = feature
self._cache: dict[datetime, NDArray] = {}
self.validate = validate

def calc(self, evt, account):
time = evt.time
if time in self._cache:
return self._cache[time]
values = self._cache[time]
if self.validate:
calc_values = self.feature.calc(evt, account)
assert np.array_equal(
values, calc_values, equal_nan=True
), f"Wrong cache time={time} cache={values} calculated={calc_values}"
return values

values = self.feature.calc(evt, account)
self._cache[time] = values
Expand Down Expand Up @@ -449,8 +465,7 @@ def reset(self):


class MaxReturnFeature2(Feature):
"""Calculate the maximum return over a certain period.
"""
"""Calculate the maximum return over a certain period."""

def __init__(self, feature: Feature, period: int) -> None:
super().__init__()
Expand Down Expand Up @@ -540,6 +555,7 @@ def size(self) -> int:
def reset(self):
self.history = None
self.feature.reset()
self._cnt = 0


class DayOfWeekFeature(Feature):
Expand Down Expand Up @@ -617,3 +633,6 @@ def _calc(self, symbol: str, ohlcv: OHLCVBuffer) -> float:

def size(self) -> int:
return len(self.symbols)

def reset(self):
self._data = {}
2 changes: 1 addition & 1 deletion roboquant/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class SignalType(Flag):
ENTRY_EXIT = ENTRY | EXIT

def __str__(self):
return self.name
return self.name or str(self)


@dataclass(slots=True, frozen=True)
Expand Down

0 comments on commit 4fe7d52

Please sign in to comment.