Skip to content

Commit

Permalink
respect python keywords during uncoarsening
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Apr 17, 2023
1 parent 82e96a1 commit 651401f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ fun String.coarsenAsPython(): String =
}

fun String.uncoarsenAsPython(prompt: String): String {
val words = prompt.tokenizeAsPython().filter { it.any { it.isLetterOrDigit() }}.toMutableList()
return tokenizeByWhitespace().joinToString(" ") { token ->
// println("Before uncoarsening: $this")
val words = prompt.tokenizeAsPython()
.filter { it !in pythonKeywords && it.any { it.isLetterOrDigit() }}.toMutableList()
val uncoarsed = tokenizeByWhitespace().joinToString(" ") { token ->
when {
token.isBracket() -> token
token.none { it.isLetterOrDigit() } -> token
Expand All @@ -124,6 +126,9 @@ fun String.uncoarsenAsPython(prompt: String): String {
else -> throw Exception("Unknown token: $token")
}
} + words.joinToString(" ")

// println("After uncoarsening: $uncoarsed")
return uncoarsed
}

fun String.dispatchTo(model: Model, grammar: CFG?): List<String> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,43 @@ import ai.hypergraph.kaliningraph.parsing.repair
import ai.hypergraph.kaliningraph.sampling.pow
import ai.hypergraph.kaliningraph.sat.*
import org.intellij.lang.annotations.Language
import kotlin.time.*

/*
./gradlew pythonStatementRepair
*/

@OptIn(ExperimentalTime::class)
fun main() {
MAX_SAMPLE = 3
MAX_REPAIR = 2
// Synthetic error correction
// validPythonStatements.lines().filter { it.isNotBlank() }.forEach {
// val original = it.tokenizeAsPython().joinToString(" ")
// val prompt = original.constructPromptByDeletingRandomBrackets(1)
// println("Original: $original\nCorrupted: ${prettyDiffNoFrills(original, prompt)}")
// // Organic repairs
// repairPythonStatement(prompt)
// .also { println("Original string was ${if(original in it) "" else "NOT " }contained in repairs!") }
// println("\n")
// }
validPythonStatements.lines().filter { it.isNotBlank() }.forEach {
val original = it.tokenizeAsPython().joinToString(" ")
val prompt = original.constructPromptByDeletingRandomBrackets(1)
println("Original: $original\nCorrupted: ${prettyDiffNoFrills(original, prompt)}")
// Organic repairs
val clock = TimeSource.Monotonic.markNow()

repair(
prompt = prompt,
cfg = pythonStatementCFG,
coarsen = String::coarsenAsPython,
uncoarsen = String::uncoarsenAsPython,
// synthesizer = { a -> println(a.joinToString(" ")); synthesize(a) }, // SAT solver
synthesizer = { a -> a
// .also { println("Solving: ${it.joinToString(" ")}") }
.solve(this,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS }
) }, // Enumerative search
// synthesizer = { a -> a.parallelSolve(this).asSequence() }, // Parallel search
diagnostic = { println("Δ=${levenshtein(prompt, it) - 1} repair: ${prettyDiffNoFrills(prompt, it)}") },
// filter = { isValidPython() } ,
)
.also { println("Original string was ${if(original in it) "" else "NOT " }contained in repairs!") }

println("\n")
}

// val (vp, ip) = coarsened.lines().mapIndexed { i, it -> i to it }
// .partition { (i, it ) -> it.isValidPython() }
Expand All @@ -42,12 +62,16 @@ fun main() {
// println("Our valid, their invalid: ${ov - vp}")
// println("Their valid, our invalid: ${vp - ov}")

}

fun organicRepair() {
// Organic error correction
invalidPythonStatements.lines().shuffled().filter { it.isNotBlank() }
// .parallelStream()
.filter { !it.matches(pythonStatementCFG) }
// .filter { !it.tokenizeAsPython().joinToString(" ").matches(pythonStatementCFG) }
.forEach {
// println("${it.hasBalancedBrackets()}::${it.isValidPython()}\t\t" + it)
it.isValidPython { println("Invalid Python: $it") }
println("${it.hasBalancedBrackets()}::${it.isValidPython()}\t\t" + it)
val prompt = it.tokenizeAsPython()
.joinToString(" ") // No need to corrupt since these are already broken
println("Original: $prompt")
Expand Down Expand Up @@ -209,7 +233,7 @@ S -> not S | S or S
S -> lambda w : S | lambda w , w : S | lambda w , w , w : S | lambda w , w , w , w : S
""".trimIndent().parseCFG()
.apply { blocked.add("w") }
// .apply { blocked.addAll(terminals.filter { !it.isBracket() }) }
.apply { blocked.addAll(terminals.filter { !it.isBracket() }) }

@Language("py")
val invalidPythonStatements = """
Expand Down Expand Up @@ -252,7 +276,7 @@ val invalidPythonStatements = """

@Language("py")
val validPythonStatements = """
numValues = sum([len(i.cache.keys()) for i in _memoizedFunctions]),
numValues = sum([len(i.cache.keys()) for i in _memoizedFunctions])
expectedGroupedC = [(i, [(i, i * 3 + j) for j in range(3)]) for i in range(5)]
res2 = array(map(lambda x: int(x[1]), tmp))
val = np.array([(1698 - 10.79 * xi) *(np.abs(np.cos(- 0.11 +(xi + 1) / 6)) - 1) + 484 for xi in x])
Expand Down

0 comments on commit 651401f

Please sign in to comment.