Skip to content

Commit

Permalink
simulate arbitrary syntactic corruption
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Apr 17, 2023
1 parent 651401f commit 35e4ab1
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 56 deletions.
2 changes: 1 addition & 1 deletion galoisenne
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package edu.mcgill.cstk.experiments.repair

import ai.hypergraph.kaliningraph.parsing.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.parsing.*
import edu.mcgill.cstk.disk.*
import edu.mcgill.cstk.utils.*
import org.apache.commons.lang3.StringUtils
Expand Down Expand Up @@ -53,8 +53,8 @@ fun main() {
DATA_DIR.also { println("Evaluating syntax repair using $models on $it...") }
.allFilesRecursively().allMethods()
.map { it.first.lineSequence() }.flatten()
.filter(String::isANontrivialStatementWithBalancedBrackets)
.map { it to it.constructPromptByMaskingRandomBrackets() }
.filter(Σᐩ::isANontrivialStatementWithBalancedBrackets)
.map { it to it.constructPromptByMaskingRandomSyntax() }
.runningFold(modelScores) { scores, (groundTruth, prompt) ->
models.associateWith { model ->
val repairs = prompt.dispatchTo(model, cfg)
Expand All @@ -70,19 +70,19 @@ fun updateScore(scores: Scores, model: Model, groundTruth: () -> Boolean) =
if (groundTruth()) (n + 1) to (d + 1) else n to (d + 1)
}

fun String.coarsen(): String =
tokenize().joinToString(" ") {
fun Σᐩ.coarsen(): Σᐩ =
defaultTokenizer().joinToString(" ") {
when {
it.isBracket() -> it
it == MSK -> "_"
else -> "w"
}
}

fun String.isWhitespace() = trim().isEmpty()
fun Σᐩ.isWhitespace() = trim().isEmpty()

fun String.uncoarsen(prompt: String): String {
val words = prompt.tokenize().filter { it !in brackets }.toMutableList()
fun Σᐩ.uncoarsen(prompt: Σᐩ): Σᐩ {
val words = prompt.defaultTokenizer().filter { it !in brackets }.toMutableList()
return tokenizeByWhitespace().joinToString("") { token ->
when {
token.isBracket() -> token
Expand All @@ -94,19 +94,26 @@ fun String.uncoarsen(prompt: String): String {
} + words.joinToString("")
}

fun String.isBracket() = length == 1 && this in brackets

fun String.constructPromptByMaskingRandomBrackets(bracketsToMask: Int = 1): String =
tokenize().toMutableList().let { tokens ->
tokens.indices.filter { tokens[it] in brackets }
.shuffled().take(bracketsToMask)
.forEach { tokens[it] = MSK }
tokens
}.joinToString("")

const val brackets = "()[]{}"
fun String.tokenize(): List<String> =
fun Σᐩ.isBracket() = length == 1 && this in brackets

fun Σᐩ.constructPromptByMaskingRandomSyntax(
eligibleTokensToMask: Set<Σᐩ> = brackets,
numTokensToMask: Int = 1,
tokenize: Σᐩ.() -> List<Σᐩ> = Σᐩ::defaultTokenizer
): Σᐩ =
tokenize().toMutableList().let { codeTokens ->
// println("Code tokens: $codeTokens")
// println("Eligible tokens to mask: $eligibleTokensToMask")
codeTokens.indices.filter { codeTokens[it].trim() in eligibleTokensToMask }
// .also { println("Indices of eligible tokens to mask: $it") }
.shuffled().take(numTokensToMask)
.forEach { codeTokens[it] = MSK }
codeTokens
}.joinToString(" ")

val brackets = "()[]{}".map { "$it" }.toSet()
fun Σᐩ.defaultTokenizer(): List<Σᐩ> =
split(Regex("[\\(\\)\\[\\]{}]|___".let { "((?<=($it))|(?=($it)))" }))

fun String.tokenizeGranular() =
fun Σᐩ.tokenizeGranular() =
StringUtils.splitByCharacterTypeCamelCase(this).toList()
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ val cfg =
"""S -> w | ( ) | [ ] | { } | ( S ) | [ S ] | { S } | S S"""
.parseCFG().apply { blocked.addAll(setOf("w")) }

val pythonKeywords = listOf(
val pythonKeywords = setOf(
"False", "None", "True", "and", "as", "assert",
"async", "await", "break", "class", "continue",
"def", "del", "elif", "else", "except", "finally",
Expand All @@ -103,6 +103,12 @@ val pythonKeywords = listOf(
"return", "try", "while", "with", "yield"
)

val pythonOperators = setOf(
"=", "==", "+", "-", "*", "/", "%", "**", "//",
">>", "<<", "&", "|", "^", "~", "<", ">", "<=",
">=", "!=", "not", "in", "is", "and", "or"
)

fun String.coarsenAsPython(): String =
tokenizeAsPython().joinToString(" ") {
when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.parsing.*
import edu.mcgill.cstk.utils.*
import ai.hypergraph.kaliningraph.parsing.repair
import ai.hypergraph.kaliningraph.sampling.pow
import ai.hypergraph.kaliningraph.sat.*
import org.intellij.lang.annotations.Language
import kotlin.time.*
Expand All @@ -15,12 +14,16 @@ import kotlin.time.*

@OptIn(ExperimentalTime::class)
fun main() {
MAX_SAMPLE = 3
MAX_SAMPLE = 20
MAX_REPAIR = 2
// Synthetic error correction
validPythonStatements.lines().filter { it.isNotBlank() }.forEach {
validPythonStatements.lines().filter { it.isNotBlank() }.asSequence()
.map {
val original = it.tokenizeAsPython().joinToString(" ")
val prompt = original.constructPromptByDeletingRandomBrackets(1)
val prompt = original.constructPromptByDeletingRandomSyntax(pythonKeywords + pythonOperators + brackets, tokenizer = Σᐩ::tokenizeAsPython)
original to prompt
}
.forEach { (original, prompt) ->
println("Original: $original\nCorrupted: ${prettyDiffNoFrills(original, prompt)}")
// Organic repairs
val clock = TimeSource.Monotonic.markNow()
Expand All @@ -30,17 +33,16 @@ fun main() {
cfg = pythonStatementCFG,
coarsen = String::coarsenAsPython,
uncoarsen = String::uncoarsenAsPython,
// synthesizer = { a -> println(a.joinToString(" ")); synthesize(a) }, // SAT solver
synthesizer = { a -> a
// .also { println("Solving: ${it.joinToString(" ")}") }
.solve(this,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS }
) }, // Enumerative search
// synthesizer = { a -> a.parallelSolve(this).asSequence() }, // Parallel search
//synthesizer = { a -> println(a.joinToString(" ")); synthesize(a) }, // SAT solver
synthesizer = setRepair(clock), // Enumerative search
//synthesizer = { a -> a.parallelSolve(this).asSequence() }, // Parallel search
diagnostic = { println("Δ=${levenshtein(prompt, it) - 1} repair: ${prettyDiffNoFrills(prompt, it)}") },
// filter = { isValidPython() } ,
filter = { isValidPython() },
)
.also { println("Original string was ${if(original in it) "" else "NOT " }contained in repairs!") }
.also {
val contained = original in it
println("Original string was ${if(contained) "#${it.indexOf(original)}" else "NOT" } in repairs!")
}

println("\n")
}
Expand All @@ -61,9 +63,25 @@ fun main() {
// println("Both invalid: ${oi.intersect(ip)}")
// println("Our valid, their invalid: ${ov - vp}")
// println("Their valid, our invalid: ${vp - ov}")

}

@OptIn(ExperimentalTime::class)
private fun satRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> =
{ a: List<Σᐩ> -> asCJL.synthesize(a,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS }) }

@OptIn(ExperimentalTime::class)
private fun setRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> =
{ a: List<Σᐩ> -> a
// .also { println("Solving: ${it.joinToString(" ")}") }
.solve(this,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS }
) }

@OptIn(ExperimentalTime::class)
private fun parallelSetRepair(clock: TimeSource.Monotonic.ValueTimeMark): CFG.(List<Σᐩ>) -> Sequence<Σᐩ> =
{ a: List<Σᐩ> -> a.parallelSolve(this).asSequence() }

fun organicRepair() {
// Organic error correction
invalidPythonStatements.lines().shuffled().filter { it.isNotBlank() }
Expand Down Expand Up @@ -233,7 +251,7 @@ S -> not S | S or S
S -> lambda w : S | lambda w , w : S | lambda w , w , w : S | lambda w , w , w , w : S
""".trimIndent().parseCFG()
.apply { blocked.add("w") }
.apply { blocked.addAll(terminals.filter { !it.isBracket() }) }
// .apply { blocked.addAll(terminals.filter { !it.isBracket() }) }

@Language("py")
val invalidPythonStatements = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ val example_code = """
return True
else:
return False
""".tokenize()
""".defaultTokenizer()

// A very minimal CFG that can parse a subset of Python, e.g., the above code

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fun main() {
.filter(String::isANontrivialStatementWithBalancedBrackets)
.filter { it.coarsen().let { it.length in 23..69 && cfg.parse(it) != null } }
.map {
val prompt = it.constructPromptByDeletingRandomBrackets()
val prompt = it.constructPromptByDeletingRandomSyntax()
val coarsened = prompt.coarsen().also { println("Coarsened: $it") }
println("Bin progress: " + strbins.entries.sortedBy { it.key }.joinToString(", "){ "${it.key} (${it.value.size})" })
CodeSnippet(
Expand Down Expand Up @@ -97,8 +97,13 @@ fun main() {
}
}

fun String.constructPromptByDeletingRandomBrackets(bracketsToDelete: Int = 1) =
constructPromptByMaskingRandomBrackets(bracketsToDelete).replace(MSK, "")
fun String.constructPromptByDeletingRandomSyntax(
eligibleTokensForDeletion: Set<String> = brackets,
tokensToDelete: Int = 1,
tokenizer: Σᐩ.() -> List<Σᐩ> = Σᐩ::defaultTokenizer
) =
trim().constructPromptByMaskingRandomSyntax(eligibleTokensForDeletion, tokensToDelete, tokenizer)
.replace(Regex("\\s*$MSK\\s*"), " ")

fun Int.bin10() = (floor((this + 1).toDouble() / 10.0) * 10).toInt()

Expand Down
29 changes: 15 additions & 14 deletions src/main/kotlin/edu/mcgill/cstk/utils/StringUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package edu.mcgill.cstk.utils

import ai.hypergraph.kaliningraph.types.cc
import com.github.difflib.text.*
import com.github.difflib.text.DiffRow.Tag.*
import edu.mcgill.cstk.disk.*
import edu.mcgill.cstk.experiments.probing.embeddingServer
import info.debatty.java.stringsimilarity.interfaces.MetricStringDistance
Expand Down Expand Up @@ -395,20 +396,20 @@ fun String.visibleLen() =

// Just print the new line with ASCII colors but no border
fun prettyDiffNoFrills(original: String, new: String) =
DiffRowGenerator.create()
.showInlineDiffs(true)
.inlineDiffByWord(true)
.newTag { l -> if(l) "<begin>" else "<end>" }
.oldTag { _ -> "" }
.build()
.generateDiffRows(original.split(" "), new.split(" ")).joinToString(" ") {
when (it.tag) {
DiffRow.Tag.INSERT -> it.newLine.replace("<begin>", ANSI_GREEN_BACKGROUND).replace("<end>", ANSI_RESET)
DiffRow.Tag.CHANGE -> it.newLine.replace("<begin>", ANSI_YELLOW_BACKGROUND).replace("<end>", ANSI_RESET)
DiffRow.Tag.DELETE -> it.newLine//.replace("<begin>", "$ANSI_RED_BACKGROUND _" ).replace("<end>", ANSI_RESET)
else -> it.newLine.replace("<begin>", "").replace("<end>", "")
}
}
DiffRowGenerator.create()
.showInlineDiffs(true)
.inlineDiffByWord(true)
.newTag { l -> if(l) "<begin>" else "<end>" }
.oldTag { _ -> "" }
.build()
.generateDiffRows(original.split(" "), new.split(" ")).joinToString(" ") {
when (it.tag) {
INSERT -> it.newLine.replace("<begin>", ANSI_GREEN_BACKGROUND).replace("<end>", ANSI_RESET)
CHANGE -> it.newLine.replace("<begin>", ANSI_YELLOW_BACKGROUND).replace("<end>", ANSI_RESET)
DELETE -> "$ANSI_RED_BACKGROUND${List(it.oldLine.length){ " " }.joinToString("")}$ANSI_RESET"
else -> it.newLine.replace("<begin>", "").replace("<end>", "")
}
}.replace("&lt;", "<").replace("&gt;", ">")

fun prettyDiffHorizontal(
left: String, right: String,
Expand Down

0 comments on commit 35e4ab1

Please sign in to comment.