Skip to content

Commit

Permalink
repair procedure now works on StackOverflow dataset!
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 12, 2023
1 parent ea33545 commit 5cc3b8d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 28 deletions.
2 changes: 1 addition & 1 deletion galoisenne
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import java.util.stream.Stream
import kotlin.math.*
import kotlin.streams.asStream
import kotlin.system.measureTimeMillis
import kotlin.time.TimeSource
import kotlin.time.*


val brokenSnippetURL =
Expand All @@ -35,9 +35,19 @@ val brokenPythonSnippets by lazy {
}

val P_seq2parse: MarkovChain<Σᐩ> by lazy {
measureTimedValue {
brokenPythonSnippets.toList().parallelStream().map { "BOS $it EOS" }
.map { it.tokenizeByWhitespace().asSequence().toMarkovChain(1) }
.reduce { t, u -> t + u }.get()
}.let { println("Trained Markov chain on ${it.value.counter.total.get()} tokens StackOverflow in ${it.duration.inWholeMilliseconds}ms"); it.value }
}

val P_stackoverflow: MarkovChain<Σᐩ> by lazy {
measureTimedValue {
readContents("parse_fixes.json").asStream().parallel()
.map { "\n$it\n".lexToStrTypesAsPython().asSequence().toMarkovChain(3) }
.reduce { t, u -> t + u }.get()
}.let { println("Trained Markov chain on ${it.value.counter.total.get()} tokens StackOverflow in ${it.duration.inWholeMilliseconds}ms"); it.value }
}

val mostCommonTokens by lazy {
Expand All @@ -63,14 +73,17 @@ fun stackOverflowEval() {
readContents("parse_errors.json") to
readContents("parse_fixes.json")

val deck = P_stackoverflow.topK(200).map { it.first }.toSet() + "ε"
println("Deck size: $deck")

brokeSnippets.zip(fixedSnippets)
.asStream().parallel()
.asStream()//.parallel()
.filter { (broke, fixed) ->
// '"' !in broke && '\'' !in broke &&
broke.tokenizeAsPython().size < 30 &&
broke != fixed &&
(!broke.isValidPython() && fixed.isValidPython()) &&
broke.lines().size < 20 &&
// broke.lines().size < 20 &&
(broke.lines().size - fixed.lines().size).absoluteValue < 4
}
.minimizeFix { tokenizeAsPython(true) }
Expand All @@ -80,37 +93,60 @@ fun stackOverflowEval() {
// (brokeTokens.size - fixedTokens.size).absoluteValue < 10 &&
multisetManhattanDistance(brokeTokens, fixedTokens).let { it in 1..5 }
}
// .forEach { (broke, fixed, minfix) ->
// broke.tokenizeAsPython()
// }
.filter { (broke, fixed, minfix) ->
val (brokeVis, fixedVis, minfixVis) = broke.visibleChars() to fixed.visibleChars() to minfix.visibleChars()
brokeVis != fixedVis && brokeVis != minfixVis// && fixedVis != minfixVis
}
.map { (broke, fixed, minfix) ->
prettyDiffs(listOf(broke, fixed), listOf("original snippet", "human patch")).let { origDiff ->
prettyDiffs(listOf(broke, minfix), listOf("original snippet", "minimized patch")).let { minDiff ->
// Compare ASCII characters for a visible difference, if same do not print two
// if (corrected.visibleChars() == minfix.visibleChars()) origDiff to "" else
origDiff to minDiff to broke to minfix
}
}
}
// .map { (broke, fixed, minfix) ->
// prettyDiffs(listOf(broke, fixed), listOf("original snippet", "human patch")).let { origDiff ->
// prettyDiffs(listOf(broke, minfix), listOf("original snippet", "minimized patch")).let { minDiff ->
// // Compare ASCII characters for a visible difference, if same do not print two
//// if (corrected.visibleChars() == minfix.visibleChars()) origDiff to "" else
// origDiff to minDiff to broke to minfix
// }
// }
// }
// .filter { (a, b) -> b.isNotEmpty() && 2 < (a.count { it == '\u001B' } - b.count { it == '\u001B' }).absoluteValue }
.forEach { (a, b, c, d) ->
.filter { extractPatch(it.first.lexToStrTypesAsPython(), it.third.lexToStrTypesAsPython()).changes().size < 3 }
.forEach { (humanError, humanFix, minimumFix) ->
// println("$a\n$b")
val coarseBroke = c.lexToStrTypesAsPython().joinToString(" ") + " | " + c.lexToIntTypesAsPython().joinToString(" ")
val coarseFixed = d.lexToStrTypesAsPython().joinToString(" ") + " | " + d.lexToIntTypesAsPython().joinToString(" ")
// val diff1 = prettyDiffNoFrills(coarseBroke, coarseFixed)
// val diff2 = prettyDiffNoFrills(coarseFixed, coarseBroke)
// val diff3 = prettyDiffNoFrills(c.lexToIntTypesAsPython().joinToString("\n"), d.lexToIntTypesAsPython().joinToString("\n"))
// val maxlen = max(diff1.visibleLen(), diff2.visibleLen())

println("Broke tokens: ${prettyDiffNoFrills(coarseFixed, coarseBroke)}")
println("Fixed tokens: ${prettyDiffNoFrills(coarseBroke, coarseFixed)}")
val coarseBrokeTks = humanError.lexToStrTypesAsPython()
val coarseFixedTks = minimumFix.lexToStrTypesAsPython()
val coarseBrokeStr = coarseBrokeTks.joinToString(" ", "", " NEWLINE")
val coarseFixedStr = coarseFixedTks.joinToString(" ", "", " NEWLINE")

println("Broke tokens: ${prettyDiffNoFrills(coarseFixedStr, coarseBrokeStr)}")
println("Fixed tokens: ${prettyDiffNoFrills(coarseBrokeStr, coarseFixedStr)}")
println("\n\n")

val startTime = System.currentTimeMillis()
// val segmentation = Segmentation.build(seq2parsePythonCFG, coarseBrokeStr)

println("Repairing: $coarseBrokeStr")

parallelRepair(
prompt = coarseBrokeStr,
fillers = deck,
maxEdits = 4,
admissibilityFilter = { tokenizeByWhitespace().map { pythonVocabBindex.getUnsafe(it) ?: it.toInt() }.isValidPython() },
// TODO: incorporate parseable segmentations into scoring mechanism to prioritize chokepoint repairs
scoreEdit = { P_stackoverflow.score(it.tokenizeByWhitespace()) }
).also { repairs ->
repairs.take(20).apply { println("\nTop $size repairs:\n") }.forEach {
println("Δ=${it.scoreStr()} repair (${it.elapsed()}): ${prettyDiffNoFrills(coarseBrokeStr, it.result)}")
// println("(LATEX) Δ=${levenshtein(prompt, it)} repair: ${latexDiffSingleLOC(prompt, it)}")
}

val contained = repairs.any { coarseFixedStr == it.result }
val elapsed = System.currentTimeMillis() - startTime

println("\nFound ${repairs.size} valid repairs in ${elapsed}ms, or roughly " +
"${(repairs.size / (elapsed/1000.0)).toString().take(5)} repairs per second.")

println("Minimized repair was ${if (contained) "#" + repairs.indexOfFirst { it.result == coarseFixedStr } else "NOT"} in repair proposals!\n")
}
}
}
}

fun seq2parseEval() {
brokenPythonSnippets.map {
Expand Down
8 changes: 7 additions & 1 deletion src/main/kotlin/edu/mcgill/cstk/utils/ParseUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fun Σᐩ.tokenizeAsPython(exhaustive: Boolean = false): List<Σᐩ> =
else throw Exception("Could not find token $t in ${toSplit.map { it.code }}").also { println("\n\n$this\n\n") }
}

fun IntArray.isValidPython(): Boolean {
fun List<Int>.isValidPython(): Boolean {
val tokenSource = ListTokenSource(map { CommonToken(it) })
val tokens = CommonTokenStream(tokenSource)
return try {
Expand All @@ -49,6 +49,12 @@ fun Σᐩ.lexToIntTypesAsPython(
lexer: Lexer = Python3Lexer(CharStreams.fromString(this + "\n"))
) = lexer.allTokens.map { it.type }

val pythonVocabBindex: Bindex<Σᐩ> =
Python3Lexer(CharStreams.fromString(""))
.vocabulary.let { vocab ->
(0..vocab.maxTokenType).associateWith { vocab.getDisplayName(it) }
}.let { Bindex(it) }//.also { println(it.toString()) }

fun Σᐩ.lexToStrTypesAsPython(
lexer: Lexer = Python3Lexer(CharStreams.fromString(this)),
vocabulary: Vocabulary = lexer.vocabulary
Expand Down
1 change: 1 addition & 0 deletions src/main/kotlin/edu/mcgill/cstk/utils/RepairUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ fun <T> deltaDebug(elements: List<T>, n: Int = 2, checkValid: (List<T>) -> Boole
else deltaDebug(elements, n * 2, checkValid)
}

// TODO: Generify to accept List<T>
fun parallelRepair(
prompt: Σᐩ,
fillers: Set<Σᐩ>,
Expand Down

0 comments on commit 5cc3b8d

Please sign in to comment.