Skip to content

Commit

Permalink
extract common functionality across experiments into utils
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 12, 2023
1 parent f887853 commit 03fc148
Show file tree
Hide file tree
Showing 14 changed files with 450 additions and 478 deletions.
2 changes: 1 addition & 1 deletion galoisenne
2 changes: 1 addition & 1 deletion gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-rc-2-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
zipStorePath=wrapper/dists
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,6 @@ val P_kotlin: MarkovChain<Σᐩ> by lazy {
}.fold(MarkovChain(memory = memory)) { a, b -> a + b }
}

// Output stream that rejects all lines starting with "Parser error:" or "Lex error:"
class FilteredOutputStream(out: OutputStream) : PrintStream(out) {
override fun println(x: String?) {
if (x == null) return
if (x.toString().let {
// it.startsWith("logging: ") ||
it.startsWith("Parser error:") ||
it.startsWith("Lexer error:")
}) return
super.println(x)
}
}

/*
./gradlew kotlinStatementRepair
*/
Expand Down Expand Up @@ -88,7 +75,7 @@ fun evaluateSyntheticRepairBenchmarkOn(dataset: String, postprocess: List<Repair
.forEach { (groundTruth, prompt) ->
println("Original: $groundTruth\nCorrupted: ${prettyDiffNoFrills(groundTruth, prompt)}")
val startTime = System.currentTimeMillis()
parallelRepairKotlinStatement(prompt, deck, edits + 1, scoreEdit).postprocess()
parallelRepair(prompt, deck, edits + 1, { isSyntacticallyValidKotlin() }, scoreEdit).postprocess()
.also {
// repairKotlinStatement(prompt).also {
val gtSeq = groundTruth.tokenizeByWhitespace().joinToString(" ")
Expand Down Expand Up @@ -138,88 +125,6 @@ fun fetchKotlinExamples() =
.map { it.trim() }.distinct()
// .take(10)

fun Σᐩ.coarsenAsKotlin(lex: Boolean = true): Σᐩ =
(if(lex) lexAsKotlin() else tokenizeByWhitespace()).joinToString(" ") {
when {
it.isBracket() -> it
it.none { it.isLetterOrDigit() } -> it
it in officialKotlinKeywords -> it
it.first().isUpperCase() -> "W"
else -> "w"
}
}

fun Σᐩ.uncoarsenAsKotlin(prompt: Σᐩ): Σᐩ {
val words = prompt.tokenizeByWhitespace()
.filter { it !in officialKotlinKeywords && it.any { it.isLetterOrDigit() } }.toMutableList()
val uncoarsed = tokenizeByWhitespace().joinToString(" ") { token ->
when {
token.isBracket() -> token
token.none { it.isLetterOrDigit() } -> token
token.equals("w", ignoreCase = true) -> words.removeFirst()
token in officialKotlinKeywords -> token
else -> throw Exception("Unknown token: $token")
}
} + words.joinToString(" ", " ")

// println("After uncoarsening: $uncoarsed")
return uncoarsed
}

fun parallelRepairKotlinStatement(
prompt: Σᐩ,
fillers: Set<Σᐩ>,
maxEdits: Int = 2,
scoreEdit: ((Σᐩ) -> Double)? = null,
): List<Repair> {
var bestRepair = Double.MAX_VALUE
val delim = List(prompt.length) { "-" }.joinToString("")
println("$delim\nBest repairs so far:\n$delim")
// We intersperse the prompt with empty strings to enable the repair of the first and last token
// as well as insertion of tokens by the repair algorithm, which only considers substitutions
val promptTokens = prompt.tokenizeByWhitespace().intersperse(maxEdits.coerceAtMost(2))

val clock: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
return bijectiveRepair(
promptTokens = promptTokens,
fillers = fillers,
maxEdits = maxEdits,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS },
admissibilityFilter = { isSyntacticallyValidKotlin() },
scoreEdit = scoreEdit ?: { 0.0 },
diagnostic =
if (scoreEdit != null) {
{
val score = scoreEdit(it.result)
if (score < bestRepair) {
println("Δ=${it.scoreStr()} repair (${it.elapsed()}): ${prettyDiffNoFrills(prompt, it.result)}")
// println("(LATEX) Δ=$score repair: ${latexDiffSingleLOC(prompt, it)}")
bestRepair = score
}
}
}
else {
{
val levDiff = it.edit.size.toDouble()
if (levDiff < bestRepair) {
println("Δ=$levDiff repair (${it.elapsed()}): ${prettyDiffNoFrills(prompt, it.result)}")
// println("(LATEX) Δ=$levDiff repair: ${latexDiffSingleLOC(prompt, it)}")
bestRepair = levDiff
}
}
}
).toList()
// .parallelStream().map {
// it.editSignatureEquivalenceClass(
// tokens = (fillers + promptTokens).shuffled().toSet() - "\"",
// filter = { it.isSyntacticallyValidKotlin() },
// score = { scoreEdit?.invoke(it) ?: 0.0 }
// ).also { it.time = clock.elapsedNow().inWholeMilliseconds }
// }.toList()
.distinctBy { it.result }
.sortedWith(compareBy({ it.edit.size }, { it.score }))
}

fun repairKotlinStatement(
prompt: Σᐩ,
clock: TimeMark = TimeSource.Monotonic.markNow()
Expand All @@ -245,10 +150,6 @@ private fun bruteForceKotlinRepair(clock: TimeMark): CFG.(List<Σᐩ>) -> Sequen
} catch (e: Exception) { e.printStackTrace(); emptySequence()}
}

fun Σᐩ.isSyntacticallyValidKotlin(): Boolean =
try { parseKotlinCode(tokenizeKotlinCode(this)).let { true } }
catch (_: Throwable) { false }

@Language("kt")
val originalKotlinLines = """
val common = results.map { it.uri }.intersect(results.map { it.uri }.toSet())
Expand Down Expand Up @@ -1087,16 +988,6 @@ val permissiveKotlinCFG = """
val ignoredKeywords =
setOf("import", "package", "//", "/*", "\"", "\'", "\\`", "data", "_")

val officialKotlinKeywords = setOf(
"as", "as?", "break", "class", "continue", "do", "else", "false", "for", "fun", "if", "in",
"!in", "interface", "is", "!is", "null", "object", "package", "return", "super", "this",
"throw", "true", "try", "typealias", "val", "var", "when", "while", "by", "catch", "constructor",
"delegate", "dynamic", "field", "file", "finally", "get", "import", "init", "param", "property",
"receiver", "set", "setparam", "where", "actual", "abstract", "annotation", "companion",
"const", "crossinline", "data", "enum", "expect", "external", "final", "infix", "inline",
"inner", "internal", "lateinit", "noinline", "open", "operator", "out", "override", "private",
"protected", "public", "reified", "sealed", "suspend", "tailrec", "vararg", "field", "it"
)

val allBuiltinTypes = setOf(
"Any", "Boolean", "Byte", "Char", "Double", "Float", "Int", "Long", "Nothing", "Short", "String",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,52 +68,4 @@ fun main() {
fun updateScore(scores: Scores, model: Model, groundTruth: () -> Boolean) =
scores[model]!!.let { (n, d) -> // numerator / denominator
if (groundTruth()) (n + 1) to (d + 1) else n to (d + 1)
}

fun Σᐩ.coarsen(): Σᐩ =
defaultTokenizer().joinToString(" ") {
when {
it.isBracket() -> it
it == MSK -> "_"
else -> "w"
}
}

fun Σᐩ.isWhitespace() = trim().isEmpty()

fun Σᐩ.uncoarsen(prompt: Σᐩ): Σᐩ {
val words = prompt.defaultTokenizer().filter { it !in brackets }.toMutableList()
return tokenizeByWhitespace().joinToString("") { token ->
when {
token.isBracket() -> token
words.isEmpty() -> { //System.err.println("IDK what happened:\nSynthesized: $this");
"" }
token == "w" -> words.removeAt(0)
else -> throw Exception("Unknown token: $token")
}
} + words.joinToString("")
}

fun Σᐩ.isBracket() = length == 1 && this in brackets

fun Σᐩ.constructPromptByMaskingRandomSyntax(
eligibleTokensToMask: Set<Σᐩ> = brackets,
numTokensToMask: Int = 1,
tokenize: Σᐩ.() -> List<Σᐩ> = Σᐩ::defaultTokenizer
): Σᐩ =
tokenize().toMutableList().let { codeTokens ->
// println("Code tokens: $codeTokens")
// println("Eligible tokens to mask: $eligibleTokensToMask")
codeTokens.indices.filter { codeTokens[it].trim() in eligibleTokensToMask }
// .also { println("Indices of eligible tokens to mask: $it") }
.shuffled().take(numTokensToMask)
.forEach { codeTokens[it] = MSK }
codeTokens
}.joinToString(" ")

val brackets = "()[]{}".map { "$it" }.toSet()
fun Σᐩ.defaultTokenizer(): List<Σᐩ> =
split(Regex("[\\(\\)\\[\\]{}]|___".let { "((?<=($it))|(?=($it)))" }))

fun Σᐩ.tokenizeGranular() =
StringUtils.splitByCharacterTypeCamelCase(this).toList()
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import kotlin.time.*

fun main() {
val json = File("bifi/data/orig_bad_code/orig.bad.json").readText()
val parsed = Klaxon().parse<Map<String, Map<String, Any>>>(json)
val parsed = Klaxon().parse<Map<Σᐩ, Map<Σᐩ, Any>>>(json)

val strbins: MutableMap<Int, MutableList<CodeSnippet>> = mutableMapOf()

Expand All @@ -35,7 +35,7 @@ fun main() {

MAX_TOKENS = 100
MAX_SAMPLE = 200
var pfxs = mutableListOf<String>()
var pfxs = mutableListOf<Σᐩ>()
for (i in listOf(10000, 30000, 60000)) {
TIMEOUT_MS = i.also { println("REEVALUATING TIMEOUT: $it ms") }
val lenbins = ConcurrentHashMap<Int, Π3A<AtomicInteger>>()
Expand All @@ -53,7 +53,7 @@ fun main() {
val t = TimeSource.Monotonic.markNow()
var totalValidSamples = 0
val repair = repair(code, cfg,
String::coarsen, String::uncoarsen,
Σᐩ::coarsen, Σᐩ::uncoarsen,
// synthesizer = { a -> synthesize(a) },
synthesizer = { a -> a.solve(this) },
score = { defaultModel.score(it) },
Expand All @@ -75,14 +75,7 @@ fun main() {
}
}

data class CodeSnippet(
val originalCode: String,
val coarsened: String,
val errorMsg: String,
val groundTruth: String? = null
) {
val tokens = coarsened.split(" ")
}


const val NO_REPAIR = "NO_REPAIR_PROPOSAL!"

Expand All @@ -93,49 +86,14 @@ val cfg: CFG =
"""S -> w | ( ) | [ ] | { } | ( S ) | [ S ] | { S } | S S"""
.parseCFG().apply { blocked.addAll(setOf("w")) }

val pythonKeywords = setOf(
"False", "None", "True", "and", "as", "assert",
"async", "await", "break", "class", "continue",
"def", "del", "elif", "else", "except", "finally",
"for", "from", "global", "if", "import", "in", "is",
"lambda", "nonlocal", "not", "or", "pass", "raise",
"return", "try", "while", "with", "yield"
)

val pythonOperators = setOf(
"=", "==", "+", "-", "*", "/", "%", "**", "//",
">>", "<<", "&", "|", "^", "~", "<", ">", "<=",
">=", "!=", "not", "in", "is", "and", "or"
)

fun String.coarsenAsPython(): String =
tokenizeAsPython().joinToString(" ") {
when {
it.isBracket() -> it
it.none { it.isLetterOrDigit() } -> it
it in pythonKeywords -> it
else -> "w"
}
}

fun String.uncoarsenAsPython(prompt: String): String {
val words = prompt.tokenizeByWhitespace()
.filter { it !in pythonKeywords && it.any { it.isLetterOrDigit() }}.toMutableList()
val uncoarsed = tokenizeByWhitespace().joinToString(" ") { token ->
when {
token.isBracket() -> token
token.none { it.isLetterOrDigit() } -> token
token == "w" -> words.removeFirst()
token in pythonKeywords -> token
else -> throw Exception("Unknown token: $token")
}
} + words.joinToString(" ", " ")

// println("After uncoarsening: $uncoarsed")
return uncoarsed
}

fun String.dispatchTo(model: Model, grammar: CFG?): List<String> =
fun Σᐩ.dispatchTo(model: Model, grammar: CFG?): List<Σᐩ> =
when (model) {
tidyparse -> repair(this, grammar!!,
String::coarsen, String::uncoarsen,
Expand All @@ -146,7 +104,7 @@ fun String.dispatchTo(model: Model, grammar: CFG?): List<String> =
else -> { if (MSK in this) listOf(model.complete(replace(MSK, model.mask))) else emptyList() }
}

fun String.parsePythonOutput(): String =
fun Σᐩ.parsePythonOutput(): Σᐩ =
ProcessBuilder("python", "parser.py", this)
.start().also { it.waitFor() }.inputStream
.bufferedReader().readText().lines().first()
Expand Down
Loading

0 comments on commit 03fc148

Please sign in to comment.