Skip to content

Commit

Permalink
WIP: redoing the way messages are sorted
Browse files Browse the repository at this point in the history
  • Loading branch information
hunterjackson committed Mar 26, 2024
1 parent 3f7fb3d commit b79cf49
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 62 deletions.
3 changes: 2 additions & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/main/java/com/meta/cp4m/Service.java
Expand Up @@ -70,16 +70,16 @@ public MessageHandler<T> messageHandler() {
}

private void execute(ThreadState<T> thread) {
T llmResponse;
ThreadState<T> updatedThread;
try {
llmResponse = llmPlugin.handle(thread);
updatedThread = llmPlugin.handle(thread);
} catch (IOException e) {
LOGGER.error("failed to communicate with LLM", e);
return;
}
store.add(thread,llmResponse);
store.add(updatedThread);
try {
handler.respond(llmResponse);
handler.respond(updatedThread.tail());
} catch (Exception e) {
// we log in the handler where we have the body context
// TODO: create transactional store add
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Expand Up @@ -37,7 +37,7 @@ public HuggingFaceLlamaPlugin(HuggingFaceConfig config) {
}

@Override
public T handle(ThreadState<T> threadState) throws IOException {
public ThreadState<T> handle(ThreadState<T> threadState) throws IOException {
ObjectNode body = MAPPER.createObjectNode();
ObjectNode params = MAPPER.createObjectNode();

Expand All @@ -49,7 +49,7 @@ public T handle(ThreadState<T> threadState) throws IOException {

Optional<String> prompt = promptCreator.createPrompt(threadState);
if (prompt.isEmpty()) {
return threadState.newMessageFromBot(
return threadState.withNewMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
}

Expand All @@ -72,6 +72,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
String llmResponse = allGeneratedText.strip().replace(prompt.get().strip(), "");
Instant timestamp = Instant.now();

return threadState.newMessageFromBot(timestamp, llmResponse);
return threadState.withNewMessageFromBot(timestamp, llmResponse);
}
}
2 changes: 1 addition & 1 deletion src/main/java/com/meta/cp4m/llm/LLMPlugin.java
Expand Up @@ -14,5 +14,5 @@

public interface LLMPlugin<T extends Message> {

T handle(ThreadState<T> threadState) throws IOException;
ThreadState<T> handle(ThreadState<T> threadState) throws IOException;
}
6 changes: 3 additions & 3 deletions src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java
Expand Up @@ -126,7 +126,7 @@ private Optional<ArrayNode> pruneMessages(ArrayNode messages, @Nullable JsonNode
}

@Override
public T handle(ThreadState<T> threadState) throws IOException {
public ThreadState<T> handle(ThreadState<T> threadState) throws IOException {
T fromUser = threadState.tail();

ObjectNode body = MAPPER.createObjectNode();
Expand Down Expand Up @@ -161,7 +161,7 @@ public T handle(ThreadState<T> threadState) throws IOException {

Optional<ArrayNode> prunedMessages = pruneMessages(messages, null);
if (prunedMessages.isEmpty()) {
return threadState.newMessageFromBot(
return threadState.withNewMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
}
body.set("messages", prunedMessages.get());
Expand All @@ -182,6 +182,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
JsonNode choice = responseBody.get("choices").get(0);
String messageContent = choice.get("message").get("content").textValue();
return threadState.newMessageFromBot(timestamp, messageContent);
return threadState.withNewMessageFromBot(timestamp, messageContent);
}
}
45 changes: 22 additions & 23 deletions src/main/java/com/meta/cp4m/message/MessageNode.java
Expand Up @@ -8,28 +8,27 @@

package com.meta.cp4m.message;

public class MessageNode <T extends Message>{
T message;
T parentMessage;
import org.checkerframework.checker.nullness.qual.Nullable;

public MessageNode(T message){
this.message = message;
this.parentMessage = null;
}
public MessageNode(T message, T parentMessage){
this.message = message;
this.parentMessage = parentMessage;
}
public T getMessage() {
return message;
}
public T getParentMessage() {
return parentMessage;
}
public void setMessage(T message) {
this.message = message;
}
public void setParentMessage(T parentMessage) {
this.parentMessage = parentMessage;
}
public class MessageNode<T extends Message> {
T message;
@Nullable MessageNode<T> parentMessage;

public MessageNode(T message) {
this.message = message;
this.parentMessage = null;
}

public MessageNode(T message, @Nullable MessageNode<T> parentMessage) {
this.message = message;
this.parentMessage = parentMessage;
}

public T message() {
return message;
}

public @Nullable MessageNode<T> parentMessage() {
return parentMessage;
}
}
105 changes: 86 additions & 19 deletions src/main/java/com/meta/cp4m/message/ThreadState.java
Expand Up @@ -16,6 +16,7 @@
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.checkerframework.common.reflection.qual.NewInstance;

public class ThreadState<T extends Message> {
private final List<MessageNode<T>> messageNodes;
Expand All @@ -25,34 +26,83 @@ private ThreadState(T message) {
Objects.requireNonNull(message);
Preconditions.checkArgument(
message.role() != Role.SYSTEM, "ThreadState should never hold a system message");
MessageNode<T> messageNode = new MessageNode<>(message,null);
MessageNode<T> messageNode = new MessageNode<>(message, null);
this.messageNodes = ImmutableList.of(messageNode);
messageFactory = MessageFactory.instance(message);
}

private ThreadState(List<MessageNode<T>> nodes, MessageFactory<T> factory) {
this.messageNodes = nodes;
this.messageFactory = factory;
}

/** Constructor that exists to support the with method */
private ThreadState(ThreadState<T> current, ThreadState<T> old, T newMessage) {
Objects.requireNonNull(newMessage);
Preconditions.checkArgument(
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
messageFactory = current.messageFactory;
Preconditions.checkArgument(
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<MessageNode<T>> messageNodes = current.messageNodes;
MessageNode<T> mWithParentMessage = new MessageNode<>(newMessage,old.tail());
// MessageNode<T> mWithParentMessage = new MessageNode<>(newMessage, old.tail());
MessageNode<T> mWithParentMessage = new MessageNode<>(newMessage);
this.messageNodes =
Stream.concat(messageNodes.stream(), Stream.of(mWithParentMessage))
.sorted((m1,m2) -> m1.getParentMessage() == m2.getParentMessage() ? compare(m1.getMessage().role().priority(),m2.getMessage().role().priority()) : (m1.getMessage().timestamp().compareTo(m2.getMessage().timestamp())))
.collect(Collectors.toUnmodifiableList());

Stream.concat(messageNodes.stream(), Stream.of(mWithParentMessage))
.sorted(
(m1, m2) ->
m1.parentMessage() == m2.parentMessage()
? compare(m1.message().role().priority(), m2.message().role().priority())
: (m1.message().timestamp().compareTo(m2.message().timestamp())))
.collect(Collectors.toUnmodifiableList());
Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this thread state");
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this thread state");
}

private int compare(int priority1, int priority2){
return Integer.compare(priority1, priority2);
public static <T extends Message> ThreadState<T> merge(
ThreadState<T> first, ThreadState<T> second) {
List<MessageNode<T>> firstMessages = first.messageNodes;
List<MessageNode<T>> secondMessages = second.messageNodes;
Preconditions.checkState(
firstMessages.get(0).equals(secondMessages.get(0)),
"attempting to merge disconnected instances of " + ThreadState.class.getCanonicalName());

List<MessageNode<T>> result =
new ArrayList<>(Math.max(firstMessages.size(), secondMessages.size()));

int firstLocation = 0;
int secondLocation = 0;

while (firstLocation < firstMessages.size() && secondLocation < secondMessages.size()) {
// if (firstLocation >= firstMessages.size() && secondLocation <)
MessageNode<T> firstNode = firstMessages.get(firstLocation);
MessageNode<T> secondNode = secondMessages.get(secondLocation);
if (firstNode.message.threadId().equals(secondNode.message.threadId())) {
if (firstNode.parentMessage() == null) {
Preconditions.checkState(secondNode.parentMessage() == null);
result.add(new MessageNode<>(firstNode.message()));
} else {
Preconditions.checkState(
Objects.equals(firstNode.parentMessage(), secondNode.parentMessage()));
Identifier threadId =
Objects.requireNonNull(firstNode.parentMessage()).message().threadId();

// search backward through the array for the parent message node
// the parent will almost always be in one of the last two positions of the array
for (int i = result.size() - 1; i >= 0; i--) {
MessageNode<T> mn = result.get(i);
if (mn.message().threadId().equals(threadId)) {
result.add(new MessageNode<>(firstNode.message(), mn));
}
}
}
firstLocation += 1;
secondLocation += 1;
}
}
return new ThreadState<>(result, first.messageFactory);
}

public static <T extends Message> ThreadState<T> of(T message) {
Expand Down Expand Up @@ -84,24 +134,41 @@ public T newMessageFromBot(Instant timestamp, String message) {
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
}

private int compare(int priority1, int priority2) {
return Integer.compare(priority1, priority2);
}

public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
}

public ThreadState<T> with(T message) {
return new ThreadState<>(this,this, message);
public @NewInstance ThreadState<T> withNewMessageFromBot(Instant timestamp, String message) {
T newMessage =
messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
return with(newMessage);
}

public ThreadState<T> withNewMessageFromUser(
Instant timestamp, String message, Identifier instanceId) {
T newMessage =
messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
return with(newMessage);
}

public ThreadState<T> with(ThreadState<T> thread,T message) {
return new ThreadState<>(this,thread, message);
public ThreadState<T> with(T message) {
return new ThreadState<>(this, this, message);
}

public List<T> messages() {
return messageNodes.stream().map(MessageNode::getMessage).collect(Collectors.toList());
return messageNodes.stream().map(MessageNode::message).collect(Collectors.toList());
}

public T tail() {
return messageNodes.get(messageNodes.size() - 1).getMessage();
return messageNodes.get(messageNodes.size() - 1).message();
}

public Identifier threadId() {
return tail().threadId();
}
}
2 changes: 0 additions & 2 deletions src/main/java/com/meta/cp4m/message/WAMessage.java
Expand Up @@ -9,8 +9,6 @@
package com.meta.cp4m.message;

import com.meta.cp4m.Identifier;
import org.checkerframework.checker.lock.qual.NewObject;

import java.time.Instant;

public record WAMessage(
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/com/meta/cp4m/store/ChatStore.java
Expand Up @@ -8,7 +8,6 @@

package com.meta.cp4m.store;

import com.meta.cp4m.Identifier;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.util.List;
Expand All @@ -26,7 +25,7 @@ public interface ChatStore<T extends Message> {

ThreadState<T> add(T message);

ThreadState<T> add(ThreadState<T> thread,T message);
ThreadState<T> add(ThreadState<T> thread);

long size();

Expand Down
6 changes: 4 additions & 2 deletions src/main/java/com/meta/cp4m/store/MemoryStore.java
Expand Up @@ -45,8 +45,10 @@ public ThreadState<T> add(T message) {
}

@Override
public ThreadState<T> add(ThreadState<T> thread, T message){
return this.store.asMap().compute(message.threadId(), (k,v) -> {return v.with(thread,message);});
public ThreadState<T> add(ThreadState<T> thread) {
return this.store
.asMap()
.compute(thread.threadId(), (k, v) -> v == null ? thread : ThreadState.merge(v, thread));
}

@Override
Expand Down
5 changes: 3 additions & 2 deletions src/test/java/com/meta/cp4m/llm/DummyLLMPlugin.java
Expand Up @@ -44,8 +44,9 @@ public String dummyResponse() {
}

@Override
public T handle(ThreadState<T> threadState) {
public ThreadState<T> handle(ThreadState<T> threadState) {
receivedThreadStates.add(threadState);
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse);
T message = threadState.newMessageFromBot(Instant.now(), dummyLLMResponse);
return threadState.with(message);
}
}

0 comments on commit b79cf49

Please sign in to comment.