Skip to content

Commit

Permalink
test CT speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Apr 13, 2024
1 parent bce217e commit e6afbed
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
3 changes: 2 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import de.undercouch.gradle.tasks.download.Download
import org.jetbrains.kotlin.gradle.dsl.JvmTarget

plugins {
kotlin("jvm") version "2.0.0-RC1"
Expand Down Expand Up @@ -236,7 +237,7 @@ tasks {
}

compileKotlin {
kotlinOptions.jvmTarget = "17"
compilerOptions.jvmTarget = JvmTarget.JVM_17
dependsOn("downloadKotlinGrammarTools")
}

Expand Down
22 changes: 11 additions & 11 deletions src/main/kotlin/edu/mcgill/cstk/experiments/probing/ProbeLLaMA.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,13 @@ import java.nio.charset.StandardCharsets
./gradlew probeLLaMA
*/
fun main() {
LlamaModel.setLogger { level: LogLevel?, message: String? -> print(message) }
// LlamaModel.setLogger { level: LogLevel?, message: String? -> print(message) }

val modelPath = File("").absolutePath + "/models/ggml-model-Q6_K.gguf"
val modelParams = ModelParameters().setNGpuLayers(43)
val inferParams = InferenceParameters()
.setTemperature(0.7f)
.setPenalizeNl(true) // .setNProbs(10)
.setMirostat(InferenceParameters.MiroStat.V2)
.setAntiPrompt("User:")
.setModelFilePath(modelPath)

val modelPath = File("").absolutePath +
"/models/ggml-model-Q6_K.gguf"

LlamaModel(modelPath, modelParams).use { model ->
LlamaModel(modelParams).use { model ->
invalidPythonStatements.lines().forEach { invalidCodeSnippet ->
val prompt = """
The following line of Python code contains a syntax error:
Expand All @@ -40,7 +34,13 @@ fun main() {

BufferedReader(InputStreamReader(System.`in`, StandardCharsets.UTF_8))
val sb = StringBuilder()
for (output in model.generate(prompt, inferParams)) {
val inferParams = InferenceParameters(prompt)
.setTemperature(0.7f)
.setPenalizeNl(true) // .setNProbs(10)
// .setMirostat(InferenceParameters.MiroStat.V2)
// .setAntiPrompt("User:")

for (output in model.generate(inferParams)) {
sb.append(output)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fun main() {
// MAX_UNIQUE = 1_000
TIMEOUT_MS = 30_000
MIN_TOKENS = 3
MAX_TOKENS = 50
MAX_TOKENS = 80
MAX_RADIUS = 3
CFG_THRESH = 10_000
evaluateBarHillelRepairOnStackOverflow()
Expand Down Expand Up @@ -279,8 +279,8 @@ val sizeAndDistBalancedRepairsUnminimized: Sequence<Π2A<Σᐩ>> by lazy {
levDist <= MAX_RADIUS
}.toList()
.groupBy { it.π3 to it.π4 }.let { map ->
val minSize = map.values.minOf { it.size }
println("Size of smallest group: $minSize")
val minSize = map.entries.minBy { it.value.size }
println("Size of smallest group: ${minSize.key}, ${minSize.value.size}")
map.mapValues { (_, v) -> v.shuffled().take(100) }
}
.values.asSequence().flatten()
Expand Down

0 comments on commit e6afbed

Please sign in to comment.