Skip to content

Commit

Permalink
Major refactor to use the "Get" package for all the HTTP work, which …
Browse files Browse the repository at this point in the history
…makes the code much cleaner and focused.
  • Loading branch information
btfranklin committed May 6, 2023
1 parent d8aabcb commit b1acfe7
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 253 deletions.
14 changes: 14 additions & 0 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"pins" : [
{
"identity" : "get",
"kind" : "remoteSourceControl",
"location" : "https://github.com/kean/Get",
"state" : {
"revision" : "12830cc64f31789ae6f4352d2d51d03a25fc3741",
"version" : "2.1.6"
}
}
],
"version" : 2
}
5 changes: 4 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ let package = Package(
targets: ["CleverBird"]),
],
dependencies: [
.package(url: "https://github.com/kean/Get", from: "2.1.6"),
],
targets: [
.target(
name: "CleverBird",
dependencies: [],
dependencies: [
.product(name: "Get", package: "Get")
],
resources: [
.process("tokenization/resources/gpt3-encoder.json"),
.process("tokenization/resources/gpt3-vocab.bpe"),
Expand Down
42 changes: 38 additions & 4 deletions Sources/CleverBird/OpenAIAPIConnection.swift
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
// Created by B.T. Franklin on 12/23/22

import Foundation
import Get

public struct OpenAIAPIConnection {
public class OpenAIAPIConnection {

private static let CHAT_COMPLETION_PATH = "/v1/chat/completions"

let urlRequester: URLRequester
let apiKey: String
let organization: String?
let client: APIClient

private let requestHeaders: [String:String]

public init(apiKey: String, organization: String? = nil, urlRequester: URLRequester? = nil) {
public init(apiKey: String, organization: String? = nil) {
self.apiKey = apiKey
self.organization = organization
self.urlRequester = urlRequester ?? HTTPURLRequester()

var urlComponents = URLComponents()
urlComponents.scheme = "https"
urlComponents.host = "api.openai.com"
let openAIChatCompletionURL = urlComponents.url

let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase

var clientConfiguration = APIClient.Configuration(baseURL: openAIChatCompletionURL)
clientConfiguration.encoder = encoder

self.client = APIClient(configuration: clientConfiguration)

var requestHeaders = [
"Content-Type": "application/json",
"Authorization": "Bearer \(apiKey)"
]
if let organization {
requestHeaders["OpenAI-Organization"] = organization
}
self.requestHeaders = requestHeaders
}

func createRequest(for body: Encodable) async throws -> Request<ChatCompletionResponse> {
Request<ChatCompletionResponse>(
path: Self.CHAT_COMPLETION_PATH,
method: .post,
body: body,
headers: self.requestHeaders)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Created by B.T. Franklin on 4/15/23

public struct ChatCompletionRequest: Codable {
public struct ChatCompletionRequestBody: Codable {
public let model: Model
public let temperature: Percentage
public let topP: Percentage?
Expand Down Expand Up @@ -28,19 +28,3 @@ public struct ChatCompletionRequest: Codable {
self.messages = messages
}
}

public struct ChatMessage: Codable {
public enum Role: String, Codable {
case system
case user
case assistant
}
public let role: Role
public let content: String

public init(role: Role,
content: String) {
self.role = role
self.content = content
}
}
9 changes: 9 additions & 0 deletions Sources/CleverBird/chat/ChatCompletionResponse.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Created by B.T. Franklin on 5/5/23

struct ChatCompletionResponse: Codable {
struct Choice: Codable {
let message: ChatMessage
}
let choices: [Choice]
}

17 changes: 17 additions & 0 deletions Sources/CleverBird/chat/ChatMessage.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Created by B.T. Franklin on 5/5/23

public struct ChatMessage: Codable {
public enum Role: String, Codable {
case system
case user
case assistant
}
public let role: Role
public let content: String

public init(role: Role,
content: String) {
self.role = role
self.content = content
}
}
30 changes: 30 additions & 0 deletions Sources/CleverBird/chat/OpenAIChatThread+complete.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Created by B.T. Franklin on 5/5/23

import Get

extension OpenAIChatThread {
public func complete() async -> ChatMessage? {

let requestBody = ChatCompletionRequestBody(
model: self.model,
temperature: self.temperature,
topP: self.topP,
stop: self.stop,
presencePenalty: self.presencePenalty,
frequencyPenalty: self.frequencyPenalty,
user: self.user,
messages: self.messages
)

do {
let request = try await self.connection.createRequest(for: requestBody)
let response = try await self.connection.client.send(request)
let completion = response.value

return completion.choices.first?.message
} catch {
logger("Error executing request: \(error.localizedDescription)")
return nil
}
}
}
39 changes: 39 additions & 0 deletions Sources/CleverBird/chat/OpenAIChatThread+tokenCount.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Created by B.T. Franklin on 5/5/23

extension OpenAIChatThread {
public func tokenCount() -> Int {

let tokenEncoder: TokenEncoder
do {
tokenEncoder = try TokenEncoder()
} catch {
logger("Unable to create token encoder: \(error)")
return -1
}

var tokensPerMessage: Int

switch self.model {
case .gpt35Turbo:
tokensPerMessage = 4
case .gpt4:
tokensPerMessage = 3
}

var numTokens = 0
for message in messages {
do {
let roleTokens = try tokenEncoder.encode(text: message.role.rawValue).count
let contentTokens = try tokenEncoder.encode(text: message.content).count

numTokens += roleTokens + contentTokens + tokensPerMessage
} catch {
logger("Error encoding text: \(error)")
}
}

numTokens += 3 // every reply is primed with assistant

return numTokens
}
}
129 changes: 10 additions & 119 deletions Sources/CleverBird/chat/OpenAIChatThread.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,21 @@ import Foundation

public class OpenAIChatThread {

private struct ChatCompletionResponse: Codable {
struct Choice: Codable {
let message: ChatMessage
}
let choices: [Choice]
}

private static let DEFAULT_LOGGER: Logger = { message in
print(message)
}

private let connection: OpenAIAPIConnection
private let model: Model
private let temperature: Percentage
private let topP: Percentage?
private let stop: [String]?
private let presencePenalty: Penalty?
private let frequencyPenalty: Penalty?
private let user: String?
private let logger: Logger
let connection: OpenAIAPIConnection
let model: Model
let temperature: Percentage
let topP: Percentage?
let stop: [String]?
let presencePenalty: Penalty?
let frequencyPenalty: Penalty?
let user: String?
let logger: Logger

private var messages: [ChatMessage] = []
var messages: [ChatMessage] = []

public init(connection: OpenAIAPIConnection,
model: Model = .gpt4,
Expand Down Expand Up @@ -71,105 +64,3 @@ public class OpenAIChatThread {
messages.filter { $0.role != .system }
}
}

extension OpenAIChatThread {
public func complete() async -> ChatMessage? {

var urlComponents = URLComponents()
urlComponents.scheme = "https"
urlComponents.host = "api.openai.com"
urlComponents.path = "/v1/chat/completions"
let openAIChatCompletionURL = urlComponents.url

var request = URLRequest(url: openAIChatCompletionURL!)
request.httpMethod = "POST"
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue("Bearer \(connection.apiKey)", forHTTPHeaderField: "Authorization")
if let organization = connection.organization {
request.setValue(organization, forHTTPHeaderField: "OpenAI-Organization")
}

let requestBody = ChatCompletionRequest(
model: self.model,
temperature: self.temperature,
topP: self.topP,
stop: self.stop,
presencePenalty: self.presencePenalty,
frequencyPenalty: self.frequencyPenalty,
user: self.user,
messages: self.messages
)

do {
let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
let httpBodyJson = try encoder.encode(requestBody)
request.httpBody = httpBodyJson
} catch {
logger("Unable to convert to JSON \(error)")
return nil
}

let urlRequester = self.connection.urlRequester
let result = await urlRequester.executeRequest(request, withSessionConfig: nil)
switch result {
case .success(let jsonStr):
let json = jsonStr.data(using: .utf8)!
let decoder = JSONDecoder()
do {
let response = try decoder.decode(ChatCompletionResponse.self, from: json)
if let choice = response.choices.first {
_ = addMessage(choice.message)
return choice.message
} else {
logger("Error decoding ChatCompletion OpenAI API Response: Unable to parse completion")
return nil
}
} catch {
logger("Error decoding ChatCompletion OpenAI API Response: \(error)")
return nil
}

case .failure(let error):
logger("Error executing request: \(error.localizedDescription)")
return nil
}
}}

extension OpenAIChatThread {
public func tokenCount() -> Int {

let tokenEncoder: TokenEncoder
do {
tokenEncoder = try TokenEncoder()
} catch {
logger("Unable to create token encoder: \(error)")
return -1
}

var tokensPerMessage: Int

switch self.model {
case .gpt35Turbo:
tokensPerMessage = 4
case .gpt4:
tokensPerMessage = 3
}

var numTokens = 0
for message in messages {
do {
let roleTokens = try tokenEncoder.encode(text: message.role.rawValue).count
let contentTokens = try tokenEncoder.encode(text: message.content).count

numTokens += roleTokens + contentTokens + tokensPerMessage
} catch {
logger("Error encoding text: \(error)")
}
}

numTokens += 3 // every reply is primed with assistant

return numTokens
}
}
39 changes: 0 additions & 39 deletions Sources/CleverBird/http/HTTPURLRequester.swift

This file was deleted.

0 comments on commit b1acfe7

Please sign in to comment.