Skip to content

Commit

Permalink
compare seq2parse patch sizes across minimum edits
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 18, 2023
1 parent 843f40f commit 91b89d0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ai.hypergraph.markovian.mcmc.*
import ai.hypergraph.kaliningraph.types.*
import com.beust.klaxon.*
import edu.mcgill.cstk.utils.*
import org.antlr.v4.runtime.CommonToken
import org.apache.datasketches.frequencies.ErrorType
import org.kosat.round
import java.io.*
Expand Down Expand Up @@ -69,10 +70,34 @@ fun main() {
}

fun evaluateSeq2ParseOnStackOverflowDataset() {
var syntaxPrecision = 0.0
var humanFixPrecision = 0.0
var chrMatchPrecision = 0.0
var latency = 0.0
class Seq2ParsePrecision {
var syntaxPrecision = 0.0
var humanFixPrecision = 0.0
var chrMatchPrecision = 0.0
var samples = 0
var latency = 0.0
fun update(seq2parseWasParseable: Boolean,
seq2parseFixTks: List<Σᐩ>, minFixTks: List<Σᐩ>,
humanFix: String, seq2parseFix: String, latency: Int) {
samples += 1
syntaxPrecision += if (seq2parseWasParseable) 1.0 else 0.0
val avgSyntaxPrecision = (syntaxPrecision / (samples)).round(3)
println("Average Syntactic precision@1: $avgSyntaxPrecision")
humanFixPrecision += if (seq2parseFixTks == minFixTks) 1.0 else 0.0
val avgHumanFixPrecision = (humanFixPrecision / (samples)).round(3)
println("Average HumanEval precision@1: $avgHumanFixPrecision")
chrMatchPrecision += if (seq2parseFix == humanFix) 1.0 else 0.0
val avgChrMatchPrecision = (chrMatchPrecision / (samples)).round(3)
println("Average CharMatch precision@1: $avgChrMatchPrecision")
this.latency += latency
val avgLatency = (this.latency / samples).round(3)
println("Average latency to produce a single sample: ${avgLatency}ms\n")
}
}

val totalPrecision = Seq2ParsePrecision()
val editPrecision = (1..MAX_PATCH_SIZE).associateWith { Seq2ParsePrecision() }
var latency: Int
var percentageOfFixesShorterThanSeq2Parse = 0.0
var percentageOfFixesLongerThanSeq2Parse = 0.0
preprocessStackOverflow()
Expand All @@ -84,7 +109,7 @@ fun evaluateSeq2ParseOnStackOverflowDataset() {
val minFixSize = extractPatch(errTks, minFixTks).changes().size

val seq2parseFix = measureTimedValue { seq2parseFix(humanError) }.let {
latency += it.duration.inWholeMilliseconds
latency = it.duration.inWholeMilliseconds.toInt()
it.value
}

Expand All @@ -110,18 +135,11 @@ fun evaluateSeq2ParseOnStackOverflowDataset() {
else if (minFixSize < seq2parseEditSize) percentageOfFixesShorterThanSeq2Parse += 1.0
println("Percentage of fixes shorter than Seq2Parse: ${(percentageOfFixesShorterThanSeq2Parse / (i + 1)).round(3)}")
println("Percentage of fixes longer than Seq2Parse : ${(percentageOfFixesLongerThanSeq2Parse / (i + 1)).round(3)}")
syntaxPrecision += if (seq2parseWasParseable) 1.0 else 0.0
val avgSyntaxPrecision = (syntaxPrecision / (i + 1)).round(3)
println("Average Syntactic precision@1 (${i+1} samples): $avgSyntaxPrecision")
humanFixPrecision += if (seq2parseFixTks == minFixTks) 1.0 else 0.0
val avgHumanFixPrecision = (humanFixPrecision / (i + 1)).round(3)
println("Average HumanEval precision@1 (${i+1} samples): $avgHumanFixPrecision")
chrMatchPrecision += if (seq2parseFix == humanFix) 1.0 else 0.0
val avgChrMatchPrecision = (chrMatchPrecision / (i + 1)).round(3)
println("Average CharMatch precision@1 (${i+1} samples): $avgChrMatchPrecision")

val avgLatency = (latency / (i + 1)).round(3)
println("Average latency to produce a single sample: ${avgLatency}ms\n")
println("Ranking stats for $minFixSize-edit fixes (${editPrecision[minFixSize]!!.samples} samples):")
editPrecision[minFixSize]!!.run { update(seq2parseWasParseable, seq2parseFixTks, minFixTks, humanFix, seq2parseFix, latency) }
println("\nTotal ranking stats across all edit sizes (${totalPrecision.samples} samples):")
totalPrecision.run { update(seq2parseWasParseable, seq2parseFixTks, minFixTks, humanFix, seq2parseFix, latency) }
}
}

Expand Down Expand Up @@ -242,7 +260,7 @@ fun evaluateTidyparseOnStackoverflow() {
private fun preprocessStackOverflow(
brokeSnippets: Sequence<String> = readContents("parse_errors.json"),
fixedSnippets: Sequence<String> = readContents("parse_fixes.json")
): Sequence<Π3<Σᐩ, String, String>> =
): Sequence<Π3A<Σᐩ>> =
brokeSnippets.zip(fixedSnippets)
// .asStream()//.parallel()
.filter { (broke, fixed) ->
Expand All @@ -253,15 +271,17 @@ private fun preprocessStackOverflow(
}
.minimizeFix { tokenizeAsPython(true) }
.filter { (broke, fixed, minfix) ->
val (brokeTokens, minFixedTokens) =
broke.lexToIntTypesAsPython() to minfix.lexToIntTypesAsPython()
// val (brokeTokens, minFixedTokens) =
// broke.lexToIntTypesAsPython() to minfix.lexToIntTypesAsPython()
// (brokeTokens.size - fixedTokens.size).absoluteValue < 10 &&
minfix.isValidPython() &&
multisetManhattanDistance(brokeTokens, minFixedTokens).let { it in 1..5 }
}
.filter { (broke, fixed, minfix) ->

val minpatch = extractPatch(broke.lexToStrTypesAsPython(), minfix.lexToStrTypesAsPython())
val (brokeVis, fixedVis, minfixVis) = broke.visibleChars() to fixed.visibleChars() to minfix.visibleChars()

minfix.isValidPython() &&
minpatch.changes().size <= MAX_PATCH_SIZE &&
brokeVis != fixedVis && brokeVis != minfixVis// && fixedVis != minfixVis
// multisetManhattanDistance(brokeTokens, minFixedTokens).let { it in 1..5 }
}
// .map { (broke, fixed, minfix) ->
// prettyDiffs(listOf(broke, fixed), listOf("original snippet", "human patch")).let { origDiff ->
Expand All @@ -272,13 +292,7 @@ private fun preprocessStackOverflow(
// }
// }
// }
.filter {
extractPatch(
it.first.lexToStrTypesAsPython(),
it.third.lexToStrTypesAsPython()
).changes().size <= MAX_PATCH_SIZE
}
.distinctBy { it.third }
.distinctBy { it.π3 }
.shuffleOnline()

private fun compareSeq2ParseFix(
Expand Down
5 changes: 5 additions & 0 deletions src/main/kotlin/edu/mcgill/cstk/utils/ParseUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ fun Σᐩ.lexToStrTypesAsPython(
vocabulary: Vocabulary = lexer.vocabulary
) = lexer.allTokens.map { vocabulary.getDisplayName(it.type) }

fun Σᐩ.lexToPythonTokens(
lexer: Lexer = Python3Lexer(CharStreams.fromString(this)),
vocabulary: Vocabulary = lexer.vocabulary
) = lexer.allTokens.toList()

fun Σᐩ.lexAsPython(): Python3Lexer =
Python3Lexer(CharStreams.fromStream(byteInputStream()))

Expand Down
24 changes: 19 additions & 5 deletions src/main/kotlin/edu/mcgill/cstk/utils/RepairUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package edu.mcgill.cstk.utils

import ai.hypergraph.kaliningraph.intersperse
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.types.*
import bijectiveRepair
import com.github.difflib.text.*
import edu.mcgill.cstk.experiments.repair.MSK
import java.util.stream.Stream
import kotlin.math.*
import kotlin.time.TimeSource

Expand All @@ -31,6 +31,7 @@ fun Σᐩ.defaultTokenizer(): List<Σᐩ> =

fun Sequence2A<Σᐩ>>.minimizeFix(tokenize: Σᐩ.() -> List<Σᐩ>) =
map { (broke, fixed) ->
// val startTime = TimeSource.Monotonic.markNow()
val (brokeTokens, fixedTokens) = broke.tokenize() to fixed.tokenize()

// val brokeJoin = brokeTokens.joinToString("")
Expand All @@ -39,10 +40,19 @@ fun Sequence<Π2A<Σᐩ>>.minimizeFix(tokenize: Σᐩ.() -> List<Σᐩ>) =

val patch: Patch = extractPatch(brokeTokens, fixedTokens)
val minEdit = deltaDebug(patch.changes()) { idxs -> patch.apply(idxs).isValidPython() }
val minFix = patch.apply(minEdit)
// deltaDebug only minimizes contiguous chunks, so here we find the minimal configuration of edits
// .minimalSubpatch { patch.apply(this).isValidPython() }

// val pdiff = prettyDiffs(listOf(brokeJoin, minFix), listOf("broken", "minimized fix"))
// if(pdiff.any { it == '\u001B' } && pdiffTok.filter { !it.isWhitespace() } != pdiff.filter { !it.isWhitespace() }) println(pdiffTok + "\n\n" + pdiff)
broke to fixedJoin to minFix

// println("Reduced from ${patch.changes().size} to ${minEdit.size} edits in ${startTime.elapsedNow().inWholeMilliseconds}ms")

// if(!minFix.isValidPython()) println("Minimized fix is invalid Python: $minFix")

val minfix= patch.apply(minEdit)

broke to fixedJoin to minfix
}

typealias Edit = Π2A<Σᐩ>
Expand All @@ -52,8 +62,12 @@ val Edit.new: Σᐩ get() = second

fun Patch.changes(): List<Int> = indices.filter { this[it].old != this[it].new }

fun Patch.apply(indices: List<Int>) =
mapIndexed { i, it -> if (i in indices) it.new else it.old }.joinToString("")
fun List<Int>.minimalSubpatch(filter: List<Int>.() -> Boolean): List<Int> =
(1..size).asSequence().map { choose(it).map { it.toList() } }
.map { it.filter { it.filter() } }.firstOrNull { it.any() }?.firstOrNull() ?: this

fun Patch.apply(indices: List<Int>, separator: Σᐩ = ""): Σᐩ =
mapIndexed { i, it -> if (i in indices) it.new else it.old }.joinToString(separator)

fun extractPatch(original: List<Σᐩ>, new: List<Σᐩ>): Patch =
DiffRowGenerator.create().build()
Expand Down

0 comments on commit 91b89d0

Please sign in to comment.