Skip to content

Commit

Permalink
Tweak similarity threshold for new openai model
Browse files Browse the repository at this point in the history
  • Loading branch information
IanPhilips committed Jul 1, 2024
1 parent 8723d52 commit 0fc9e79
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 55 deletions.
5 changes: 3 additions & 2 deletions backend/api/src/get-related-markets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
} from 'common/supabase/groups'
import { ValidatedAPIParams } from 'common/api/schema'
import { mapValues, orderBy } from 'lodash'
import { TOPIC_SIMILARITY_THRESHOLD } from 'shared/helpers/embeddings'

export const getrelatedmarketscache: APIHandler<
'get-related-markets-cache'
Expand Down Expand Up @@ -47,9 +48,9 @@ const getRelatedMarkets = async (
select * from close_contract_embeddings(
input_contract_id := $1,
match_count := $2,
similarity_threshold := 0.7
similarity_threshold := $3
)`,
[contractId, embeddingsLimit],
[contractId, embeddingsLimit, TOPIC_SIMILARITY_THRESHOLD],
(row) => row.data as Contract
),
pg.map(
Expand Down
14 changes: 8 additions & 6 deletions backend/api/src/get-similar-groups-to-contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { createSupabaseDirectClient } from 'shared/supabase/init'
import { convertGroup } from 'common/supabase/groups'
import { orderBy, uniqBy } from 'lodash'
import { log } from 'shared/utils'
import { TOPIC_SIMILARITY_THRESHOLD } from 'shared/helpers/embeddings'

const bodySchema = z
.object({
Expand All @@ -32,16 +33,17 @@ export const getsimilargroupstocontract = authEndpoint(async (req) => {
log('Finding similar groups to' + question)
const groups = await pg.map(
`
select *, (embedding <=> ($1)::vector) as distance from groups
select groups.*, (embedding <=> ($1)::vector) as distance from groups
join group_embeddings on groups.id = group_embeddings.group_id
where (embedding <=> ($1)::vector) < 0.12
and importance_score > 0.15
where (embedding <=> ($1)::vector) < $2
and importance_score > 0.1
and (total_members > 10 or importance_score > 0.3)
and privacy_status = 'public'
and slug not in ($2:list)
order by POW(1-(embedding <=> ($1)::vector) - 0.8, 2) * importance_score desc
and slug not in ($3:list)
order by POW(1-(embedding <=> ($1)::vector), 2) * importance_score desc
limit 5
`,
[embedding, GROUPS_SLUGS_TO_IGNORE],
[embedding, TOPIC_SIMILARITY_THRESHOLD, GROUPS_SLUGS_TO_IGNORE],
(group) => {
log('group: ' + group.name + ' distance: ' + group.distance)
return convertGroup(group)
Expand Down
100 changes: 53 additions & 47 deletions backend/scripts/generate-embeddings.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,58 @@
import { initAdmin } from 'shared/init-admin'
initAdmin()

import { run } from 'common/supabase/utils'
import {
createSupabaseClient,
createSupabaseDirectClient,
} from 'shared/supabase/init'
import { generateEmbeddings } from 'shared/helpers/openai-utils'

async function main() {
const db = createSupabaseClient()
const pg = createSupabaseDirectClient()

const result = await run(db.from('contract_embeddings').select('contract_id'))

const contractIds = new Set(result.data.map((row: any) => row.contract_id))
console.log('Got', contractIds.size, 'markets with preexisting embeddings')

const { data: contracts } = await run(
db.from('contracts').select('id, data')
// doesn't work if too many contracts
// .not('id', 'in', '(' + [...contractIds].join(',') + ')')
).catch((err) => (console.error(err), { data: [] }))

console.log('Got', contracts.length, 'markets to process')

for (const contract of contracts) {
const { id, data } = contract
if (contractIds.has(id)) continue

const { question } = data as { question: string }
const embedding = await generateEmbeddings(question)
if (!embedding || embedding.length < 1500) {
console.log('No embeddings for', question)
continue
import { runScript } from './run-script'
import { upsertGroupEmbedding } from 'shared/helpers/embeddings'
import { chunk } from 'lodash'
if (require.main === module) {
runScript(async ({ pg }) => {
const contracts = await pg.map(
`select id, question from contracts
where visibility = 'public'
`,
[],
(r) => ({ id: r.id, question: r.question })
)

console.log('Got', contracts.length, 'markets to process')

const chunks = chunk(contracts, 100)
let processed = 0
for (const contracts of chunks) {
await Promise.all(
contracts.map(async (contract) => {
const { question, id } = contract
const embedding = await generateEmbeddings(question)
if (!embedding || embedding.length < 1500) {
console.log('No embeddings for', question)
return
}

await pg
.none(
'insert into contract_embeddings (contract_id, embedding) values ($1, $2) on conflict (contract_id) do update set embedding = $2',
[id, embedding]
)
.catch((err) => console.error(err))
})
)
processed += contracts.length
console.log('Processed', processed, 'contracts')
}

console.log('Generated embeddings for', id, ':', question)

await pg
.none(
'insert into contract_embeddings (contract_id, embedding) values ($1, $2) on conflict (contract_id) do nothing',
[id, embedding]
const groupIds = await pg.map(
`select id from groups`,
[],
(r) => r.id as string
)
const groupChunks = chunk(groupIds, 100)
let groupProcessed = 0
for (const groupIds of groupChunks) {
await Promise.all(
groupIds.map(async (groupId) => {
await upsertGroupEmbedding(pg, groupId)
})
)
.catch((err) => console.error(err))
}
}

if (require.main === module) {
main().then(() => process.exit())
groupProcessed += groupIds.length
console.log('Processed', groupProcessed, 'groups')
}
})
}
2 changes: 2 additions & 0 deletions backend/shared/src/helpers/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { log } from 'shared/utils'
import { getMostlyActiveUserIds } from 'shared/supabase/users'
import { HIDE_FROM_NEW_USER_SLUGS } from 'common/envs/constants'

export const TOPIC_SIMILARITY_THRESHOLD = 0.5

function magnitude(vector: number[]): number {
const vectorSum = sum(vector.map((val) => val * val))
return Math.sqrt(vectorSum)
Expand Down

0 comments on commit 0fc9e79

Please sign in to comment.