Skip to content

Commit

Permalink
resampling optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 23, 2023
1 parent 1a9221d commit 8e60ec9
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package edu.mcgill.cstk.experiments.repair
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.*
import bijectiveRepair
import edu.mcgill.cstk.utils.*
import org.apache.datasketches.frequencies.ErrorType
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlin.spec.grammar.tools.*
import java.io.*
import kotlin.time.*

Expand Down Expand Up @@ -55,7 +53,7 @@ fun evaluateSyntheticRepairBenchmarkOn(dataset: String, postprocess: List<Repair
(commonKotlinKeywords + "ε" - "w")
.also { println("Full deck: $it") }
.sortedBy { P_kotlin[it] }.reversed().take(32)
.also { println("High frequency deck: $it") }.toSet()
.also { println("High frequency deck: $it") }

val edits = 2
// Generate synthetic error dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.kosat.round
import java.io.*
import java.net.*
import java.util.regex.Pattern
import java.util.stream.Collectors
import kotlin.math.*
import kotlin.streams.*
import kotlin.system.measureTimeMillis
Expand Down Expand Up @@ -43,11 +44,18 @@ val P_seq2parse: MarkovChain<Σᐩ> by lazy {
val P_stackoverflow: MarkovChain<Σᐩ> by lazy {
measureTimedValue {
readContents("parse_fixes.json").asStream().parallel()
.map { "\n$it\n".lexToStrTypesAsPython().asSequence().toMarkovChain(3) }
.map { "\n$it\n".lexToStrTypesAsPython().asSequence().toMarkovChain(4) }
.reduce { t, u -> t + u }.get()
}.let { println("Trained Markov chain on ${it.value.counter.total.get()} tokens StackOverflow in ${it.duration.inWholeMilliseconds}ms"); it.value }
}

typealias CooccurenceMatrix = List<List<Double>>
val pythonCooccurence: CooccurenceMatrix by lazy {
readContents("parse_fixes.json").asStream().parallel()
.map { "\n$it\n".lexToIntTypesAsPython() }
.computeCooccurrenceProbs(pythonVocabBindex.size)
}

val mostCommonTokens by lazy {
P_seq2parse.counter.rawCounts
.getFrequentItems(ErrorType.NO_FALSE_NEGATIVES)
Expand All @@ -73,6 +81,7 @@ fun main() {
// evaluateTidyparseOnSeq2Parse15k()
evaluateTidyparseOnStackoverflow()
// evaluateSeq2ParseOnStackOverflowDataset()
// println(extractErrProbs().joinToString(", ", "listOf(", ")") { "\"${it.first}\" to ${it.second}" })
}

fun evaluateSeq2ParseOnStackOverflowDataset() {
Expand Down Expand Up @@ -151,12 +160,12 @@ fun evaluateSeq2ParseOnStackOverflowDataset() {

class RankStats(val name: String = "Total") {
val upperBound = TIMEOUT_MS / 1000
val step = 20
val time = (1..10).toSet()//setOf(2, 5, 10) + (20..upperBound step 20).toSet()
// Mean Reciprocal Rank
val timedMRR = (step..upperBound step step).associateWith { 0.0 }.toMutableMap()
val timedMRR = time.associateWith { 0.0 }.toMutableMap()
// Precision at K, first int is K, second is the time cutoff
val timedPAK =
(setOf(1, 5, 10, Int.MAX_VALUE) * ((step..upperBound step step).toSet()))
(setOf(1, 5, 10, Int.MAX_VALUE) * (time.toSet()))
.associateWith { 0.0 }.toMutableMap()
var samplesEvaluated = 0

Expand All @@ -172,7 +181,7 @@ class RankStats(val name: String = "Total") {
}

(timedPAK.keys).forEach { (k, sec) ->
repairProposals.filter { it.timeMS in 0..(sec * 1000) }
repairProposals.filter { it.timeMS in 0..(sec * 1_000) }
.let {
val pak = (if(k == Int.MAX_VALUE) it else it.take(k))
.count { it.matches(groundTruthRepair) }.toDouble()
Expand Down Expand Up @@ -213,8 +222,10 @@ class MultiRankStats {
}

fun evaluateTidyparseOnStackoverflow() {
val deck = P_stackoverflow.topK(200).map { it.first }.toSet() + "ε"
println("Deck size: $deck")
// val errDeck = pythonErrProbs.expandByFrequency(10)
val topTokens = P_stackoverflow.topK(200).map { it.first } + "ε" // + errDeck
println("Top tokens: $topTokens")

val multiRankStats = MultiRankStats()

preprocessStackOverflow()
Expand All @@ -241,7 +252,7 @@ fun evaluateTidyparseOnStackoverflow() {

parallelRepair(
prompt = coarseBrokeStr,
fillers = deck,
fillers = topTokens,
maxEdits = 4,
admissibilityFilter = { map { pythonVocabBindex.getUnsafe(it) ?: it.toInt() }.isValidPython() },
// TODO: incorporate parseable segmentations into scoring mechanism to prioritize chokepoint repairs
Expand Down Expand Up @@ -373,7 +384,7 @@ fun evaluateTidyparseOnSeq2Parse15k() {
maxEdits = 4,
admissibilityFilter = { this in seq2parsePythonCFG.language },
// TODO: incorporate parseable segmentations into scoring mechanism to prioritize chokepoint repairs
scoreEdit = { P_seq2parse.score(it) }
scoreEdit = { P_seq2parse.score(it) },
).also {
it.take(20).apply { println("\nTop $size repairs:\n") }.forEach {
println("Δ=${it.scoreStr()} repair (${it.elapsed()}): ${prettyDiffNoFrills(prompt, it.resToStr())}")
Expand All @@ -388,6 +399,21 @@ fun evaluateTidyparseOnSeq2Parse15k() {
}
}

val pythonErrProbs =
listOf(
"'-'" to 5, "'raise'" to 2, "'import'" to 40, "'None'" to 3, "')'" to 495,
"'else'" to 3, "'in'" to 10, "'%'" to 2, "'pass'" to 14, "'True'" to 1,
"'|'" to 4, "'=='" to 18, "'['" to 53, "':'" to 149, "'lambda'" to 5,
"'...'" to 11, "98" to 108, "'.'" to 21, "99" to 105, "NUMBER" to 15,
"'*'" to 1, "'yield'" to 1, "'is'" to 1, "NEWLINE" to 95, "'&'" to 16,
"'from'" to 23, "'except'" to 8, "NAME" to 200, "'if'" to 6, "'}'" to 135,
"';'" to 19, "'class'" to 13, "ε" to 2225, "'return'" to 9, "'as'" to 4,
"'def'" to 17, "'/'" to 2, "'+'" to 18, "'~'" to 1, "']'" to 152,
"'global'" to 1, "','" to 292, "'('" to 234, "'for'" to 5, "'='" to 49,
"'**'" to 2, "'while'" to 1, "'{'" to 60, "'!='" to 1, "'del'" to 1,
"STRING" to 221
)

// Taken from seq2parse's Python grammar
val seq2parsePythonCFG: CFG by lazy {
"""
Expand Down Expand Up @@ -602,6 +628,15 @@ Yield_Arg -> From_Keyword Test | Testlist_Endcomma
}
}

fun extractErrProbs(): List<Pair<Σᐩ, Int>> =
preprocessStackOverflow().take(3000).asStream().parallel().flatMap { (b, _, m) ->
val patch = extractPatch(b.lexToStrTypesAsPython(), m.lexToStrTypesAsPython())
val changes = patch.changes()
changes.map { patch[it].new.let { it.ifEmpty { "ε" } } }
.also { println(it) }
.stream()
}.collect(Collectors.groupingBy { it }).mapValues { it.value.size }.toList()

fun readContents(
filename: String = "parse_errors.json",
file: File = File(File("").absolutePath +
Expand All @@ -623,5 +658,4 @@ fun readContents(
})
}
}
}

}
3 changes: 1 addition & 2 deletions src/main/kotlin/edu/mcgill/cstk/utils/ParseUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ fun Σᐩ.lexToIntTypesAsPython(
) = lexer.allTokens.map { it.type }

val pythonVocabBindex: Bindex<Σᐩ> =
Python3Lexer(CharStreams.fromString(""))
.vocabulary.let { vocab ->
Python3Lexer(CharStreams.fromString("")).vocabulary.let { vocab ->
(0..vocab.maxTokenType).associateWith { vocab.getDisplayName(it) }
}.let { Bindex(it) }//.also { println(it.toString()) }

Expand Down
10 changes: 7 additions & 3 deletions src/main/kotlin/edu/mcgill/cstk/utils/RepairUtils.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package edu.mcgill.cstk.utils

import ai.hypergraph.kaliningraph.intersperse
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.markovian.mcmc.expandByFrequency
import bijectiveRepair
import com.github.difflib.text.*
import edu.mcgill.cstk.experiments.repair.MSK
Expand Down Expand Up @@ -58,6 +59,7 @@ fun Sequence<Π2A<Σᐩ>>.minimizeFix(tokenize: Σᐩ.() -> List<Σᐩ>) =
typealias Edit = Π2A<Σᐩ>
typealias Patch = List<Edit>
val Edit.old: Σᐩ get() = first
// If new is empty, then this is a deletion
val Edit.new: Σᐩ get() = second

fun Patch.changes(): List<Int> = indices.filter { this[it].old != this[it].new }
Expand Down Expand Up @@ -111,7 +113,7 @@ fun <T> deltaDebug(elements: List<T>, n: Int = 2, checkValid: (List<T>) -> Boole
// TODO: Generify to accept List<T>
fun parallelRepair(
prompt: Σᐩ,
fillers: Set<Σᐩ>,
fillers: Collection<Σᐩ>,
maxEdits: Int = 2,
admissibilityFilter: List<Σᐩ>.() -> Boolean,
scoreEdit: ((List<Σᐩ>) -> Double)? = null,
Expand All @@ -123,10 +125,12 @@ fun parallelRepair(
// as well as insertion of tokens by the repair algorithm, which only considers substitutions
val promptTokens = prompt.tokenizeByWhitespace().intersperse(maxEdits.coerceAtMost(2))

val deck = fillers + promptTokens.toSet() - "\""

val clock: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
return bijectiveRepair(
promptTokens = promptTokens,
fillers = fillers,
deck = deck,
maxEdits = maxEdits,
takeMoreWhile = { clock.elapsedNow().inWholeMilliseconds < TIMEOUT_MS },
admissibilityFilter = admissibilityFilter,
Expand Down

0 comments on commit 8e60ec9

Please sign in to comment.