Skip to content

Commit

Permalink
add warning for single cts dataset (#1034)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuttie committed May 10, 2023
1 parent e960d2e commit a25c569
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
23 changes: 18 additions & 5 deletions src/main/scala/ai/metarank/main/command/train/SplitStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,26 @@ object SplitStrategy {
override def split(desc: DatasetDescriptor, queries: List[QueryMetadata]): IO[Split] = for {
_ <- info(s"using time split strategy, ratio=$ratioPercent%")
size = queries.size
position <- queries.size match {
case 0 | 1 => IO.raiseError(new Exception(s"dataset size ($size items) is too small to be split"))
case 2 => warnSmallDataset(size) *> IO.pure(1)
split <- queries.size match {
case 0 =>
IO.raiseError(
new Exception("""Metarank needs a couple of click-through events (so pairs of ranking+interaction),
|and you have zero of them.
|""".stripMargin)
)
case 1 =>
warn("Only single click-through event available, and we need more for training") *> IO(
Split(Dataset(desc, queries.map(_.query)), Dataset(desc, queries.map(_.query)))
)
case 2 => warnSmallDataset(size) *> IO.pure(splitByPosition(desc, queries, 1))
case gt =>
IO.whenA(gt < MIN_SPLIT)(warnSmallDataset(size)) *> IO(math.round(queries.size * (ratioPercent / 100.0f)))
IO.whenA(gt < MIN_SPLIT)(warnSmallDataset(size)) *> IO(
splitByPosition(desc, queries, math.round(queries.size * (ratioPercent / 100.0f)))
)
}
} yield {
} yield { split }

def splitByPosition(desc: DatasetDescriptor, queries: List[QueryMetadata], position: Int) = {
val (train, test) = queries.sortBy(_.ts.ts).splitAt(position)
Split(Dataset(desc, train.map(_.query)), Dataset(desc, test.map(_.query)))
}
Expand Down
10 changes: 8 additions & 2 deletions src/test/scala/ai/metarank/main/SplitStrategyTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ class SplitStrategyTest extends AnyFlatSpec with Matchers {
val now = Timestamp.now
val query = QueryMetadata(Query(desc, List(LabeledItem(1.0, 1, Array(1.0)))), now, None, Nil)

"time-split" should "handle unbalanced small inputs, size=2" in {
"time-split" should "handle unbalanced small inputs, size=1" in {
val split = TimeSplit(80).split(desc, List(query, query)).unsafeRunSync()
split.test.groups.size shouldBe 1
split.train.groups.size shouldBe 1
}

"time-split" should "handle unbalanced small inputs, size=3" in {
it should "handle unbalanced small inputs, size=2" in {
val split = TimeSplit(80).split(desc, List(query, query)).unsafeRunSync()
split.test.groups.size shouldBe 1
split.train.groups.size shouldBe 1
}

it should "handle unbalanced small inputs, size=3" in {
val split = TimeSplit(80).split(desc, List(query, query, query)).unsafeRunSync()
split.test.groups.size shouldBe 1
split.train.groups.size shouldBe 2
Expand Down

0 comments on commit a25c569

Please sign in to comment.