Skip to content

Commit

Permalink
Centralized code to play feeds in back-ground
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Jan 19, 2024
1 parent be9ed8b commit a5686e5
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.roboquant.avro

import kotlinx.coroutines.*
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.MethodOrderer.Alphanumeric
import org.junit.jupiter.api.TestMethodOrder
import org.junit.jupiter.api.assertDoesNotThrow
Expand Down Expand Up @@ -54,11 +54,7 @@ internal class AvroFeedTest {

private fun play(feed: Feed, timeframe: Timeframe = Timeframe.INFINITE): EventChannel {
val channel = EventChannel(timeframe = timeframe)

CoroutineScope(Dispatchers.IO + Job()).launch {
feed.play(channel)
channel.close()
}
feed.playBackground(channel)
return channel
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.roboquant.charts

import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.hipparchus.stat.correlation.PearsonsCorrelation
import org.icepear.echarts.Heatmap
Expand Down Expand Up @@ -70,11 +69,7 @@ class CorrelationChart(
val channel = EventChannel(timeframe = timeframe)
val result = TreeMap<Asset, MutableList<Double>>()

val job = launch {
feed.play(channel)
channel.close()
}

val job = feed.playBackground(channel)
try {
while (true) {
val o = channel.receive()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import io.questdb.cairo.DefaultCairoConfiguration
import io.questdb.cairo.security.AllowAllSecurityContext
import io.questdb.griffin.SqlExecutionContextImpl
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.roboquant.common.Asset
import org.roboquant.common.Config
Expand Down Expand Up @@ -147,10 +146,7 @@ class QuestDBRecorder(dbPath: Path = Config.home / "questdb-prices" / "db") {
handler.createTable(tableName, partition, engine)
if (!append) engine.update("TRUNCATE TABLE $tableName")

val job = launch {
feed.play(channel)
channel.close()
}
val job = feed.playBackground(channel)

val ctx = SqlExecutionContextImpl(engine, 1).with(AllowAllSecurityContext.INSTANCE, null, null)
val writer = engine.getWriter(ctx.getTableToken(tableName), tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.roboquant.samples

import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.roboquant.Roboquant
import org.roboquant.common.*
Expand Down Expand Up @@ -101,11 +100,7 @@ internal class TiingoSamples {
) = runBlocking {
// We need a channel with enough capacity
val channel = EventChannel(10_000, timeframe = timeframe)

val job = launch {
play(channel)
channel.close()
}
val job = playBackground(channel)
var sum = 0L
var n = 0L

Expand Down
10 changes: 2 additions & 8 deletions roboquant/src/main/kotlin/org/roboquant/Roboquant.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

package org.roboquant

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.runBlocking
import org.roboquant.brokers.Account
import org.roboquant.brokers.Broker
import org.roboquant.brokers.closeSizes
Expand Down Expand Up @@ -151,11 +151,7 @@ data class Roboquant(
if (feed.timeframe.start > timeframe.end) return

val channel = EventChannel(channelCapacity, timeframe, onChannelFull)
val scope = CoroutineScope(Dispatchers.Default + Job())

val job = scope.launch {
channel.use { feed.play(it) }
}
val job = feed.playBackground(channel)

if (reset) reset(false)
start(name, timeframe)
Expand Down Expand Up @@ -186,8 +182,6 @@ data class Roboquant(
} finally {
end(name)
if (job.isActive) job.cancel()
scope.cancel()
channel.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,11 @@ class ParallelJobs {
return job
}

/**
* Add a job
*/
fun add(job: Job) {
jobs.add(job)
}

}
8 changes: 2 additions & 6 deletions roboquant/src/main/kotlin/org/roboquant/common/Timeframe.kt
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ data class Timeframe(val start: Instant, val end: Instant, val inclusive: Boolea

// Different formatters used when displaying a timeframe
private val dayFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
private val minuteFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")
private val hourFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")
private val secondFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
private val milliFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS")

Expand Down Expand Up @@ -361,10 +359,8 @@ data class Timeframe(val start: Instant, val end: Instant, val inclusive: Boolea
fun toPrettyString(): String {
val d = duration.toSeconds()
val formatter = when {
d < 1 -> milliFormatter // less than 1 second
d < 60 -> secondFormatter // less than 1 minute
d < 3600 -> minuteFormatter // less than 1 hour
d < 3600 * 24 -> hourFormatter // less than 1 day
d < 10 -> milliFormatter // less than 10 seconds
d < 3600 * 24 -> secondFormatter // less than 1 day
else -> dayFormatter
}

Expand Down
11 changes: 1 addition & 10 deletions roboquant/src/main/kotlin/org/roboquant/feeds/AggregatorFeed.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

package org.roboquant.feeds

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.launch
import org.roboquant.common.Asset
import org.roboquant.common.TimeSpan
import org.roboquant.common.Timeframe
Expand Down Expand Up @@ -78,12 +74,7 @@ class AggregatorFeed(
@Suppress("CyclomaticComplexMethod")
override suspend fun play(channel: EventChannel) {
val c = EventChannel(channel.capacity, channel.timeframe)
val scope = CoroutineScope(Dispatchers.Default + Job())
val job = scope.launch {
c.use {
feed.play(it)
}
}
val job = feed.playBackground(c)

val history = mutableMapOf<Asset, PriceBar>()
var expiration: Instant? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,7 @@ class AggregatorLiveFeed(
}
}

val job = scope.launch {
inputChannel.use {
feed.play(it)
}
}
val job = feed.playBackground(inputChannel)

try {
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ class CombinedFeed(vararg val feeds: Feed, private val channelCapacity: Int = 1)
for (feed in feeds) {
val feedChannel = EventChannel(channelCapacity, channel.timeframe)
channels.add(feedChannel)
jobs.add {
feed.play(feedChannel)
feedChannel.close()
}
jobs.add(feed.playBackground(feedChannel))
}

jobs.add {
Expand Down
60 changes: 22 additions & 38 deletions roboquant/src/main/kotlin/org/roboquant/feeds/Feed.kt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ interface Feed : AutoCloseable {
// default is to do nothing
}

/**
* (Re)play the events of the feed in the background without blocking the current thread and returns
* a reference to the coroutine as a [Job].
*
* The channel will be closed after the replay of events has finished
*
* @see play
*/
fun playBackground(channel: EventChannel): Job {
val scope = CoroutineScope(Dispatchers.Default)

return scope.launch {
channel.use { play(it) }
}

}

}

/**
Expand All @@ -81,11 +98,7 @@ inline fun <reified T : Action> Feed.filter(

val channel = EventChannel(timeframe = timeframe)
val result = mutableListOf<Pair<Instant, T>>()

val job = launch {
play(channel)
channel.close()
}
val job = playBackground(channel)

try {
while (true) {
Expand All @@ -112,11 +125,7 @@ inline fun <reified T : Action> Feed.apply(
) = runBlocking {

val channel = EventChannel(timeframe = timeframe)

val job = launch {
play(channel)
channel.close()
}
val job = playBackground(channel)

try {
while (true) {
Expand All @@ -127,7 +136,6 @@ inline fun <reified T : Action> Feed.apply(
} catch (_: ClosedReceiveChannelException) {
// Intentionally left empty
} finally {
channel.close()
if (job.isActive) job.cancel()
}

Expand All @@ -142,11 +150,7 @@ inline fun Feed.applyEvents(
) = runBlocking {

val channel = EventChannel(timeframe = timeframe)

val job = launch {
play(channel)
channel.close()
}
val job = playBackground(channel)

try {
while (true) {
Expand All @@ -157,7 +161,6 @@ inline fun Feed.applyEvents(
} catch (_: ClosedReceiveChannelException) {
// Intentionally left empty
} finally {
channel.close()
if (job.isActive) job.cancel()
}

Expand All @@ -173,10 +176,7 @@ fun Feed.toList(
val channel = EventChannel(timeframe = timeframe)
val result = mutableListOf<Event>()

val job = launch {
play(channel)
channel.close()
}
val job = playBackground(channel)

try {
while (true) {
Expand All @@ -203,11 +203,7 @@ fun Feed.validate(
): List<Pair<Instant, PriceAction>> = runBlocking {

val channel = EventChannel(timeframe = timeframe)

val job = launch {
play(channel)
channel.close()
}
val job = playBackground(channel)

val lastPrices = mutableMapOf<Asset, Double>()
val errors = mutableListOf<Pair<Instant, PriceAction>>()
Expand Down Expand Up @@ -242,15 +238,3 @@ fun Feed.validate(
fun Collection<PriceAction>.toDoubleArray(type: String = "DEFAULT"): DoubleArray =
this.map { it.getPrice(type) }.toDoubleArray()

/**
* Run a feed in the background using the provided [channel] and close the channel once done.
* This method returns the corresponding [Job] instance.
*/
internal fun Feed.runBackgroud(channel: EventChannel): Job {
val scope = CoroutineScope(Dispatchers.Default + Job())
return scope.launch {
channel.use {
play(it)
}
}
}
4 changes: 2 additions & 2 deletions roboquant/src/test/kotlin/org/roboquant/feeds/FeedTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.roboquant.feeds

import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertFalse
import kotlin.test.Test
import org.junit.jupiter.api.assertDoesNotThrow
import org.roboquant.TestData
import org.roboquant.common.Asset
import org.roboquant.common.Timeframe
import java.util.*
import kotlin.test.Test
import kotlin.test.assertContains
import kotlin.test.assertEquals
import kotlin.test.assertTrue
Expand Down Expand Up @@ -86,7 +86,7 @@ internal class FeedTest {

val channel = EventChannel(100)
assertFalse(channel.closed)
val job = feed.runBackgroud(channel)
val job = feed.playBackground(channel)
job.join()
assertTrue(channel.closed)
assertTrue(job.isCompleted)
Expand Down
5 changes: 2 additions & 3 deletions roboquant/src/test/kotlin/org/roboquant/feeds/LiveFeedTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ internal class LiveFeedTest {
feed2.heartbeatInterval = 2
val feed = CombinedLiveFeed(feed1, feed2)

Background.job {
feed.play(EventChannel())
}
val job = feed.playBackground(EventChannel())

assertDoesNotThrow {
feed.close()
feed.close()
}
if (job.isActive) job.cancel()

}

Expand Down

0 comments on commit a5686e5

Please sign in to comment.