Skip to content

Commit

Permalink
score samples using variable order markov chain
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed May 21, 2023
1 parent b818e76 commit 88d5703
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
4 changes: 2 additions & 2 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ dependencies {
// exclude(group = "org.jetbrains.kotlinx", module = "multik-core")
// exclude(group = "org.jetbrains.kotlinx", module = "multik-default")
exclude(group = "org.jetbrains.lets-plot", module = "lets-plot-kotlin-jvm")
exclude(group = "org.apache.datasketches", module = "datasketches")
exclude(group = "org.apache.datasketches", module = "datasketches-java")
// exclude(group = "org.apache.datasketches", module = "datasketches")
// exclude(group = "org.apache.datasketches", module = "datasketches-java")
exclude(group = "ca.umontreal.iro.simul", module = "ssj")
exclude(group = "org.sosy-lab", module = "common")
exclude(group = "org.sosy-lab", module = "java-smt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package edu.mcgill.cstk.experiments.repair

import ai.hypergraph.kaliningraph.levenshtein
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.*
import bijectiveRepair
import edu.mcgill.cstk.utils.*
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlin.spec.grammar.tools.*
import java.io.File
import kotlin.math.ln
import kotlin.time.*

/*
Expand All @@ -19,7 +19,7 @@ import kotlin.time.*
fun main() {
// fetchKotlinExamples()

val scoreEdit: (Edit) -> Float = constructScoringFunction()
val scoreEdit: (Σᐩ) -> Double = constructScoringFunction()

// Generate synthetic error dataset
List(100) { originalKotlinLines }.joinToString("\n").lines().map {
Expand Down Expand Up @@ -56,44 +56,35 @@ fun main() {
}
}

private fun constructScoringFunction(): (Edit) -> Float {
val tokenCounts: MutableMap<Σᐩ, Int> = mutableMapOf()
coarsenedKotlinLines.tokenizeByWhitespace()
.forEach { tokenCounts[it] = (tokenCounts[it] ?: 0) + 1 }
val normalizingConstant = tokenCounts.values.sum().toFloat()
val normedTokenWeights =
tokenCounts.mapValues { (token, count) -> count.toFloat() / normalizingConstant }
private fun constructScoringFunction(): (Σᐩ) -> Double {
val P = coarsenedKotlinLines.lines().map { "BOS $it EOS" }
.map { it.tokenizeByWhitespace().asSequence().toMarkovChain(3) }
.fold(MarkovChain<Σᐩ>()) { a, b -> a + b }

println("Top 100 most common tokens: ${tokenCounts.toList().sortedByDescending { it.second }.take(100)}\n\n")
println("Top 10 most common tokens: ${P.topK(10)}\n\n")

val scoreEdit: (Edit) -> Float = {
val tokens = it.values
val tokenWeights = tokens.map { normedTokenWeights[it] ?: 0f }
// Tokens are t_1...t_n, we compute the score as log(p(t_1)*...*p(t_n))
// Edits are penalized by length, so we divide by the number of tokens
tokenWeights.sumOf { ln(it.toDouble()) }.toFloat() / tokens.size
}
return scoreEdit
return { P.score("BOS ${it.coarsenAsKotlin(false)} EOS".tokenizeByWhitespace()) }
}

// Get top level directory and all Kotlin files in all subdirectories
fun fetchKotlinExamples() =
File(File("").absolutePath)
.walkTopDown().filter { it.extension == "kt" }
.flatMap { it.readLines() }
.filter { it.isValidKotlin() } .toList()
.filter { it.isValidKotlin() }.toList()
.filter {
it.coarsenAsKotlin().let { str ->
dropKeywords.none { it in str } && str.split(" ").size in 10..40
}
}.map { it.trim() }.distinct().forEach { println(it) }

fun Σᐩ.coarsenAsKotlin(): Σᐩ =
lexAsKotlin().joinToString(" ") {
fun Σᐩ.coarsenAsKotlin(lex: Boolean = true): Σᐩ =
(if(lex) lexAsKotlin() else tokenizeByWhitespace()).joinToString(" ") {
when {
it.isBracket() -> it
it.none { it.isLetterOrDigit() } -> it
it in officialKotlinKeywords -> it
it.first().isUpperCase() -> "W"
else -> "w"
}
}
Expand All @@ -105,7 +96,7 @@ fun Σᐩ.uncoarsenAsKotlin(prompt: Σᐩ): Σᐩ {
when {
token.isBracket() -> token
token.none { it.isLetterOrDigit() } -> token
token == "w" -> words.removeFirst()
token.equals("w", ignoreCase = true) -> words.removeFirst()
token in officialKotlinKeywords -> token
else -> throw Exception("Unknown token: $token")
}
Expand All @@ -118,10 +109,10 @@ fun Σᐩ.uncoarsenAsKotlin(prompt: Σᐩ): Σᐩ {
@OptIn(ExperimentalTime::class)
fun parallelRepairKotlinStatement(
prompt: Σᐩ,
scoreEdit: ((Edit) -> Float)? = null,
clock: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
scoreEdit: ((Σᐩ) -> Double)? = null,
clock: TimeMark = TimeSource.Monotonic.markNow()
): List<Σᐩ> {
var bestRepair = Int.MAX_VALUE
var bestRepair = Double.MAX_VALUE
val delim = List(prompt.length) { "-" }.joinToString("")
println("$delim\nBest repairs so far:\n$delim")
return bijectiveRepair(
Expand All @@ -132,22 +123,35 @@ fun parallelRepairKotlinStatement(
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS },
// updateProgress = { println(it) },
filter = { isValidKotlin() },
scoreEdit = scoreEdit,
diagnostic = {
val levDiff = levenshtein(prompt, it)
if (levDiff < bestRepair) {
println("Δ=$levDiff repair: ${prettyDiffNoFrills(prompt, it)}")
// println("(LATEX) Δ=$levDiff repair: ${latexDiffSingleLOC(prompt, it)}")
bestRepair = levDiff
scoreString = scoreEdit,
diagnostic =
if (scoreEdit!= null) {
{
val score = scoreEdit(it)
if (score < bestRepair) {
println("Δ=$score repair: ${prettyDiffNoFrills(prompt, it)}")
// println("(LATEX) Δ=$score repair: ${latexDiffSingleLOC(prompt, it)}")
bestRepair = score
}
}
}
else {
{
val levDiff = levenshtein(prompt, it).toDouble()
if (levDiff < bestRepair) {
println("Δ=$levDiff repair: ${prettyDiffNoFrills(prompt, it)}")
// println("(LATEX) Δ=$levDiff repair: ${latexDiffSingleLOC(prompt, it)}")
bestRepair = levDiff
}
}
}
},
)
}

@OptIn(ExperimentalTime::class)
fun repairKotlinStatement(
prompt: Σᐩ,
clock: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
clock: TimeMark = TimeSource.Monotonic.markNow()
): List<Σᐩ> =
// newRepair(prompt, permissiveKotlinCFG)
repair(
Expand All @@ -162,7 +166,7 @@ fun repairKotlinStatement(
)

@OptIn(ExperimentalTime::class)
private fun bruteForceKotlinRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> =
private fun bruteForceKotlinRepair(clock: TimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> =
{ a: List<Σᐩ> ->
try {
a.genCandidates(setOf(), commonKotlinKeywords + "ε" - "w" )
Expand Down

0 comments on commit 88d5703

Please sign in to comment.