Skip to content

Commit

Permalink
introduce location hinting parameter to repair API
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 25, 2023
1 parent bfcbc02 commit dfeab03
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ fun evaluateSyntheticRepairBenchmarkOn(dataset: String, postprocess: List<Repair
.forEach { (groundTruth, prompt) ->
println("Original: $groundTruth\nCorrupted: ${prettyDiffNoFrills(groundTruth, prompt)}")
val startTime = System.currentTimeMillis()
parallelRepair(prompt, deck, edits + 1, { joinToString("").isSyntacticallyValidKotlin() }, scoreEdit).postprocess()
parallelRepair(
prompt = prompt,
fillers = deck,
maxEdits = edits + 1,
admissibilityFilter = { joinToString("").isSyntacticallyValidKotlin() },
scoreEdit = scoreEdit
).postprocess()
.also {
// repairKotlinStatement(prompt).also {
val gtSeq = groundTruth.tokenizeByWhitespace().joinToString(" ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ fun evaluateSeq2ParseOnStackOverflowDataset() {

class RankStats(val name: String = "Total") {
val upperBound = TIMEOUT_MS / 1000
val time = (300..3000 step 300) //(1..10).toSet()//setOf(2, 5, 10) + (20..upperBound step 20).toSet()
val time = (100..2500 step 200).toSet() //(1..10).toSet()//setOf(2, 5, 10) + (20..upperBound step 20).toSet()
// Mean Reciprocal Rank
val timedMRR = time.associateWith { 0.0 }.toMutableMap()
// Precision at K, first int is K, second is the time cutoff
Expand Down Expand Up @@ -189,18 +189,20 @@ class RankStats(val name: String = "Total") {
}
}

fun Int.roundToTenths() = (toDouble() / 1000).round(1)

var summary = "$name ranking statistics across $samplesEvaluated samples...\n"
val latestMRRs = timedMRR.entries.sortedByDescending { it.key }
.joinToString(", ") { (k, v) ->
"${k}s: ${(v / samplesEvaluated).round(3)}"
"${(k.roundToTenths()).round(1)}s: ${(v / samplesEvaluated).round(3)}"
}
summary += "\nMRR= $latestMRRs"

val latestPAKs = timedPAK.entries.groupBy { it.key.first }
.mapValues { (_, v) ->
v.sortedByDescending { it.key.second }
.joinToString(", ") { (p, v) ->
"${p.second}ms: ${(v / samplesEvaluated).round(3)}"
"${p.second.roundToTenths()}s: ${(v / samplesEvaluated).round(3)}"
}
}.entries.joinToString("\n") { (k, v) ->
"P@${if (k == Int.MAX_VALUE) "All" else k}=".padEnd(6) + v
Expand Down Expand Up @@ -253,6 +255,7 @@ fun evaluateTidyparseOnStackoverflow() {
parallelRepair(
prompt = coarseBrokeStr,
fillers = topTokens,
// hints = pythonErrorLocations(humanError.lexToIntTypesAsPython()),
maxEdits = 4,
admissibilityFilter = { map { pythonVocabBindex.getUnsafe(it) ?: it.toInt() }.isValidPython() },
// TODO: incorporate parseable segmentations into scoring mechanism to prioritize chokepoint repairs
Expand Down Expand Up @@ -280,6 +283,10 @@ fun evaluateTidyparseOnStackoverflow() {
}
}

fun pythonErrorLocations(coarseBrokeTks: List<Int>): List<Int> =
listOf(coarseBrokeTks.getIndexOfFirstPythonError())

// Returns a triple of: (1) the broken source, (2) the human fix, and (3) the minimized fix
private fun preprocessStackOverflow(
brokeSnippets: Sequence<String> = readContents("parse_errors.json"),
fixedSnippets: Sequence<String> = readContents("parse_fixes.json")
Expand Down
29 changes: 29 additions & 0 deletions src/main/kotlin/edu/mcgill/cstk/utils/ParseUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ val errorListener =
) { throw Exception("$msg") }
}

// Reports the index of the token that caused the error
val indexReportingListener =
object: BaseErrorListener() {
override fun syntaxError(
recognizer: Recognizer<*, *>?,
offendingSymbol: Any?,
line: Int,
charPositionInLine: Int,
msg: Σᐩ?,
e: RecognitionException?
) {
(offendingSymbol as? Token)
?.let { throw Exception(it.tokenIndex.toString()) }
?: throw Exception("")
}
}

// Exhaustive tokenization includes whitespaces
fun Σᐩ.tokenizeAsPython(exhaustive: Boolean = false): List<Σᐩ> =
if (!exhaustive) lexAsPython().allTokens.map { it.text }
Expand All @@ -32,6 +49,17 @@ fun Σᐩ.tokenizeAsPython(exhaustive: Boolean = false): List<Σᐩ> =
else throw Exception("Could not find token $t in ${toSplit.map { it.code }}").also { println("\n\n$this\n\n") }
}

fun List<Int>.getIndexOfFirstPythonError(): Int {
val tokenSource = ListTokenSource(map { CommonToken(it) })
val tokens = CommonTokenStream(tokenSource)
return try {
Python3Parser(tokens)
.apply { removeErrorListeners(); addErrorListener(indexReportingListener) }
.file_input()
-1
} catch (e: Exception) { e.message?.toIntOrNull() ?: -1 }
}

fun List<Int>.isValidPython(): Boolean {
val tokenSource = ListTokenSource(map { CommonToken(it) })
val tokens = CommonTokenStream(tokenSource)
Expand Down Expand Up @@ -104,6 +132,7 @@ fun Σᐩ.javac(): Σᐩ =

fun Σᐩ.isValidJava() = javac().isEmpty()


fun Σᐩ.isValidPython(onErrorAction: (Σᐩ?) -> Unit = {}): Boolean =
try {
Python3Parser(
Expand Down
7 changes: 6 additions & 1 deletion src/main/kotlin/edu/mcgill/cstk/utils/RepairUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ fun <T> deltaDebug(elements: List<T>, n: Int = 2, checkValid: (List<T>) -> Boole
fun parallelRepair(
prompt: Σᐩ,
fillers: Collection<Σᐩ>,
hints: List<Int> = emptyList(),
maxEdits: Int = 2,
admissibilityFilter: List<Σᐩ>.() -> Boolean,
scoreEdit: ((List<Σᐩ>) -> Double)? = null,
Expand All @@ -123,14 +124,18 @@ fun parallelRepair(
println("$delim\nBest repairs so far:\n$delim")
// We intersperse the prompt with empty strings to enable the repair of the first and last token
// as well as insertion of tokens by the repair algorithm, which only considers substitutions
val promptTokens = prompt.tokenizeByWhitespace().intersperse(maxEdits.coerceAtMost(2))
val spacingLength = maxEdits.coerceAtMost(2)
val promptTokens = prompt.tokenizeByWhitespace().intersperse(spacingLength)
// Remap the hints to the new indices in the interspersed prompt tokens
val remappedHints = hints.map { (spacingLength + 1) * it + 2 }

val deck = fillers + promptTokens.toSet() - "\""

val clock: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
return bijectiveRepair(
promptTokens = promptTokens,
deck = deck,
hints = remappedHints,
maxEdits = maxEdits,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS },
admissibilityFilter = admissibilityFilter,
Expand Down

0 comments on commit dfeab03

Please sign in to comment.