Skip to content

Commit

Permalink
Simplified the API to just return a single ChatMessage from a complet…
Browse files Browse the repository at this point in the history
…ion call, rather than the almost-useless case of allowing multiple choices.
  • Loading branch information
btfranklin committed Apr 21, 2023
1 parent db7541f commit 7396266
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 35 deletions.
21 changes: 12 additions & 9 deletions Sources/CleverBird/chat/OpenAIChatThread.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ public class OpenAIChatThread {
private let model: Model
private let temperature: Percentage
private let top_p: Percentage?
private let numberOfCompletionsToCreate: Int?
private let stop: [String]?
private let presence_penalty: Penalty?
private let frequency_penalty: Penalty?
Expand All @@ -33,23 +32,22 @@ public class OpenAIChatThread {
self.model = model
self.temperature = temperature
self.top_p = top_p
self.numberOfCompletionsToCreate = numberOfCompletionsToCreate
self.stop = stop
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.user = user
self.logger = logger ?? Self.DEFAULT_LOGGER
}

public func addSystemMessage(content: String) -> Self {
public func addSystemMessage(_ content: String) -> Self {
addMessage(ChatMessage(role: .system, content: content))
}

public func addUserMessage(content: String) -> Self {
public func addUserMessage(_ content: String) -> Self {
addMessage(ChatMessage(role: .user, content: content))
}

public func addAssistantMessage(content: String) -> Self {
public func addAssistantMessage(_ content: String) -> Self {
addMessage(ChatMessage(role: .assistant, content: content))
}

Expand All @@ -60,7 +58,7 @@ public class OpenAIChatThread {
}

extension OpenAIChatThread {
public func complete() async -> ChatCompletionResponse? {
public func complete() async -> ChatMessage? {

var urlComponents = URLComponents()
urlComponents.scheme = "https"
Expand All @@ -80,7 +78,6 @@ extension OpenAIChatThread {
model: self.model,
temperature: self.temperature,
top_p: self.top_p,
n: self.numberOfCompletionsToCreate,
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
Expand All @@ -104,12 +101,18 @@ extension OpenAIChatThread {
let json = jsonStr.data(using: .utf8)!
let decoder = JSONDecoder()
do {
let product = try decoder.decode(ChatCompletionResponse.self, from: json)
return product
let response = try decoder.decode(ChatCompletionResponse.self, from: json)
if let choice = response.choices.first {
return choice.message
} else {
logger("Error decoding ChatCompletion OpenAI API Response: Unable to parse completion")
return nil
}
} catch {
logger("Error decoding ChatCompletion OpenAI API Response: \(error)")
return nil
}

case .failure(let error):
logger("Error executing request: \(error.localizedDescription)")
return nil
Expand Down
24 changes: 5 additions & 19 deletions Sources/CleverBird/chat/datatypes.swift
Original file line number Diff line number Diff line change
@@ -1,37 +1,23 @@
// Created by B.T. Franklin on 4/15/23

// The overall structure of these is:
// ChatCompletionResponse
// - [ChatChoice]
// - ChatMessage
// - ChatRole
//
// ChatCompletionRequest
// - [ChatMessage]
// - ChatRole

public struct ChatCompletionResponse: Codable {
public let model: Model
public let choices: [ChatChoice]
struct ChatCompletionResponse: Codable {
struct Choice: Codable {
let message: ChatMessage
}
let choices: [Choice]
}

public struct ChatCompletionRequest: Codable {
public let model: Model
public let temperature: Percentage
public let top_p: Percentage?
public let n: Int?
public let stop: [String]?
public let presence_penalty: Penalty?
public let frequency_penalty: Penalty?
public let user: String?
public let messages: [ChatMessage]
}

public struct ChatChoice: Codable {
public let message: ChatMessage
public let index: Int
}

public struct ChatMessage: Codable {
public let role: ChatRole
public let content: String
Expand Down
14 changes: 7 additions & 7 deletions Tests/CleverBirdTests/OpenAIChatThreadTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ class OpenAIChatThreadTests: XCTestCase {

let openAIAPIConnection = OpenAIAPIConnection(apiKey: "fake_api_key", urlRequester: mockURLRequester)
let chatThread = OpenAIChatThread(connection: openAIAPIConnection)
.addSystemMessage(content: "You are a helpful assistant.")
.addUserMessage(content: "Who won the world series in 2020?")
.addSystemMessage("You are a helpful assistant.")
.addUserMessage("Who won the world series in 2020?")


let completionResponse = await chatThread.complete()
let completion = await chatThread.complete()

XCTAssertNotNil(completionResponse, "Completion response is nil")
XCTAssertEqual(completionResponse?.choices.first?.message.content.trimmingCharacters(in: .whitespacesAndNewlines),
XCTAssertNotNil(completion, "Completion is nil")
XCTAssertEqual(completion?.content.trimmingCharacters(in: .whitespacesAndNewlines),
"The 2020 World Series was won by the Los Angeles Dodgers.", "Unexpected assistant response")
}

func testTokenCount() {
let openAIAPIConnection = OpenAIAPIConnection(apiKey: "fake_api_key", urlRequester: MockURLRequester(response: ""))
let chatThread = OpenAIChatThread(connection: openAIAPIConnection)
.addSystemMessage(content: "You are a helpful assistant.")
.addUserMessage(content: "Who won the world series in 2020?")
.addSystemMessage("You are a helpful assistant.")
.addUserMessage("Who won the world series in 2020?")

let tokenCount = chatThread.tokenCount()

Expand Down

0 comments on commit 7396266

Please sign in to comment.