Skip to content

Commit

Permalink
extract similarity if it's there, nicer test
Browse files Browse the repository at this point in the history
  • Loading branch information
jillesvangurp committed Jan 5, 2024
1 parent df04b52 commit 022dc38
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 65 deletions.
43 changes: 24 additions & 19 deletions src/main/kotlin/com/tryformation/pgdocstore/DocStore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ data class DocStoreEntry(
val updatedAt: Instant,
val json: String,
val tags: List<String>?,
val text: String?
val text: String?,
val similarity: Float? = null
)

fun String.sanitizeInputForDB(): String {
Expand Down Expand Up @@ -64,6 +65,8 @@ val RowData.docStoreEntry
it as List<String>
},
text = getString(DocStoreEntry::text.name),
// only there on searches with a text
similarity = this.size.takeIf { it > 6}?.let { getFloat("rank") },
)

class DocStore<T : Any>(
Expand Down Expand Up @@ -109,9 +112,9 @@ class DocStore<T : Any>(
suspend fun create(
doc: T,
timestamp: Instant = Clock.System.now(),
onConflictUpdate:Boolean=false
onConflictUpdate: Boolean = false
): DocStoreEntry {
return create(idExtractor.invoke(doc), doc, timestamp,onConflictUpdate)
return create(idExtractor.invoke(doc), doc, timestamp, onConflictUpdate)
}

/**
Expand All @@ -138,18 +141,18 @@ class DocStore<T : Any>(
INSERT INTO $tableName (id, created_at, updated_at, json, tags, text)
VALUES (?,?,?,?,?,?)
${
if(onConflictUpdate) """
if (onConflictUpdate) """
ON CONFLICT (id) DO UPDATE SET
json = EXCLUDED.json,
tags = EXCLUDED.tags,
text = EXCLUDED.text,
updated_at = EXCLUDED.updated_at
""".trimIndent()
else ""
}
else ""
}
""".trimIndent(), listOf(id, timestamp, timestamp, txt, tags, text)
)
return DocStoreEntry(id, timestamp,timestamp,txt,tags,text)
return DocStoreEntry(id, timestamp, timestamp, txt, tags, text)
}

/**
Expand Down Expand Up @@ -409,7 +412,7 @@ class DocStore<T : Any>(
* This falls back to overwriting the document with ON CONFLICT (id) DO UPDATE in case
* the document already exists.
*/
suspend fun insertList(chunk: List<Pair<String, T>>,timestamp: Instant = Clock.System.now()) {
suspend fun insertList(chunk: List<Pair<String, T>>, timestamp: Instant = Clock.System.now()) {
// Base SQL for INSERT
val baseSql = """
INSERT INTO $tableName (id, json, tags, created_at, updated_at, text)
Expand Down Expand Up @@ -464,7 +467,7 @@ class DocStore<T : Any>(
limit: Int = 100,
offset: Int = 0,
similarityThreshold: Double = 0.1,
): List<T> {
): List<T> {
val q = constructQuery(
tags = tags,
query = query,
Expand All @@ -491,7 +494,7 @@ class DocStore<T : Any>(
limit: Int = 100,
offset: Int = 0,
similarityThreshold: Double = 0.1,
): List<DocStoreEntry> {
): List<DocStoreEntry> {
val q = constructQuery(
tags = tags,
query = query,
Expand Down Expand Up @@ -528,7 +531,7 @@ class DocStore<T : Any>(
query: String? = null,
fetchSize: Int = 100,
similarityThreshold: Double = 0.1,
): Flow<T> {
): Flow<T> {
val q = constructQuery(
tags = tags,
query = query,
Expand Down Expand Up @@ -557,7 +560,7 @@ class DocStore<T : Any>(
query: String? = null,
fetchSize: Int = 100,
similarityThreshold: Double = 0.1,
): Flow<DocStoreEntry> {
): Flow<DocStoreEntry> {
return queryFlow(
query = constructQuery(
tags = tags,
Expand All @@ -581,7 +584,7 @@ class DocStore<T : Any>(
offset: Int = 0,
similarityThreshold: Double = 0.01
): String {
val rankSelect = if(!query.isNullOrBlank()) {
val rankSelect = if (!query.isNullOrBlank()) {
// prepared statement does not work for this
", similarity(text, '${query.sanitizeInputForDB()}') AS rank"
} else {
Expand All @@ -592,9 +595,11 @@ class DocStore<T : Any>(
} else {
"WHERE " + listOfNotNull(
tags.takeIf { it.isNotEmpty() }
?.let { tags.joinToString(
if (orTags) " OR " else " AND "
) { "? = ANY(tags)" } }
?.let {
tags.joinToString(
if (orTags) " OR " else " AND "
) { "? = ANY(tags)" }
}
?.let {
// surround with parentheses
"($it)"
Expand All @@ -607,13 +612,13 @@ class DocStore<T : Any>(
}

val limitClause = if (limit != null) {
" LIMIT $limit" + if(offset>0) " OFFSET $offset" else ""
" LIMIT $limit" + if (offset > 0) " OFFSET $offset" else ""
} else ""

val orderByClause = if(query.isNullOrBlank()) {
val orderByClause = if (query.isNullOrBlank()) {
"ORDER BY updated_at DESC"
} else {
"ORDER BY rank DESC"
"ORDER BY rank DESC, updated_at DESC"
}
return "SELECT *$rankSelect FROM $tableName $whereClause $orderByClause$limitClause"
}
Expand Down
6 changes: 5 additions & 1 deletion src/test/kotlin/com/tryformation/pgdocstore/TaggingTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ class TaggingTest : DbTestBase() {
TaggedModel("three", listOf("foo","bar")),
))

ds.documentsByRecencyScrolling(listOf("foo")).count() shouldBe 2
ds.documentsByRecency(listOf("foo")).count() shouldBe 2
ds.documentsByRecency(listOf("foo")).count() shouldBe 2
ds.entriesByRecency(listOf("foo")).count() shouldBe 2
ds.entriesByRecencyScrolling(listOf("foo")).count() shouldBe 2
ds.documentsByRecencyScrolling(listOf("foo", "bar")).count() shouldBe 1
ds.documentsByRecencyScrolling(listOf("foo", "bar"), orTags = true).count() shouldBe 3
}

}
58 changes: 13 additions & 45 deletions src/test/kotlin/com/tryformation/pgdocstore/TextSearchTest.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package com.tryformation.pgdocstore

import io.kotest.matchers.collections.shouldContain
import io.kotest.matchers.collections.shouldContainAnyOf
import io.kotest.matchers.collections.shouldContainInOrder
import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.ints.shouldBeGreaterThan
import io.kotest.matchers.shouldBe
import kotlinx.serialization.Serializable
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -71,56 +68,27 @@ class TextSearchTest : DbTestBase() {
).map { SearchableModel(it) }
ds.bulkInsert(docs)

ds.documentsByRecency(query = "bar").map { it.title }.let {
ds.search(query = "bar").map { it.title }.let {
it.first() shouldBe "bar" // clearly the best match
// we also expect to find these
it shouldContainAnyOf listOf("ba", "b", "foo bar foobarred")
}
ds.documentsByRecency(query = "own", similarityThreshold = 0.01).map { it.title }.let {
ds.search(query = "own", similarityThreshold = 0.01).map { it.title }.let {
it.first() shouldBe "the quick brown fox" // clearly the best match
}
ds.documentsByRecency(query = "own", similarityThreshold = 0.5).map { it.title }.let {
ds.search(query = "own", similarityThreshold = 0.5).map { it.title }.let {
it shouldHaveSize 0
}
}
}

@Test
fun shouldDoTrigrams() = coRun {
db.sendQuery(
"""
CREATE EXTENSION IF NOT EXISTS pg_trgm;
DROP INDEX IF EXISTS idx_trigrams_text;
DROP TABLE IF EXISTS trigrams;
CREATE TABLE IF NOT EXISTS trigrams (
id text PRIMARY KEY,
text text
);
CREATE INDEX idx_trigrams_text ON trigrams USING gin (text gin_trgm_ops);
INSERT INTO trigrams (id, text) VALUES ('1', 'test@domain.com');
INSERT INTO trigrams (id, text) VALUES ('2', 'alice@aaa.com');
INSERT INTO trigrams (id, text) VALUES ('3', 'bob@bobby.com');
INSERT INTO trigrams (id, text) VALUES ('4', 'the quick brown fox');
INSERT INTO trigrams (id, text) VALUES ('5', 'the slow yellow fox');
""".trimIndent()
)

val q = "brown fox"
db.sendQuery(
"""
SELECT text, similarity(text, '$q') AS sml FROM trigrams WHERE similarity(text, '$q') > 0.01 ORDER BY sml DESC
""".trimIndent()
).rows.also {
println(
"""
RESULTS: ${it.size}
""".trimIndent()
)
it.size shouldBeGreaterThan 0
}.forEach {
println("${it["text"]} ${it["sml"]}")
private suspend fun DocStore<*>.search(query: String, similarityThreshold:Double = 0.1) =
entriesByRecency(query = query, similarityThreshold = similarityThreshold).also {
println("Found for '$query' with threshold $similarityThreshold:")
it.forEach { e ->
val d = e.document<SearchableModel>()
println("- ${d.title} (${e.similarity})")
}
}
}
}.map {
it.document<SearchableModel>()
}

0 comments on commit 022dc38

Please sign in to comment.