Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved sliding window - trim old context #12

Merged
merged 9 commits into from Sep 7, 2023
Expand Up @@ -16,6 +16,7 @@ import com.aallam.openai.api.image.ImageCreation
import com.aallam.openai.api.image.ImageURL
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import com.example.compose.jetchat.data.CustomChatMessage
import com.example.compose.jetchat.data.DroidconContract
import com.example.compose.jetchat.data.DroidconDbHelper
import com.example.compose.jetchat.data.DroidconSessionData
Expand All @@ -25,6 +26,7 @@ import com.example.compose.jetchat.functions.AskDatabaseFunction
import com.example.compose.jetchat.functions.ListFavoritesFunction
import com.example.compose.jetchat.functions.RemoveFavoriteFunction
import com.example.compose.jetchat.functions.SessionsByTimeFunction
import kotlinx.serialization.json.JsonNull.content
import kotlinx.serialization.json.jsonPrimitive
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
Expand All @@ -47,7 +49,7 @@ infix fun DoubleArray.dot(other: DoubleArray): Double {
@OptIn(BetaOpenAI::class)
class DroidconEmbeddingsWrapper(val context: Context?) {
private val openAIToken: String = Constants.OPENAI_TOKEN
private var conversation: MutableList<ChatMessage> = mutableListOf()
private var conversation: MutableList<CustomChatMessage> = mutableListOf()
private var openAI: OpenAI = OpenAI(openAIToken)
/** Sqlite access for favorites, embeddings, and SQL queries */
private val dbHelper = DroidconDbHelper(context)
Expand All @@ -56,11 +58,11 @@ class DroidconEmbeddingsWrapper(val context: Context?) {
* then loaded into memory on first use */
private var vectorCache: MutableMap<String, DoubleArray> = mutableMapOf()

private var systemMessage: ChatMessage
private var systemMessage: CustomChatMessage
init {
systemMessage = ChatMessage(
systemMessage = CustomChatMessage (
role = ChatRole.System,
content = """You are a personal assistant called JetchatAI.
grounding = """You are a personal assistant called JetchatAI.
You will answer questions about the speakers and sessions at the droidcon SF conference.
The conference is on June 8th and 9th, 2023 on the UCSF campus in Mission Bay.
It starts at 8am and finishes by 6pm.
Expand All @@ -81,9 +83,10 @@ class DroidconEmbeddingsWrapper(val context: Context?) {

// add the user's message to the chat history
conversation.add(
ChatMessage(
CustomChatMessage(
role = ChatRole.User,
content = messagePreamble + message
grounding = messagePreamble,
userContent = message
)
)

Expand Down Expand Up @@ -134,9 +137,9 @@ class DroidconEmbeddingsWrapper(val context: Context?) {
// no function, add the response to the conversation history
Log.i("LLM", "No function call was made, showing LLM response")
conversation.add(
ChatMessage(
CustomChatMessage(
role = ChatRole.Assistant,
content = chatResponse
userContent = chatResponse
)
)
} else { // handle function call
Expand Down Expand Up @@ -196,19 +199,19 @@ class DroidconEmbeddingsWrapper(val context: Context?) {
handled = false
chatResponse += " " + function.name + " " + function.arguments
conversation.add(
ChatMessage(
CustomChatMessage(
role = ChatRole.Assistant,
content = chatResponse
userContent = chatResponse
)
)
}
}
if (handled) {
// add the 'call a function' response to the history
conversation.add(
ChatMessage(
CustomChatMessage(
role = completionMessage.role,
content = completionMessage.content
userContent = completionMessage.content
?: "", // required to not be empty in this case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still true with the custom chat message class? or can we remove the null check now

functionCall = completionMessage.functionCall
)
Expand All @@ -217,10 +220,10 @@ class DroidconEmbeddingsWrapper(val context: Context?) {
// add the response to the 'function' call to the history
// so that the LLM can form the final user-response
conversation.add(
ChatMessage(
CustomChatMessage(
role = ChatRole.Function,
name = function.name,
content = functionResponse
userContent = functionResponse
)
)

Expand All @@ -239,9 +242,9 @@ class DroidconEmbeddingsWrapper(val context: Context?) {
chatResponse = functionCompletion.choices.first().message?.content!!
// ignore trimmedConversation, will be recreated
conversation.add(
ChatMessage(
CustomChatMessage(
role = ChatRole.Assistant,
content = chatResponse
userContent = chatResponse
)
)
}
Expand Down Expand Up @@ -296,17 +299,20 @@ class DroidconEmbeddingsWrapper(val context: Context?) {
Log.i("LLM", "Top match was ${sortedVectors.lastKey()} which was below 0.8 and failed to meet criteria for grounding data")
}

// ALWAYS add the date and time to every prompt
var date = Constants.TEST_DATE
var time = Constants.TEST_TIME
if (date == "") {
date = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd"))
}
if (time == "") {
time = LocalDateTime.now().format(DateTimeFormatter.ofPattern("HH:mm"))
if (messagePreamble.isNullOrEmpty()) {
// ONLY show date/time when embeddings are empty, as it triggers the SessionsByTime function (I THINK)
// ALWAYS add the date and time to every prompt
Comment on lines +303 to +304
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// ONLY show date/time when embeddings are empty, as it triggers the SessionsByTime function (I THINK)
// ALWAYS add the date and time to every prompt
// ONLY show date/time when embeddings are empty, as it triggers the SessionsByTime function

hoping the "i think" has been resolved 😜

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 not sure!

var date = Constants.TEST_DATE
var time = Constants.TEST_TIME
if (date == "") {
date = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd"))
}
if (time == "") {
time = LocalDateTime.now().format(DateTimeFormatter.ofPattern("HH:mm"))
}
messagePreamble =
"The current date is $date and the time (in 24 hour format) is $time.\n\n$messagePreamble"
}
messagePreamble =
"The current date is $date and the time (in 24 hour format) is $time.\n\n$messagePreamble"
return messagePreamble
}

Expand Down
Expand Up @@ -5,14 +5,17 @@ import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.FunctionCall



/**
* Wrapper for the final `ChatMessage` class so that we
* can count tokens and split the grounding from the
* user query for embedding-supported (RAG) requests
*/
class CustomChatMessage @OptIn(BetaOpenAI::class) constructor(
val role: ChatRole,
val grounding: String? = null,
val userContent: String? = null,
val name: String? = null,
val functionCall: FunctionCall? = null
public val role: ChatRole,
public val grounding: String? = null,
public val userContent: String? = null,
public val name: String? = null,
public val functionCall: FunctionCall? = null
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
) {

public fun summary(): String? {
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -26,8 +29,26 @@ class CustomChatMessage @OptIn(BetaOpenAI::class) constructor(
}
}

public fun getTokenCount() : Int {
var messageContent = grounding + userContent ?: ""
/**
* Count the number of tokens in the user query PLUS the additional
* embedding data for grounding OR the functions when userContent is empty
*
* `userContent` can never be larger than `tokensAllowed` - if this happens in
* message history, the message will be dropped from the window. In theory
* it could happen if a user entered maxTokens worth of text in their
* chat query, but in practice that seems unlikely (and should probably
* be a validation error on the UI)
*/
public fun getTokenCount(includeGrounding: Boolean = true, tokensAllowed: Int = -1) : Int {
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
var messageContent = userContent ?: ""
if (includeGrounding) {
messageContent = if (tokensAllowed < 0) {
grounding + userContent ?: ""
} else { // only include as much of the grounding as will fit
Tokenizer.trimToTokenLimit(grounding, tokensAllowed) + userContent ?: ""
}
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
}

if (userContent.isNullOrEmpty()) {
messageContent = "" + functionCall?.name + functionCall?.arguments
}
Expand All @@ -36,9 +57,30 @@ class CustomChatMessage @OptIn(BetaOpenAI::class) constructor(
return messageTokens
}

/**
*
*/
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
public fun canFitInTokenLimit(includeGrounding: Boolean = true, tokensAllowed: Int = -1): Boolean {
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
if (tokensAllowed < 0) return true
if (getTokenCount(includeGrounding, tokensAllowed) <= tokensAllowed) return true
return false
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Create `ChatMessage` instance to add to completion request
*/
@OptIn(BetaOpenAI::class)
public fun getChatMessage () : ChatMessage {
val content = grounding + userContent
public fun getChatMessage (includeGrounding: Boolean = true, tokensAllowed: Int = -1) : ChatMessage {
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
var content = userContent
if (includeGrounding) {
content += if (tokensAllowed < 0) {
grounding
} else {
// only include as much of the grounding as will fit
// TODO: preserve leading and trailing grounding instructions
Tokenizer.trimToTokenLimit(grounding, tokensAllowed)
}
}
return ChatMessage(role = role, content = content, name = name, functionCall = functionCall)
}
}
Expand Up @@ -11,9 +11,12 @@ class SlidingWindow {
/**
* Takes the conversation history and trims older ChatMessage
* objects (except for System messasge) from the start
*
* Only includes the most recent embedding, omits the additional
* grounding information from older messages
*/
@OptIn(BetaOpenAI::class)
fun chatHistoryToWindow (conversation: MutableList<ChatMessage>): MutableList<ChatMessage> {
fun chatHistoryToWindow (conversation: MutableList<CustomChatMessage>): MutableList<ChatMessage> {
Log.v("LLM-SW", "-- chatHistoryToWindow() max tokens ${Constants.OPENAI_MAX_TOKENS}")
// set parameters for sliding window
val tokenLimit = Constants.OPENAI_MAX_TOKENS
Expand All @@ -24,6 +27,7 @@ class SlidingWindow {
Log.v("LLM-SW", "-- tokens reserved for response $expectedResponseSizeTokens and functions $reservedForFunctionsTokens")
var tokensUsed = 0
var systemMessage: ChatMessage? = null
var includeGrounding = true

/** maximum tokens for chat, after hardcoded functions and allowing for a given response size */
val tokenMax = tokenLimit - expectedResponseSizeTokens - reservedForFunctionsTokens
Expand All @@ -33,23 +37,30 @@ class SlidingWindow {

// check for system message
if (conversation[0].role == ChatRole.System) {
systemMessage = conversation[0]
systemMessage = conversation[0].getChatMessage()
var systemMessageTokenCount = Tokenizer.countTokensIn(systemMessage.content)
tokensUsed += systemMessageTokenCount
Log.v("LLM-SW", "-- tokens used by system message: $tokensUsed")
}

// loop through other messages
for (message in conversation.reversed()) {
if (message.role != ChatRole.System) {
var m = CustomChatMessage(message.role, "", message.content, message.name, message.functionCall)
for (m in conversation.reversed()) {
if (m.role != ChatRole.System) {

Log.v("LLM-SW", "-- message (${m.role.role}) ${m.summary()}")
Log.v("LLM-SW", " contains tokens: ${m.getTokenCount()}")
if ((tokensUsed + m.getTokenCount()) < tokenMax) {
messagesInWindow.add(message)
tokensUsed += m.getTokenCount()
Log.v("LLM-SW", " added. Still available: ${tokenMax - tokensUsed}")
Log.v("LLM-SW", " contains tokens: ${m.getTokenCount(includeGrounding)}")
val tokensRemaining = tokenMax - tokensUsed
if (m.canFitInTokenLimit(includeGrounding, tokensRemaining)) {
messagesInWindow.add(m.getChatMessage(includeGrounding, tokensRemaining))
tokensUsed += m.getTokenCount(includeGrounding, tokensRemaining)

if (m.role == ChatRole.User) {
Log.v("LLM-SW", " added (grounding:$includeGrounding). Still available: ${tokenMax - tokensUsed}")
// stop subsequent user messages from including grounding
includeGrounding = false
} else {
Log.v("LLM-SW", " added. Still available: ${tokenMax - tokensUsed}")
}
} else {
Log.v("LLM-SW", " NOT ADDED. Still available: ${tokenMax - tokensUsed} (inc response quota ${expectedResponseSizeTokens}) ")
break // could optionally keep adding subsequent, smaller messages to context up until token limit
Expand Down
Expand Up @@ -22,5 +22,21 @@ class Tokenizer {

return tokens
}

/**
* Trim the input text to be under the number of tokens specified
*
* @return substring of the text that's no longer than the
* specified number of tokens
*/
fun trimToTokenLimit (text: String?, tokenLimit: Int): String? {
// TODO: limit by tokens instead of the rough character approximation
val charLimit = tokenLimit * 4
return if (text?.length!! <= charLimit) {
conceptdev marked this conversation as resolved.
Show resolved Hide resolved
text
} else {
text?.substring(0, charLimit)
}
}
}
}
Expand Up @@ -19,7 +19,7 @@ class SessionsByTimeFunction {
return "sessionsByTime"
}
fun description(): String {
return "Given a date and specific time or time range, return the sessions that start on that date, during that time."
return "ONLY WHEN a date and time is specified, return the sessions that start on that date, during that time."
}
fun params(): Parameters {
val params = Parameters.buildJsonObject {
Expand All @@ -31,7 +31,7 @@ class SessionsByTimeFunction {
}
putJsonObject("earliestTime") {
put("type", "string")
put("description", "The earliest time that the conference sessions might start, eg. 09:00. Defaults to the current time.")
put("description", "The earliest time that the conference sessions might start, eg. 09:00.")
}
putJsonObject("latestTime") {
put("type", "string")
Expand All @@ -40,6 +40,7 @@ class SessionsByTimeFunction {
}
putJsonArray("required") {
add("date")
add("earliestTime")
}
}
return params
Expand Down