Skip to content

Commit

Permalink
Add e2e test for a ranklens tutorial (#311)
Browse files Browse the repository at this point in the history
* normalize url usage in cmdline args

* add e2e to CI

* make e2e test not hang
  • Loading branch information
shuttie committed Mar 16, 2022
1 parent 465e1ee commit 4823c9a
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 62 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,9 @@ jobs:
with:
distribution: 'adopt-hotspot'
java-version: '11'

- name: Run tests
run: sbt -mem 2048 test
run: sbt -mem 3000 test assembly

- name: Run e2e test
run: ./run_e2e.sh target/scala-2.12/metarank.jar
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@ ThisBuild / assemblyMergeStrategy := {
val oldStrategy = (ThisBuild / assemblyMergeStrategy).value
oldStrategy(x)
}
assembly / assemblyJarName := "metarank.jar"
39 changes: 39 additions & 0 deletions run_e2e.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/sh

JAR=$1
TMPDIR=`mktemp -d /tmp/ranklens-XXXXXX`

java -jar $JAR bootstrap \
--events src/test/resources/ranklens/events/ \
--out $TMPDIR \
--config src/test/resources/ranklens/config.yml

echo "Boostrap done into dir $TMPDIR"

java -jar $JAR train \
--input $TMPDIR/dataset \
--config src/test/resources/ranklens/config.yml \
--model-type lambdamart-lightgbm \
--model-file $TMPDIR/metarank.model

echo "Training done"

java -jar $JAR inference \
--config src/test/resources/ranklens/config.yml \
--model $TMPDIR/metarank.model \
--embedded-redis-features-dir $TMPDIR/features \
--format json \
--savepoint-dir $TMPDIR/savepoint & echo $! > $TMPDIR/inference.pid

PID=`cat $TMPDIR/inference.pid`

echo "Waiting for http server with pid=$PID to come online..."

while ! nc -z localhost 8080; do
sleep 5
echo "Trying to connect to :8080"
done

curl -v http://localhost:8080/health

kill -TERM $PID
14 changes: 8 additions & 6 deletions src/main/scala/ai/metarank/config/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ object Config extends Logging {
config
}

def load(contents: String): IO[Config] = for {
yaml <- IO.fromEither(parseYaml(contents))
decoded <- IO.fromEither(yaml.as[Config])
_ <- IO(logger.info(s"features: ${decoded.features.map(_.name)}"))
} yield {
decoded
def load(contents: String): IO[Config] = {
for {
yaml <- IO.fromEither(parseYaml(contents))
decoded <- IO.fromEither(yaml.as[Config])
_ <- IO(logger.info(s"features: ${decoded.features.map(_.name)}"))
} yield {
decoded
}
}
}
25 changes: 25 additions & 0 deletions src/main/scala/ai/metarank/mode/AsyncFlinkJob.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ai.metarank.mode

import ai.metarank.mode.inference.FeedbackFlow.logger
import ai.metarank.mode.inference.FlinkMinicluster
import ai.metarank.util.Logging
import cats.effect.IO
import cats.effect.kernel.Resource
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.util.TestStreamEnvironment

object AsyncFlinkJob extends Logging {
import ai.metarank.flow.DataStreamOps._
def execute(cluster: FlinkMinicluster, savepoint: Option[String] = None)(job: (StreamExecutionEnvironment) => Unit) =
Resource.make(IO.fromCompletableFuture {
IO {
val env = new StreamExecutionEnvironment(new TestStreamEnvironment(cluster.cluster.getMiniCluster, 1))
job(env)
val graph = env.getStreamGraph.getJobGraph
savepoint.foreach(s => graph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(s, false)))
logger.info(s"submitted job ${graph} to local cluster")
cluster.client.submitJob(graph)
}
})(job => IO.fromCompletableFuture(IO { cluster.client.cancel(job) }).map(_ => Unit))
}
14 changes: 12 additions & 2 deletions src/main/scala/ai/metarank/mode/FileLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.apache.commons.io.IOUtils
object FileLoader {
val s3Pattern = "s3://([a-zA-Z0-9\\-_]+)/(.*)".r
val filePattern = "file://(.*)".r
def load(path: String, env: Map[String, String]): IO[Array[Byte]] = path match {
def loadLocal(path: String, env: Map[String, String]): IO[Array[Byte]] = path match {
case s3Pattern(bucket, prefix) =>
for {
key <- IO.fromOption(env.get("AWS_ACCESS_KEY_ID"))(
Expand All @@ -36,6 +36,16 @@ object FileLoader {
client.shutdown()
bytes
}
case filePattern(local) => IO { File(local).byteArray }
case filePattern(local) => IO { File(local).byteArray }
case other if other.startsWith("/") => IO { File(other).byteArray }
case other => IO { (File.currentWorkingDirectory / other).byteArray }
}

def makeURL(path: String): String = path match {
case s3 @ s3Pattern(_, _) => s3
case file @ filePattern(_) => file
case absolute if absolute.startsWith("/") => "file://" + absolute
case relative => "file://" + (File.currentWorkingDirectory / relative).toString
}

}
30 changes: 18 additions & 12 deletions src/main/scala/ai/metarank/mode/bootstrap/Bootstrap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ai.metarank.flow.{
EventStateJoin,
ImpressionInjectFunction
}
import ai.metarank.mode.FlinkS3Configuration
import ai.metarank.mode.{FileLoader, FlinkS3Configuration}
import ai.metarank.model.{Clickthrough, Event, EventId, EventState}
import ai.metarank.model.Event.{FeedbackEvent, InteractionEvent, RankingEvent}
import ai.metarank.source.{EventSource, FileEventSource}
Expand Down Expand Up @@ -47,39 +47,45 @@ object Bootstrap extends IOApp with Logging {
}

override def run(args: List[String]): IO[ExitCode] = for {
cmd <- BootstrapCmdline.parse(args, System.getenv().asScala.toMap)
config <- Config.load(cmd.config)
_ <- run(config, cmd)
env <- IO { System.getenv().asScala.toMap }
cmd <- BootstrapCmdline.parse(args, env)
_ <- IO { logger.info("Performing bootstap.") }
_ <- IO { logger.info(s" events URL: ${cmd.eventPathUrl}") }
_ <- IO { logger.info(s" output dir URL: ${cmd.outDirUrl}") }
_ <- IO { logger.info(s" config: ${cmd.config}") }
configContents <- FileLoader.loadLocal(cmd.config, env)
config <- Config.load(new String(configContents))
_ <- run(config, cmd)
} yield {
ExitCode.Success
}

def run(config: Config, cmd: BootstrapCmdline) = IO {
File(cmd.outDir).createDirectoryIfNotExists(createParents = true)
if (cmd.outDirUrl.startsWith("file://")) { File(cmd.outDir).createDirectoryIfNotExists(createParents = true) }
val mapping = FeatureMapping.fromFeatureSchema(config.features, config.interactions)
val streamEnv =
StreamExecutionEnvironment.createLocalEnvironment(cmd.parallelism, FlinkS3Configuration(System.getenv()))
streamEnv.setRuntimeMode(RuntimeExecutionMode.BATCH)

logger.info("starting historical data processing")
val raw: DataStream[Event] = FileEventSource(cmd.eventPath).eventStream(streamEnv).id("load")
val raw: DataStream[Event] = FileEventSource(cmd.eventPathUrl).eventStream(streamEnv).id("load")
val grouped = groupFeedback(raw)
val (state, updates) = makeUpdates(raw, grouped, mapping)

Featury.writeState(state, new Path(s"${cmd.outDir}/state"), Compress.NoCompression).id("write-state")
Featury.writeState(state, new Path(s"${cmd.outDirUrl}/state"), Compress.NoCompression).id("write-state")
Featury
.writeFeatures(updates, new Path(s"${cmd.outDir}/features"), Compress.NoCompression)
.writeFeatures(updates, new Path(s"${cmd.outDirUrl}/features"), Compress.NoCompression)
.id("write-features")
val computed = joinFeatures(updates, grouped, mapping)
computed.sinkTo(DatasetSink.json(mapping, s"${cmd.outDir}/dataset")).id("write-train")
computed.sinkTo(DatasetSink.json(mapping, s"${cmd.outDirUrl}/dataset")).id("write-train")
streamEnv.execute("bootstrap")

logger.info("processing done, generating savepoint")
val batch = ExecutionEnvironment.getExecutionEnvironment
batch.setParallelism(cmd.parallelism)
val stateSource = Featury.readState(batch, new Path(s"${cmd.outDir}/state"), Compress.NoCompression)
val stateSource = Featury.readState(batch, new Path(s"${cmd.outDirUrl}/state"), Compress.NoCompression)

val valuesPath = s"${cmd.outDir}/features"
val valuesPath = s"${cmd.outDirUrl}/features"
val valuesSource = batch
.readFile(
new BulkInputFormat[FeatureValue](
Expand Down Expand Up @@ -111,7 +117,7 @@ object Bootstrap extends IOApp with Logging {
.withOperator("process-stateless-writes", transformStateless)
.withOperator("process-stateful-writes", transformStateful)
.withOperator("join-state", transformStateJoin)
.write(s"${cmd.outDir}/savepoint")
.write(s"${cmd.outDirUrl}/savepoint")

batch.execute("savepoint")
logger.info("done")
Expand Down
12 changes: 8 additions & 4 deletions src/main/scala/ai/metarank/mode/bootstrap/BootstrapCmdline.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package ai.metarank.mode.bootstrap

import ai.metarank.mode.FileLoader
import ai.metarank.util.Logging
import better.files.File
import cats.effect.IO
import scopt.{OParser, OptionParser}

case class BootstrapCmdline(eventPath: String, outDir: String, config: File, parallelism: Int)
case class BootstrapCmdline(eventPath: String, outDir: String, config: String, parallelism: Int) {
lazy val outDirUrl = FileLoader.makeURL(outDir)
lazy val eventPathUrl = FileLoader.makeURL(eventPath)
}

object BootstrapCmdline extends Logging {

Expand All @@ -14,7 +18,7 @@ object BootstrapCmdline extends Logging {
head("Metarank", "v0.x")

opt[String]("events")
.text("full path to directory containing historical events, with file:// or s3:// prefix")
.text("full URL path to directory containing historical events (optionally with file:// or s3:// prefix)")
.required()
.action((m, cmd) => cmd.copy(eventPath = m))
.withFallback(() => env.getOrElse("METARANK_EVENTS", ""))
Expand All @@ -24,7 +28,7 @@ object BootstrapCmdline extends Logging {
}

opt[String]("out")
.text("output directory")
.text("output directory, also")
.required()
.action((m, cmd) => cmd.copy(outDir = m))
.withFallback(() => env.getOrElse("METARANK_OUT", ""))
Expand All @@ -36,7 +40,7 @@ object BootstrapCmdline extends Logging {
opt[String]("config")
.required()
.text("config file")
.action((m, cmd) => cmd.copy(config = File(m)))
.action((m, cmd) => cmd.copy(config = m))
.withFallback(() => env.getOrElse("METARANK_CONFIG", ""))
.validate {
case "" => Left("config is required")
Expand Down
15 changes: 6 additions & 9 deletions src/main/scala/ai/metarank/mode/inference/FeedbackFlow.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.metarank.mode.inference

import ai.metarank.FeatureMapping
import ai.metarank.mode.AsyncFlinkJob
import ai.metarank.mode.bootstrap.Bootstrap
import ai.metarank.model.Event
import ai.metarank.source.LocalDirSource
Expand Down Expand Up @@ -28,10 +29,9 @@ object FeedbackFlow extends Logging {
eti: TypeInformation[Event],
valti: TypeInformation[FeatureValue],
intti: TypeInformation[Int]
) =
Resource.make(IO.fromCompletableFuture {
IO {
val env = new StreamExecutionEnvironment(new TestStreamEnvironment(cluster.cluster.getMiniCluster, 1))
) = {
AsyncFlinkJob.execute(cluster, Some(cmd.savepoint)) { env =>
{
val source = env.addSource(new LocalDirSource(path)).id("local-source")
val grouped = Bootstrap.groupFeedback(source)
val (_, updates) = Bootstrap.makeUpdates(source, grouped, mapping)
Expand All @@ -40,10 +40,7 @@ object FeedbackFlow extends Logging {
FeatureStoreSink(RedisStore(RedisConfig(redisHost, cmd.redisPort, cmd.format)), cmd.batchSize)
)
.id("write-redis")
val graph = env.getStreamGraph.getJobGraph
graph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(cmd.savepoint, false))
logger.info(s"submitted job ${graph} to local cluster")
cluster.client.submitJob(graph)
}
})(job => IO.fromCompletableFuture(IO { cluster.client.cancel(job) }).map(_ => Unit))
}
}
}
17 changes: 10 additions & 7 deletions src/main/scala/ai/metarank/mode/inference/Inference.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ import org.http4s.circe.{jsonEncoderOf, jsonOf}
import io.circe.syntax._
import org.http4s.circe._
import cats.syntax.all._
import fs2.concurrent.SignallingRef
import io.findify.featury.connector.redis.RedisStore
import io.findify.featury.values.ValueStoreConfig.RedisConfig
import org.http4s.blaze.server.BlazeServerBuilder
import io.findify.flinkadt.api._
import org.apache.flink.configuration.Configuration

import java.nio.charset.StandardCharsets
import scala.collection.JavaConverters._

object Inference extends IOApp {
Expand All @@ -32,11 +30,13 @@ object Inference extends IOApp {
for {
env <- IO { System.getenv().asScala.toMap }
cmd <- InferenceCmdline.parse(args, env)
confContents <- FileLoader.load(cmd.config, env).map(new String(_))
confContents <- FileLoader.loadLocal(cmd.config, env).map(new String(_))
config <- Config.load(confContents)
mapping <- IO.pure { FeatureMapping.fromFeatureSchema(config.features, config.interactions) }
model <- FileLoader.load(cmd.model, env).map(new String(_))
result <- cluster(dir, config, mapping, cmd, model).use { _.serve.compile.drain.as(ExitCode.Success) }
model <- FileLoader.loadLocal(cmd.model, env).map(new String(_))
result <- cluster(dir, config, mapping, cmd, model).use {
_.serve.compile.drain.as(ExitCode.Success)
}
} yield result
}

Expand All @@ -47,7 +47,9 @@ object Inference extends IOApp {
_ <- Resource.eval(redis.upload)
_ <- FeedbackFlow.resource(cluster, dir.toString(), mapping, cmd, redis.host)
s <- server(cmd, config, dir, redis.host, model)
} yield s
} yield {
s
}
}

def server(cmd: InferenceCmdline, config: Config, dir: File, redisHost: String, model: String) = {
Expand All @@ -61,5 +63,6 @@ object Inference extends IOApp {
} yield BlazeServerBuilder[IO]
.bindHttp(cmd.port, cmd.host)
.withHttpApp(httpApp)

}
}
5 changes: 3 additions & 2 deletions src/main/scala/ai/metarank/mode/inference/RedisEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ object RedisEndpoint {
}

case class EmbeddedRedis(host: String, service: RedisServer, dir: String) extends RedisEndpoint {
override def upload: IO[Unit] = Upload.run(UploadCmdline(host, 6379, JsonCodec, dir, 1024, 1)).map(_ => {})
override def close: IO[Unit] = IO { service.close() }
override def upload: IO[Unit] =
Upload.run(UploadCmdline(host, 6379, JsonCodec, dir, 1024, 1)).allocated.map(_ => {})
override def close: IO[Unit] = IO { service.close() }
}

def create(dir: Option[String], host: Option[String], port: Int): Resource[IO, RedisEndpoint] = (dir, host) match {
Expand Down
Loading

0 comments on commit 4823c9a

Please sign in to comment.