Skip to content

Commit

Permalink
Adding AWS SDK and using for LlamaAWSHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
colinmccloskey committed Aug 31, 2023
1 parent 2af0d77 commit 4eb2a07
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 32 deletions.
53 changes: 53 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,65 @@
<maven.compiler.target>20</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>2.20.45</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb-enhanced</artifactId>
<version>2.20.26</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sagemakerruntime</artifactId>
<version>2.20.26</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>lambda</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sqs</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>iam</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sagemakergeospatial</artifactId>
<version>2.20.78</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>secretsmanager</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>32.1.1-jre</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk</artifactId>
<version>1.11.1000</version>
</dependency>
<dependency>
<groupId>org.jetbrains</groupId>
<artifactId>annotations</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,19 @@ public static void calculateAuthorizationHeaders(
String awsIdentity, String awsSecret, String awsRegion, String awsService
) {
try {
String bodySha256 = hex(sha256(body));
String bodySha256 = "UNSIGNED-PAYLOAD"; // hex(sha256(body));
String isoJustDate = isoDateTime.substring(0, 8); // Cut the date portion of a string like '20150830T123600Z';

headers.put("Host", host);
headers.put("X-Amz-Content-Sha256", bodySha256);
headers.put("X-Amz-Date", isoDateTime);

// (1) https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
// String canon = "POST" +
// System.out.println(canon.getBytes(StandardCharsets.UTF_8));

List<String> canonicalRequestLines = new ArrayList<>();
canonicalRequestLines.add(method);
// canonicalRequestLines.add(method);
canonicalRequestLines.add(path);
canonicalRequestLines.add(query);
List<String> hashedHeaders = new ArrayList<>();
Expand All @@ -88,6 +91,7 @@ public static void calculateAuthorizationHeaders(
stringToSignLines.add(credentialScope);
stringToSignLines.add(canonicalRequestHash);
String stringToSign = stringToSignLines.stream().collect(Collectors.joining("\n"));
System.out.println(stringToSign);

// (3) https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html
byte[] kDate = hmac(("AWS4" + awsSecret).getBytes(StandardCharsets.UTF_8), isoJustDate);
Expand Down
65 changes: 58 additions & 7 deletions src/main/java/com/meta/chatbridge/llm/LlamaAWSHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,78 @@

package com.meta.chatbridge.llm;

import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.store.LLMContextManager;
import com.meta.chatbridge.store.MessageStack;
import java.util.*;
import java.time.Instant;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import java.nio.charset.Charset;

public class LlamaAWSHandler implements LLMHandler {

private final LLMContextManager context;
private final LlamaTokenizer tokenizer;
private final String endpoint;
private final String contentType;
private final Region region;

public LlamaAWSHandler(LLMContextManager context) {
public LlamaAWSHandler(LLMContextManager context, String endpoint, Region region) {
this.context = context;
this.tokenizer = new LlamaTokenizer();
this.endpoint = endpoint;
this.contentType = "application/json";
this.region = region;
}

@Override
public Message handle(MessageStack messageStack) {
// Take history
// get number of tokens and truncate as needed
// Pass to LLM
// Return response
Message message = (Message) messageStack.messages().get(0);
List<Message> messagesToPass = tokenizer.getCappedMessages(context.getContext(), messageStack);

return null;
StringBuilder messagesString = new StringBuilder("{" + "\"inputs\": [[ { \"role\": \"system\", \"content\": \"" + context.getContext() + "\" },");
for (Message m : messagesToPass){
messagesString.append("{\"role\": \"").append(m.role().toString().toLowerCase()).append("\", \"content\": \"").append(m.message()).append("\"}");
}

String payload = messagesString + "]],\n\"parameters\": {\"max_new_tokens\": 256, \"top_p\": 0.9, \"temperature\": 0.6}}";

SageMakerRuntimeClient runtimeClient = SageMakerRuntimeClient.builder()
.region(region)
.build();

String responseString = invokeSpecificEndpoint(runtimeClient, endpoint, payload, contentType);
Message last = messagesToPass.get(messagesToPass.size() - 1);
Identifier tempMessageID = Identifier.from("-1"); // Replace with MID once sent by messageHandler

Message response =
new FBMessage(
Instant.now(),
tempMessageID,
last.recipientId(),
last.senderId(),
responseString,
Message.Role.ASSISTANT);

return response;

}

public static String invokeSpecificEndpoint(SageMakerRuntimeClient runtimeClient, String endpointName, String payload, String contentType) {

InvokeEndpointRequest endpointRequest = InvokeEndpointRequest.builder()
.endpointName(endpointName)
.contentType(contentType)
.body(SdkBytes.fromString(payload, Charset.defaultCharset()))
.customAttributes("accept_eula=true")
.build();

InvokeEndpointResponse response = runtimeClient.invokeEndpoint(endpointRequest);
return (response.body().asString(Charset.defaultCharset()));
}
}
27 changes: 8 additions & 19 deletions src/main/java/com/meta/chatbridge/llm/LlamaTokenizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.store.MessageStack;
import com.meta.chatbridge.message.Message;

import java.util.*;

public class LlamaTokenizer {

private final int MAX_TOKENS = 4096;
private final int MAX_TOKENS = 4000; // We subtract 96 to be safe when accounting for different counting of tokens between models
private final int MAX_RESPONSE_TOKENS = 256;
private final int MAX_CONTEXT_TOKENS = 1536;

Expand Down Expand Up @@ -49,23 +49,12 @@ private int getContextStringTokens(String contextString) {
*
* @param context The "system message" context string, possibly needs to be updated to be more complicated and different across Llama and ChatGPT
* @param history The history of messages. The last message is the user question, do not remove it.
* @return The capped messages that can be sent to the OpenAI API.
* @return The capped messages that can be sent to the Llama endpoint.
*/
private List<Message> capMessages(String context,
List<Message> history) {
var availableTokens = MAX_TOKENS - getContextStringTokens(context);
var cappedHistory = new ArrayList<>(history);
// Update history to use messagestack

// Message contextMessage =
// new Message(
// timestamp,
// messageId,
// senderId,
// recipientId,
// context,
// Message.Role.SYSTEM);
// Where should this be set? Do we want to bother with initializing a context message for every user?
public List<Message> getCappedMessages(String context,
MessageStack history) {
var availableTokens = MAX_TOKENS - MAX_RESPONSE_TOKENS - getContextStringTokens(context);
var cappedHistory = new ArrayList<>(history.messages());

var tokens = getTokenCount(cappedHistory);

Expand Down Expand Up @@ -105,7 +94,7 @@ private int getTokenCount(List<Message> messages) {
* @return The number of tokens in the message
*/
private int getMessageTokenCount(Message message) {
var tokens = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokens = 4; // Also figure out the basic token overhead here

tokens += tokenizer.encode(message.role().toString()).size();
tokens += tokenizer.encode(message.message()).size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
package com.meta.chatbridge.store;

public class LLMContextManager {
private static String context = "";
private final String context;

public static void setContext(String newContext) {
context = newContext;
public LLMContextManager(String context) {
this.context = context;
}

public static String getContext() {
public String getContext() {
return context;
}
}
43 changes: 43 additions & 0 deletions src/test/java/com/meta/chatbridge/llm/LlamaAWSHandlerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.chatbridge.llm;

import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.store.LLMContextManager;
import com.meta.chatbridge.store.MessageStack;

import java.time.Instant;

import software.amazon.awssdk.regions.Region;

public class LlamaAWSHandlerTest {

public static void main(String[] args) {
Region region = Region.US_EAST_2;
String endpoint = "jumpstart-dft-meta-textgeneration-llama-2-7b-f";
LLMContextManager context = new LLMContextManager("Always answer with emojis");

LlamaAWSHandler llmhandler = new LlamaAWSHandler(context, endpoint, region);

FBMessage userMessage = new FBMessage(
Instant.now(),
Identifier.from("123"),
Identifier.from("456"),
Identifier.from("789"),
"How to go from San Francisco to NY?",
Message.Role.USER
);
MessageStack<FBMessage> messageStack = MessageStack.of(userMessage);

Message response = llmhandler.handle(messageStack);
System.out.println(response.message());
}
}

0 comments on commit 4eb2a07

Please sign in to comment.