Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI Tutor] CT-562: Add s3 system prompt #58486

Merged
merged 15 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 20 additions & 9 deletions apps/src/aiTutor/chatApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ const logViolationDetails = (response: OpenaiChatCompletionMessage) => {
export async function postOpenaiChatCompletion(
messagesToSend: OpenaiChatCompletionMessage[],
levelId?: number,
tutorType?: AITutorTypesValue
tutorType?: AITutorTypesValue,
levelInstructions?: string
): Promise<OpenaiChatCompletionMessage | null> {
const payload = levelId
? {levelId: levelId, messages: messagesToSend, type: tutorType}
: {messages: messagesToSend, type: tutorType};
? {
levelId: levelId,
messages: messagesToSend,
type: tutorType,
levelInstructions,
}
: {messages: messagesToSend, type: tutorType, levelInstructions};

const response = await HttpClient.post(
CHAT_COMPLETION_URL,
Expand All @@ -57,7 +63,7 @@ export async function postOpenaiChatCompletion(
if (response.ok) {
return await response.json();
} else {
return null;
throw new Error('Error getting chat completion response');
}
}

Expand All @@ -74,14 +80,13 @@ const formatForChatCompletion = (
* to `postOpenaiChatCompletion`, then returns the status of the response and assistant message if successful.
*/
export async function getChatCompletionMessage(
systemPrompt: string,
formattedQuestion: string,
chatMessages: ChatCompletionMessage[],
levelId?: number,
tutorType?: AITutorTypesValue
tutorType?: AITutorTypesValue,
levelInstructions?: string
): Promise<ChatCompletionResponse> {
const messagesToSend = [
{role: Role.SYSTEM, content: systemPrompt},
...formatForChatCompletion(chatMessages),
{role: Role.USER, content: formattedQuestion},
];
Expand All @@ -90,7 +95,8 @@ export async function getChatCompletionMessage(
response = await postOpenaiChatCompletion(
messagesToSend,
levelId,
tutorType
tutorType,
levelInstructions
);
} catch (error) {
MetricsReporter.logError({
Expand All @@ -100,7 +106,12 @@ export async function getChatCompletionMessage(
});
}

if (!response) return {status: Status.ERROR};
if (!response)
return {
status: Status.ERROR,
assistantResponse:
'There was an error processing your request. Please try again.',
};

switch (response.status) {
case ShareFilterStatus.Profanity:
Expand Down
3 changes: 0 additions & 3 deletions apps/src/aiTutor/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ import {
AITutorInteractionStatus as Status,
} from '@cdo/apps/aiTutor/types';

export const systemPrompt =
'You are a tutor in a high school classroom where the students are learning Java using the Code.org curriculum. Answer their questions in plain, easy-to-understand English. Do not write any code. Do not answer the question if it is not about Java or computer programming. Please format all responses in Markdown where appropriate. Use four spaces for indentation instead of tabs.';

// Initial messages we set when the user selects a tutor type.
// General Chat
export const generalChatMessage = {
Expand Down
14 changes: 2 additions & 12 deletions apps/src/aiTutor/redux/aiTutorRedux.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import _ from 'lodash';
import {getChatCompletionMessage} from '@cdo/apps/aiTutor/chatApi';
import {createSlice, PayloadAction, createAsyncThunk} from '@reduxjs/toolkit';
import {systemPrompt as baseSystemPrompt} from '@cdo/apps/aiTutor/constants';
import {savePromptAndResponse} from '../interactionsApi';
import {
Role,
Expand All @@ -20,7 +19,6 @@ export interface AITutorState {
aiResponse: string | undefined;
chatMessages: ChatCompletionMessage[];
isWaitingForChatResponse: boolean;
chatMessageError: boolean;
isChatOpen: boolean;
}

Expand All @@ -42,7 +40,6 @@ const initialState: AITutorState = {
aiResponse: '',
chatMessages: initialChatMessages,
isWaitingForChatResponse: false,
chatMessageError: false,
isChatOpen: false,
};

Expand Down Expand Up @@ -72,15 +69,8 @@ export const askAITutor = createAsyncThunk(
scriptId: aiTutorState.aiTutor.scriptId,
};

let systemPrompt = baseSystemPrompt;
const levelInstructions = instructionsState.instructions.longInstructions;

if (levelInstructions.length > 0) {
systemPrompt +=
'\n Here are the student instructions for this level: ' +
levelInstructions;
}

const storedMessages = aiTutorState.aiTutor.chatMessages;
const newMessage: ChatCompletionMessage = {
role: Role.USER,
Expand All @@ -91,11 +81,11 @@ export const askAITutor = createAsyncThunk(

const formattedQuestion = formatQuestionForAITutor(chatContext);
const chatApiResponse = await getChatCompletionMessage(
systemPrompt,
formattedQuestion,
storedMessages,
levelContext.levelId,
chatContext.actionType
chatContext.actionType,
levelInstructions
);
thunkAPI.dispatch(
updateLastChatMessage({
Expand Down
1 change: 1 addition & 0 deletions apps/src/aiTutor/views/AssistantMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const AssistantMessage: React.FC<AssistantMessageProps> = ({message}) => {

const shouldRenderFeedbackButtons =
message.id &&
message.status !== Status.ERROR &&
message.status !== Status.PROFANITY_VIOLATION &&
message.status !== Status.PII_VIOLATION;

Expand Down
5 changes: 1 addition & 4 deletions apps/src/aiTutor/views/UserMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ const PROFANITY_VIOLATION_USER_MESSAGE =
'This chat has been hidden because it is inappropriate.';
const PII_VIOLATION_USER_MESSAGE =
'This chat has been hidden because it contains personal information.';
const ERROR_USER_MESSAGE =
'There was an error getting a response. Please try again.';

const statusToStyleMap = {
[Status.OK]: style.userMessage,
[Status.UNKNOWN]: style.userMessage,
[Status.PROFANITY_VIOLATION]: style.profaneMessage,
[Status.PII_VIOLATION]: style.piiMessage,
[Status.ERROR]: style.errorMessage,
[Status.ERROR]: style.userMessage,
};

const getMessageText = (status: string, chatMessageText: string) => {
Expand All @@ -34,7 +32,6 @@ const getMessageText = (status: string, chatMessageText: string) => {
case Status.PII_VIOLATION:
return PII_VIOLATION_USER_MESSAGE;
case Status.ERROR:
return ERROR_USER_MESSAGE;
default:
return chatMessageText;
}
Expand Down
70 changes: 64 additions & 6 deletions dashboard/app/controllers/openai_chat_controller.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
class OpenaiChatController < ApplicationController
S3_AI_BUCKET = 'cdo-ai'.freeze
S3_TUTOR_SYSTEM_PROMPT_PATH = 'tutor/system_prompt.txt'.freeze

include OpenaiChatHelper
authorize_resource class: false

def s3_client
@s3_client ||= AWS::S3.create_client
end

# POST /openai/chat_completion
def chat_completion
unless has_required_messages_param?
Expand All @@ -16,16 +23,19 @@ def chat_completion
# If the content is inappropriate, we skip sending to OpenAI and instead hardcode a warning response on the front-end.
return render(status: :ok, json: {status: filter_result.type, flagged_content: filter_result.content}) if filter_result

# The system prompt is stored server-side so we need to prepend it to the student's messages
system_prompt = read_file_from_s3(S3_TUTOR_SYSTEM_PROMPT_PATH)

# Determine if the level is validated and fetch test file contents if it is
test_file_contents = ""
if validated_level?
level_id = params[:levelId]
test_file_contents = get_validated_level_test_file_contents(level_id)
messages = params[:messages]
messages.first["content"] = messages.first["content"] + " The contents of the test file are: #{test_file_contents}"
messages.second["content"] = "The student's code is: " + messages.second["content"]
else
messages = params[:messages]
end

updated_system_prompt = add_content_to_system_prompt(system_prompt, params[:levelInstructions], test_file_contents)
messages = prepend_system_prompt(updated_system_prompt, params[:messages])

response = OpenaiChatHelper.request_chat_completion(messages)
chat_completion_return_message = OpenaiChatHelper.get_chat_completion_response_message(response)
return render(status: chat_completion_return_message[:status], json: chat_completion_return_message[:json])
Expand All @@ -39,7 +49,55 @@ def validated_level?
params[:type].present? && params[:type] == 'validation'
end

def get_validated_level_test_file_contents(level_id)
def add_content_to_system_prompt(system_prompt, level_instructions, test_file_contents)
if level_instructions.present?
system_prompt += "\n Here are the student instructions for this level: #{level_instructions}"
end

if test_file_contents.present?
system_prompt += "\n The contents of the test file are: #{test_file_contents}"
end

system_prompt
end

private def prepend_system_prompt(system_prompt, messages)
system_prompt_message = {
content: system_prompt,
role: "system"
}

messages.unshift(system_prompt_message)
messages
end

private def read_file_from_s3(key_path)
full_s3_path = "#{S3_AI_BUCKET}/#{key_path}"
cache_key = "s3_file:#{full_s3_path}"
unless Rails.env.development?
cached_content = CDO.shared_cache.read(cache_key)
return cached_content if cached_content.present?
end

if Rails.env.development?
local_path = File.join("local-aws", S3_AI_BUCKET, key_path)
if File.exist?(local_path)
puts "Note: Reading AI prompt from local file: #{key_path}"
return File.read(local_path)
end
end

# Note: We will hit this codepath in dev if the file is not found locally
content = s3_client.get_object(bucket: S3_AI_BUCKET, key: key_path).body.read

# In production and test, cache the content after fetching it from S3
unless Rails.env.development?
CDO.shared_cache.write(cache_key, content, expires_in: 1.hour)
end
return content
end

private def get_validated_level_test_file_contents(level_id)
level = Level.find(level_id)

unless level
Expand Down