From 35e4ab1277942e8d1ea21eb3dbd2c5582da12be0 Mon Sep 17 00:00:00 2001 From: breandan Date: Mon, 17 Apr 2023 15:26:58 -0400 Subject: [PATCH] simulate arbitrary syntactic corruption --- galoisenne | 2 +- .../repair/LocalizedSyntaxRepair.kt | 49 +++++++++++-------- .../experiments/repair/OrganicSyntaxRepair.kt | 8 ++- .../repair/PythonStatementRepair.kt | 48 ++++++++++++------ .../experiments/repair/RepairPrompting.kt | 2 +- .../repair/SyntheticSyntaxRepair.kt | 11 +++-- .../edu/mcgill/cstk/utils/StringUtils.kt | 29 +++++------ 7 files changed, 93 insertions(+), 56 deletions(-) diff --git a/galoisenne b/galoisenne index 75e5abd6..49bdd06b 160000 --- a/galoisenne +++ b/galoisenne @@ -1 +1 @@ -Subproject commit 75e5abd6bfbccd230b9e717182ab751b16f71155 +Subproject commit 49bdd06b883c019b8c1d81c4f2f2a6f7cdd548eb diff --git a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/LocalizedSyntaxRepair.kt b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/LocalizedSyntaxRepair.kt index b2e31639..158acc37 100644 --- a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/LocalizedSyntaxRepair.kt +++ b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/LocalizedSyntaxRepair.kt @@ -1,6 +1,6 @@ package edu.mcgill.cstk.experiments.repair -import ai.hypergraph.kaliningraph.parsing.tokenizeByWhitespace +import ai.hypergraph.kaliningraph.parsing.* import edu.mcgill.cstk.disk.* import edu.mcgill.cstk.utils.* import org.apache.commons.lang3.StringUtils @@ -53,8 +53,8 @@ fun main() { DATA_DIR.also { println("Evaluating syntax repair using $models on $it...") } .allFilesRecursively().allMethods() .map { it.first.lineSequence() }.flatten() - .filter(String::isANontrivialStatementWithBalancedBrackets) - .map { it to it.constructPromptByMaskingRandomBrackets() } + .filter(Σᐩ::isANontrivialStatementWithBalancedBrackets) + .map { it to it.constructPromptByMaskingRandomSyntax() } .runningFold(modelScores) { scores, (groundTruth, prompt) -> models.associateWith { model -> val repairs = prompt.dispatchTo(model, cfg) @@ -70,8 +70,8 @@ fun updateScore(scores: Scores, model: Model, groundTruth: () -> Boolean) = if (groundTruth()) (n + 1) to (d + 1) else n to (d + 1) } -fun String.coarsen(): String = - tokenize().joinToString(" ") { +fun Σᐩ.coarsen(): Σᐩ = + defaultTokenizer().joinToString(" ") { when { it.isBracket() -> it it == MSK -> "_" @@ -79,10 +79,10 @@ fun String.coarsen(): String = } } -fun String.isWhitespace() = trim().isEmpty() +fun Σᐩ.isWhitespace() = trim().isEmpty() -fun String.uncoarsen(prompt: String): String { - val words = prompt.tokenize().filter { it !in brackets }.toMutableList() +fun Σᐩ.uncoarsen(prompt: Σᐩ): Σᐩ { + val words = prompt.defaultTokenizer().filter { it !in brackets }.toMutableList() return tokenizeByWhitespace().joinToString("") { token -> when { token.isBracket() -> token @@ -94,19 +94,26 @@ fun String.uncoarsen(prompt: String): String { } + words.joinToString("") } -fun String.isBracket() = length == 1 && this in brackets - -fun String.constructPromptByMaskingRandomBrackets(bracketsToMask: Int = 1): String = - tokenize().toMutableList().let { tokens -> - tokens.indices.filter { tokens[it] in brackets } - .shuffled().take(bracketsToMask) - .forEach { tokens[it] = MSK } - tokens - }.joinToString("") - -const val brackets = "()[]{}" -fun String.tokenize(): List = +fun Σᐩ.isBracket() = length == 1 && this in brackets + +fun Σᐩ.constructPromptByMaskingRandomSyntax( + eligibleTokensToMask: Set<Σᐩ> = brackets, + numTokensToMask: Int = 1, + tokenize: Σᐩ.() -> List<Σᐩ> = Σᐩ::defaultTokenizer +): Σᐩ = + tokenize().toMutableList().let { codeTokens -> +// println("Code tokens: $codeTokens") +// println("Eligible tokens to mask: $eligibleTokensToMask") + codeTokens.indices.filter { codeTokens[it].trim() in eligibleTokensToMask } +// .also { println("Indices of eligible tokens to mask: $it") } + .shuffled().take(numTokensToMask) + .forEach { codeTokens[it] = MSK } + codeTokens + }.joinToString(" ") + +val brackets = "()[]{}".map { "$it" }.toSet() +fun Σᐩ.defaultTokenizer(): List<Σᐩ> = split(Regex("[\\(\\)\\[\\]{}]|___".let { "((?<=($it))|(?=($it)))" })) -fun String.tokenizeGranular() = +fun Σᐩ.tokenizeGranular() = StringUtils.splitByCharacterTypeCamelCase(this).toList() diff --git a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/OrganicSyntaxRepair.kt b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/OrganicSyntaxRepair.kt index 63e9c58b..2c7c0ea1 100644 --- a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/OrganicSyntaxRepair.kt +++ b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/OrganicSyntaxRepair.kt @@ -94,7 +94,7 @@ val cfg = """S -> w | ( ) | [ ] | { } | ( S ) | [ S ] | { S } | S S""" .parseCFG().apply { blocked.addAll(setOf("w")) } -val pythonKeywords = listOf( +val pythonKeywords = setOf( "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue", "def", "del", "elif", "else", "except", "finally", @@ -103,6 +103,12 @@ val pythonKeywords = listOf( "return", "try", "while", "with", "yield" ) +val pythonOperators = setOf( + "=", "==", "+", "-", "*", "/", "%", "**", "//", + ">>", "<<", "&", "|", "^", "~", "<", ">", "<=", + ">=", "!=", "not", "in", "is", "and", "or" +) + fun String.coarsenAsPython(): String = tokenizeAsPython().joinToString(" ") { when { diff --git a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/PythonStatementRepair.kt b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/PythonStatementRepair.kt index e6aec028..2d757d18 100644 --- a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/PythonStatementRepair.kt +++ b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/PythonStatementRepair.kt @@ -4,7 +4,6 @@ import ai.hypergraph.kaliningraph.* import ai.hypergraph.kaliningraph.parsing.* import edu.mcgill.cstk.utils.* import ai.hypergraph.kaliningraph.parsing.repair -import ai.hypergraph.kaliningraph.sampling.pow import ai.hypergraph.kaliningraph.sat.* import org.intellij.lang.annotations.Language import kotlin.time.* @@ -15,12 +14,16 @@ import kotlin.time.* @OptIn(ExperimentalTime::class) fun main() { - MAX_SAMPLE = 3 + MAX_SAMPLE = 20 MAX_REPAIR = 2 // Synthetic error correction - validPythonStatements.lines().filter { it.isNotBlank() }.forEach { + validPythonStatements.lines().filter { it.isNotBlank() }.asSequence() + .map { val original = it.tokenizeAsPython().joinToString(" ") - val prompt = original.constructPromptByDeletingRandomBrackets(1) + val prompt = original.constructPromptByDeletingRandomSyntax(pythonKeywords + pythonOperators + brackets, tokenizer = Σᐩ::tokenizeAsPython) + original to prompt + } + .forEach { (original, prompt) -> println("Original: $original\nCorrupted: ${prettyDiffNoFrills(original, prompt)}") // Organic repairs val clock = TimeSource.Monotonic.markNow() @@ -30,17 +33,16 @@ fun main() { cfg = pythonStatementCFG, coarsen = String::coarsenAsPython, uncoarsen = String::uncoarsenAsPython, -// synthesizer = { a -> println(a.joinToString(" ")); synthesize(a) }, // SAT solver - synthesizer = { a -> a -// .also { println("Solving: ${it.joinToString(" ")}") } - .solve(this, - takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS } - ) }, // Enumerative search -// synthesizer = { a -> a.parallelSolve(this).asSequence() }, // Parallel search + //synthesizer = { a -> println(a.joinToString(" ")); synthesize(a) }, // SAT solver + synthesizer = setRepair(clock), // Enumerative search + //synthesizer = { a -> a.parallelSolve(this).asSequence() }, // Parallel search diagnostic = { println("Δ=${levenshtein(prompt, it) - 1} repair: ${prettyDiffNoFrills(prompt, it)}") }, -// filter = { isValidPython() } , + filter = { isValidPython() }, ) - .also { println("Original string was ${if(original in it) "" else "NOT " }contained in repairs!") } + .also { + val contained = original in it + println("Original string was ${if(contained) "#${it.indexOf(original)}" else "NOT" } in repairs!") + } println("\n") } @@ -61,9 +63,25 @@ fun main() { // println("Both invalid: ${oi.intersect(ip)}") // println("Our valid, their invalid: ${ov - vp}") // println("Their valid, our invalid: ${vp - ov}") - } +@OptIn(ExperimentalTime::class) +private fun satRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> = + { a: List<Σᐩ> -> asCJL.synthesize(a, + takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS }) } + +@OptIn(ExperimentalTime::class) +private fun setRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> = + { a: List<Σᐩ> -> a +// .also { println("Solving: ${it.joinToString(" ")}") } + .solve(this, + takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS } + ) } + +@OptIn(ExperimentalTime::class) +private fun parallelSetRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> = + { a: List<Σᐩ> -> a.parallelSolve(this).asSequence() } + fun organicRepair() { // Organic error correction invalidPythonStatements.lines().shuffled().filter { it.isNotBlank() } @@ -233,7 +251,7 @@ S -> not S | S or S S -> lambda w : S | lambda w , w : S | lambda w , w , w : S | lambda w , w , w , w : S """.trimIndent().parseCFG() .apply { blocked.add("w") } - .apply { blocked.addAll(terminals.filter { !it.isBracket() }) } +// .apply { blocked.addAll(terminals.filter { !it.isBracket() }) } @Language("py") val invalidPythonStatements = """ diff --git a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/RepairPrompting.kt b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/RepairPrompting.kt index c28436eb..6491d70d 100644 --- a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/RepairPrompting.kt +++ b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/RepairPrompting.kt @@ -16,7 +16,7 @@ val example_code = """ return True else: return False -""".tokenize() +""".defaultTokenizer() // A very minimal CFG that can parse a subset of Python, e.g., the above code diff --git a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/SyntheticSyntaxRepair.kt b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/SyntheticSyntaxRepair.kt index 737eaf48..3c2d1ea3 100644 --- a/src/main/kotlin/edu/mcgill/cstk/experiments/repair/SyntheticSyntaxRepair.kt +++ b/src/main/kotlin/edu/mcgill/cstk/experiments/repair/SyntheticSyntaxRepair.kt @@ -40,7 +40,7 @@ fun main() { .filter(String::isANontrivialStatementWithBalancedBrackets) .filter { it.coarsen().let { it.length in 23..69 && cfg.parse(it) != null } } .map { - val prompt = it.constructPromptByDeletingRandomBrackets() + val prompt = it.constructPromptByDeletingRandomSyntax() val coarsened = prompt.coarsen().also { println("Coarsened: $it") } println("Bin progress: " + strbins.entries.sortedBy { it.key }.joinToString(", "){ "${it.key} (${it.value.size})" }) CodeSnippet( @@ -97,8 +97,13 @@ fun main() { } } -fun String.constructPromptByDeletingRandomBrackets(bracketsToDelete: Int = 1) = - constructPromptByMaskingRandomBrackets(bracketsToDelete).replace(MSK, "") +fun String.constructPromptByDeletingRandomSyntax( + eligibleTokensForDeletion: Set = brackets, + tokensToDelete: Int = 1, + tokenizer: Σᐩ.() -> List<Σᐩ> = Σᐩ::defaultTokenizer +) = + trim().constructPromptByMaskingRandomSyntax(eligibleTokensForDeletion, tokensToDelete, tokenizer) + .replace(Regex("\\s*$MSK\\s*"), " ") fun Int.bin10() = (floor((this + 1).toDouble() / 10.0) * 10).toInt() diff --git a/src/main/kotlin/edu/mcgill/cstk/utils/StringUtils.kt b/src/main/kotlin/edu/mcgill/cstk/utils/StringUtils.kt index 6de88a1b..069e7963 100644 --- a/src/main/kotlin/edu/mcgill/cstk/utils/StringUtils.kt +++ b/src/main/kotlin/edu/mcgill/cstk/utils/StringUtils.kt @@ -2,6 +2,7 @@ package edu.mcgill.cstk.utils import ai.hypergraph.kaliningraph.types.cc import com.github.difflib.text.* +import com.github.difflib.text.DiffRow.Tag.* import edu.mcgill.cstk.disk.* import edu.mcgill.cstk.experiments.probing.embeddingServer import info.debatty.java.stringsimilarity.interfaces.MetricStringDistance @@ -395,20 +396,20 @@ fun String.visibleLen() = // Just print the new line with ASCII colors but no border fun prettyDiffNoFrills(original: String, new: String) = -DiffRowGenerator.create() - .showInlineDiffs(true) - .inlineDiffByWord(true) - .newTag { l -> if(l) "" else "" } - .oldTag { _ -> "" } - .build() - .generateDiffRows(original.split(" "), new.split(" ")).joinToString(" ") { - when (it.tag) { - DiffRow.Tag.INSERT -> it.newLine.replace("", ANSI_GREEN_BACKGROUND).replace("", ANSI_RESET) - DiffRow.Tag.CHANGE -> it.newLine.replace("", ANSI_YELLOW_BACKGROUND).replace("", ANSI_RESET) - DiffRow.Tag.DELETE -> it.newLine//.replace("", "$ANSI_RED_BACKGROUND _" ).replace("", ANSI_RESET) - else -> it.newLine.replace("", "").replace("", "") - } - } + DiffRowGenerator.create() + .showInlineDiffs(true) + .inlineDiffByWord(true) + .newTag { l -> if(l) "" else "" } + .oldTag { _ -> "" } + .build() + .generateDiffRows(original.split(" "), new.split(" ")).joinToString(" ") { + when (it.tag) { + INSERT -> it.newLine.replace("", ANSI_GREEN_BACKGROUND).replace("", ANSI_RESET) + CHANGE -> it.newLine.replace("", ANSI_YELLOW_BACKGROUND).replace("", ANSI_RESET) + DELETE -> "$ANSI_RED_BACKGROUND${List(it.oldLine.length){ " " }.joinToString("")}$ANSI_RESET" + else -> it.newLine.replace("", "").replace("", "") + } + }.replace("<", "<").replace(">", ">") fun prettyDiffHorizontal( left: String, right: String,