Skip to content

Commit

Permalink
add stop feature
Browse files Browse the repository at this point in the history
  • Loading branch information
eastriverlee committed Jan 27, 2024
1 parent 2796bd2 commit 9b427e8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Expand Up @@ -30,6 +30,7 @@ struct ContentView: View {
@StateObject var bot = Bot()
@State var input = "Give me seven national flag emojis people use the most; You must include South Korea."
func respond() { Task { await bot.respond(to: input) } }
func stop() { bot.stop() }

var body: some View {
VStack(alignment: .leading) {
Expand All @@ -40,6 +41,9 @@ struct ContentView: View {
Button(action: respond) {
Image(systemName: "paperplane.fill")
}
Button(action: stop) {
Image(systemName: "stop.fill")
}
}
}.frame(maxWidth: .infinity).padding()
}
Expand Down
7 changes: 7 additions & 0 deletions Sources/LLM/LLM.swift
Expand Up @@ -145,8 +145,14 @@ open class LLM: ObservableObject {
self.template = template
}

private var shouldContinuePredicting = false
public func stop() {
shouldContinuePredicting = false
}

@InferenceActor
private func predictNextToken() async -> Token {
guard shouldContinuePredicting else { return llama_token_eos(model) }
let logits = llama_get_logits_ith(context.pointer, batch.n_tokens - 1)!
var candidates: [llama_token_data] = (0..<totalTokenCount).map { token in
llama_token_data(id: Int32(token), logit: logits[token], p: 0.0)
Expand Down Expand Up @@ -194,6 +200,7 @@ open class LLM: ObservableObject {
batch.add(token, batch.n_tokens, [0], i == initialCount - 1)
}
context.decode(batch)
shouldContinuePredicting = true
return true
}

Expand Down

0 comments on commit 9b427e8

Please sign in to comment.