Skip to content

Commit

Permalink
Improved QuestDBJournal for multi-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed May 7, 2024
1 parent 3e98997 commit f260591
Show file tree
Hide file tree
Showing 20 changed files with 241 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.roboquant.alpaca
import net.jacobpeterson.alpaca.AlpacaAPI
import net.jacobpeterson.alpaca.model.util.apitype.MarketDataWebsocketSourceType
import net.jacobpeterson.alpaca.model.util.apitype.TraderAPIEndpointType
import net.jacobpeterson.alpaca.openapi.marketdata.model.StockFeed
import org.roboquant.common.Config
import org.roboquant.common.Exchange

Expand Down Expand Up @@ -48,6 +49,7 @@ data class AlpacaConfig(
var secretKey: String = Config.getProperty("alpaca.secret.key", ""),
var accountType: AccountType = AccountType.PAPER,
var dataType: DataType = DataType.IEX,
var stockFeed: StockFeed = StockFeed.IEX,
var extendedHours: Boolean = Config.getProperty("alpaca.extendedhours", false),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,13 @@ class AlpacaBroker(
*/
private fun syncOrders() {
_account.orders.forEach {
val aOrderId = orderPlacer.get(it.order)
if (aOrderId != null) {
if (it.open) {
val aOrderId = it.order.id
logger.info { "open order id=$aOrderId" }
val alpacaOrder = alpacaAPI.trader().orders().getOrderByOrderID(UUID.fromString(aOrderId), false)
updateIAccountOrder(it.order, alpacaOrder)
} else {
logger.warn("cannot find order ${it.order} in orderMap")
logger.warn("cannot find order=${it.order} in orderMap")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import net.jacobpeterson.alpaca.openapi.marketdata.api.StockApi
import net.jacobpeterson.alpaca.openapi.marketdata.model.Sort
import net.jacobpeterson.alpaca.openapi.marketdata.model.StockAdjustment
import net.jacobpeterson.alpaca.openapi.marketdata.model.StockBar
import net.jacobpeterson.alpaca.openapi.marketdata.model.StockFeed
import org.roboquant.common.Asset
import org.roboquant.common.Logging
import org.roboquant.common.TimeSpan
Expand All @@ -41,7 +40,7 @@ class AlpacaHistoricFeed(
configure: AlpacaConfig.() -> Unit = {}
) : HistoricPriceFeed() {

private val limit = 10_000L
// private val limit = 10_000L
private val config = AlpacaConfig()
private val stockData: StockApi
private val alpacaAPI: AlpacaAPI
Expand Down Expand Up @@ -69,7 +68,7 @@ class AlpacaHistoricFeed(
var nextPageToken: String? = null
do {
val resp = stockData.stockQuotes(
symbols, start, end, limit, "", StockFeed.IEX, "USD", nextPageToken, Sort.ASC
symbols, start, end, null, "", config.stockFeed, "USD", nextPageToken, Sort.ASC
)
for ((symbol, quotes) in resp.quotes) {
val asset = Asset(symbol)
Expand Down Expand Up @@ -100,7 +99,7 @@ class AlpacaHistoricFeed(
var nextPageToken: String? = null
do {
val resp = stockData.stockTrades(
symbols, start, end, limit, "", StockFeed.IEX, "USD", nextPageToken, Sort.ASC
symbols, start, end, null, "", config.stockFeed, "USD", nextPageToken, Sort.ASC
)
for ((symbol, trades) in resp.trades) {
val asset = Asset(symbol)
Expand All @@ -120,8 +119,8 @@ class AlpacaHistoricFeed(
val asset = Asset(symbol)
for (bar in bars) {
val action = PriceBar(asset, bar.o, bar.h, bar.l, bar.c, bar.v.toDouble(), timeSpan)
val now = bar.t.toInstant()
add(now, action)
val time = bar.t.toInstant()
add(time, action)
}
}

Expand All @@ -131,7 +130,8 @@ class AlpacaHistoricFeed(
fun retrieveStockPriceBars(
symbols: String,
timeframe: Timeframe,
frequency: String = "1Day"
frequency: String = "1Day",
adjustment: StockAdjustment = StockAdjustment.ALL
) {
val (start, end) = toOffset(timeframe)

Expand All @@ -142,17 +142,16 @@ class AlpacaHistoricFeed(
frequency,
start,
end,
limit,
StockAdjustment.ALL,
"",
StockFeed.IEX,
null,
adjustment,
null,
config.stockFeed,
"USD",
nextPageToken,
Sort.ASC
)
for ((symbol, bars) in resp.bars) {
processBars(symbol, bars, null)

}
nextPageToken = resp.nextPageToken
} while (nextPageToken != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
package org.roboquant.alpaca

import net.jacobpeterson.alpaca.AlpacaAPI
import net.jacobpeterson.alpaca.model.websocket.marketdata.streams.crypto.model.bar.CryptoBarMessage
import net.jacobpeterson.alpaca.model.websocket.marketdata.streams.crypto.model.quote.CryptoQuoteMessage
import net.jacobpeterson.alpaca.model.websocket.marketdata.streams.crypto.model.trade.CryptoTradeMessage
import net.jacobpeterson.alpaca.model.websocket.marketdata.streams.stock.model.bar.StockBarMessage
import net.jacobpeterson.alpaca.model.websocket.marketdata.streams.stock.model.quote.StockQuoteMessage
import net.jacobpeterson.alpaca.model.websocket.marketdata.streams.stock.model.trade.StockTradeMessage
import net.jacobpeterson.alpaca.websocket.marketdata.streams.crypto.CryptoMarketDataListenerAdapter
import net.jacobpeterson.alpaca.websocket.marketdata.streams.crypto.CryptoMarketDataWebsocketInterface
import net.jacobpeterson.alpaca.websocket.marketdata.streams.stock.StockMarketDataListenerAdapter
import net.jacobpeterson.alpaca.websocket.marketdata.streams.stock.StockMarketDataWebsocketInterface
import org.roboquant.common.Asset
Expand Down Expand Up @@ -67,8 +72,6 @@ class AlpacaLiveFeed(
private val config = AlpacaConfig()
private val alpacaAPI: AlpacaAPI
private val logger = Logging.getLogger(AlpacaLiveFeed::class)
private val listener = createStockHandler()


init {
config.configure()
Expand All @@ -83,7 +86,7 @@ class AlpacaLiveFeed(
*/
private fun connect() {
connectMarket(alpacaAPI.stockMarketDataStream())
// connectMarket(alpacaAPI.cryptoMarketDataStream())
connectCryptoMarket(alpacaAPI.cryptoMarketDataStream())
}

/**
Expand All @@ -98,7 +101,25 @@ class AlpacaLiveFeed(
if (!connection.isValid) {
throw ConfigurationException("couldn't establish $connection")
} else {
connection.setListener(listener)
val stockListener = createStockHandler()
connection.setListener(stockListener)
}
}

/**
* Connect to ta market data provider and start listening. This can be the stocks or crypto market data feeds.
*/
private fun connectCryptoMarket(connection: CryptoMarketDataWebsocketInterface) {
require(!connection.isConnected) { "already connected, disconnect first" }
val timeoutMillis: Long = 5_000
connection.setAutomaticallyReconnect(true)
connection.connect()
connection.waitForAuthorization(timeoutMillis, TimeUnit.MILLISECONDS)
if (!connection.isValid) {
throw ConfigurationException("couldn't establish $connection")
} else {
val cryptoListener = createCryptoHandler()
connection.setListener(cryptoListener)
}
}

Expand All @@ -125,12 +146,11 @@ class AlpacaLiveFeed(
* Subscribe to stock market data based on the passed [symbols] and [type]
*/
fun subscribeStocks(vararg symbols: String, type: PriceActionType = PriceActionType.PRICE_BAR) {
// validateSymbols(symbols, availableStocksMap)
val s = symbols.toList()
val s = symbols.toSet()
when (type) {
PriceActionType.TRADE -> alpacaAPI.stockMarketDataStream().tradeSubscriptions.addAll(s)
PriceActionType.QUOTE -> alpacaAPI.stockMarketDataStream().quoteSubscriptions.addAll(s)
PriceActionType.PRICE_BAR -> alpacaAPI.stockMarketDataStream().minuteBarSubscriptions.addAll(s)
PriceActionType.TRADE -> alpacaAPI.stockMarketDataStream().tradeSubscriptions= s
PriceActionType.QUOTE -> alpacaAPI.stockMarketDataStream().quoteSubscriptions= s
PriceActionType.PRICE_BAR -> alpacaAPI.stockMarketDataStream().minuteBarSubscriptions= s
}
}

Expand All @@ -139,12 +159,12 @@ class AlpacaLiveFeed(
*/
@Suppress("unused")
fun subscribeCrypto(vararg symbols: String, type: PriceActionType = PriceActionType.PRICE_BAR) {
// validateSymbols(symbols, availableCryptoMap)
val s = symbols.toList()

val s = symbols.toSet()
when (type) {
PriceActionType.TRADE -> alpacaAPI.cryptoMarketDataStream().tradeSubscriptions.addAll(s)
PriceActionType.QUOTE -> alpacaAPI.cryptoMarketDataStream().quoteSubscriptions.addAll(s)
PriceActionType.PRICE_BAR -> alpacaAPI.cryptoMarketDataStream().minuteBarSubscriptions.addAll(s)
PriceActionType.TRADE -> alpacaAPI.cryptoMarketDataStream().tradeSubscriptions=s
PriceActionType.QUOTE -> alpacaAPI.cryptoMarketDataStream().quoteSubscriptions=s
PriceActionType.PRICE_BAR -> alpacaAPI.cryptoMarketDataStream().minuteBarSubscriptions=s
}
}

Expand Down Expand Up @@ -192,6 +212,47 @@ class AlpacaLiveFeed(
}


}
}

private fun createCryptoHandler(): CryptoMarketDataListenerAdapter {
return object : CryptoMarketDataListenerAdapter() {

override fun onTrade(trade: CryptoTradeMessage) {
val asset = Asset(trade.symbol)
val item = TradePrice(asset, trade.price)
val time = trade.timestamp.toInstant()
send(time, item)
}

override fun onQuote(quote: CryptoQuoteMessage) {
val asset = Asset(quote.symbol)
val item = PriceQuote(
asset,
quote.askPrice,
quote.askSize.toDouble(),
quote.bidPrice,
quote.bidSize.toDouble(),
)
val time = quote.timestamp.toInstant()
send(time, item)
}

override fun onMinuteBar(bar: CryptoBarMessage) {
val asset = Asset(bar.symbol)
val item = PriceBar(
asset,
bar.open,
bar.high,
bar.low,
bar.close,
bar.tradeCount.toDouble()
)
val time = bar.timestamp.toInstant()
send(time, item)
}


}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,12 @@ internal class AlpacaSamples {
@Test
@Ignore
internal fun alpacaHistoricFeed3() {

val feed = AlpacaHistoricFeed()

// We get minute data
val tf = Timeframe.parse("2024-01-04", "2024-01-05")
val tf = Timeframe.parse("2016-01-01", "2024-05-05")
feed.retrieveStockPriceBars("AAPL", timeframe = tf, "1Min")
val events = feed.toList()

with(events) {
println("events=$size start=${first().time} last=${last().time} symbols=${feed.assets.symbols.toList()}")
}

println(feed)
}

@Test
Expand All @@ -141,7 +135,7 @@ internal class AlpacaSamples {
}
val order = MarketOrder(Asset("IBM"), Size.ONE)
broker.place(listOf(order))
Thread.sleep(5000)
Thread.sleep(10_000)
val account = broker.sync()
println(account.fullSummary())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.roboquant.common.Size
import org.roboquant.common.USD
import org.roboquant.feeds.random.RandomWalkFeed
import org.roboquant.feeds.util.HistoricTestFeed
import org.roboquant.journals.MetricsJournal
import org.roboquant.journals.MemoryJournal
import org.roboquant.metrics.AccountMetric
import org.roboquant.orders.MarketOrder
import org.roboquant.strategies.EMAStrategy
Expand Down Expand Up @@ -66,7 +66,7 @@ object TestData {

val data by lazy {
val feed = HistoricTestFeed(50..150)
val journal = MetricsJournal(AccountMetric())
val journal = MemoryJournal(AccountMetric())
org.roboquant.run(feed, EMAStrategy(), journal)
journal.getMetric("account.equity")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.roboquant.common.Logging
import org.roboquant.common.Observation
import org.roboquant.common.TimeSeries
import org.roboquant.feeds.Event
import org.roboquant.journals.Journal
import org.roboquant.journals.MetricsJournal
import org.roboquant.metrics.Metric
import org.roboquant.orders.Order
import org.roboquant.strategies.Signal
Expand All @@ -49,7 +49,7 @@ class QuestDBJournal(
workers: Int = 1,
private val partition: String = QuestDBRecorder.NONE,
private val truncate: Boolean = false
) : Journal {
) : MetricsJournal {

private val logger = Logging.getLogger(this::class)
private var engine: CairoEngine
Expand All @@ -60,19 +60,34 @@ class QuestDBJournal(
engine = getEngine(dbPath)
ctx = SqlExecutionContextImpl(engine, workers)
createTable(table)
logger.info { "db=$dbPath table=$table" }
}


companion object {

private var engines = mutableMapOf<Path, CairoEngine>()

@Synchronized
fun getEngine(dbPath: Path): CairoEngine {
if (Files.notExists(dbPath)) {
Files.createDirectories(dbPath)
if (dbPath !in engines) {
if (Files.notExists(dbPath)) {
Files.createDirectories(dbPath)
}
require(dbPath.isDirectory()) { "dbPath needs to be a directory" }
val config = DefaultCairoConfiguration(dbPath.toString())
engines[dbPath] = CairoEngine(config)
}
require(dbPath.isDirectory()) { "dbPath needs to be a directory" }
val config = DefaultCairoConfiguration(dbPath.toString())
val engine = CairoEngine(config)
return engine
return engines.getValue(dbPath)
}

fun getRuns(dbPath: Path): Set<String> {
val engine = getEngine(dbPath)
return engine.tables().toSet()
}

fun close(dbPath: Path) {
engines[dbPath]?.close()
}

}
Expand All @@ -88,10 +103,10 @@ class QuestDBJournal(
/**
* Get a metric for a specific [table]
*/
fun getMetric(metricName: String): TimeSeries {
override fun getMetric(name: String): TimeSeries {
val result = mutableListOf<Observation>()

engine.query("select time, value from '$table' where metric='$metricName'") {
engine.query("select time, value from '$table' where metric='$name'") {
while (hasNext()) {
val r = this.record
val o = Observation(ofEpochMicro(r.getTimestamp(0)), r.getDouble(1))
Expand All @@ -112,8 +127,8 @@ class QuestDBJournal(
}


fun getMetricNames(run: String): Set<String> {
return engine.distictSymbol(run, "name").toSortedSet()
override fun getMetricNames(): Set<String> {
return engine.distictSymbol(table, "name").toSortedSet()
}

/**
Expand Down Expand Up @@ -142,10 +157,10 @@ class QuestDBJournal(
}

/**
* Close the engine and context
* Close the underlying context
*/
fun close() {
engine.close()
// engine.close()
ctx.close()
}

Expand Down
Loading

0 comments on commit f260591

Please sign in to comment.