In [35]:
import json
import sagemaker as sm

sagemaker_session  = sm.session.Session()
region             = sagemaker_session._region_name

# This is the application metadata that contains any model endpoints we have access to.
application_metadata = {
    'models':[
        {'name':'Anthropic Claude (Bedrock)', 'endpoint':'anthropic.claude-v2'},
        {'name':'AI21 Labs J2 Ultra (Bedrock)', 'endpoint':'ai21.j2-jumbo-instruct'},
        {'name':'Amazon Titan Large (Bedrock)', 'endpoint':'amazon.titan-tg1-large'},
        {'name':'LLAMA-2 7B (SageMaker)', 'endpoint':'llama-2-7b'},
        {'name':'Falcon 7B Instruct (SageMaker)', 'endpoint':'falcon-7b-instruct'}
    ],
    'embedding':[
        {'name':'Amazon Titan Embed (Bedrock)', 'endpoint':'amazon.titan-e1t-medium'}
    ],
    'region':region,
}
json.dump(application_metadata, open('demo_fewshot_config.json', 'w'))



In [61]:
%%writefile demo_fewshot.py
import json
import time
import boto3
import numpy as np
import streamlit as st

from nltk.translate.bleu_score import sentence_bleu
st.set_page_config(layout="wide")


def get_description():
    """ Returns a description of the demo. """
    text = '''This Streamlit app demonstrates how to build effective few-shot prompts for large language models hosted on Amazon SageMaker and Bedrock. It allows the user to interactively construct a prompt, generate an initial response from the model, provide instructions to improve the response, and finalize the few-shot example. The code loads model metadata, provides helper functions to call SageMaker and Bedrock endpoints, builds the Streamlit interface, and logs the user's few-shot examples. Key steps include initializing the app state, getting user input on the prompt and instructions, querying the model endpoint with the prompt to generate responses, and checking consistency by testing similar prompts. Each few-shot example that meets consistency thresholds is added to the history. The user can iterate to build effective prompts that provide informative responses tailored to their use case.'''
    return text


def get_about():
    """ Returns an about section for the demo. """
    text = '''This demo shows the power of few-shot prompts when applied to compat models. You can leverage more powerful/flexible models when designing the example, and switch to a light-weight model when moving to production. This can be time and cost effective.'''
    return text


def download_logs():
    """ Downloads logs of user interactions as JSON. """
    data = json.dumps(st.session_state['history'], indent=4)
    return data


def initialize_session():
    """ Initializes session state variables. """
    st.session_state['prompt'] = ''
    st.session_state['similar_prompt'] = ''
    st.session_state['initial_response'] = ''
    st.session_state['initial_detail'] = ''
    st.session_state['full_response'] = ''
    st.session_state['full_detail'] = ''
    st.session_state['finished_response'] = ''
    st.session_state['feasible'] = ''
    st.session_state['consistent'] = ''
    st.session_state['feasible_model'] = ''
    st.session_state['history'] = []
    return 0


def get_metadata(filepath):
    """Loads JSON configuration for AWS demo connecting SageMaker and Bedrock models. Initializes data structures to track interactions. """ 
    md = json.load(open(filepath, 'r'))
    models = {d['name']: d['endpoint'] for d in md['models']}
    embedding = {d['name']: d['endpoint'] for d in md['embedding']}
    region = md['region']
    
    try:
        sagemaker = boto3.client('sagemaker-runtime', region_name=region)
    except:
        sagemaker = None
    
    try:
        bedrock = boto3.client(service_name='bedrock', region_name=region, endpoint_url='https://bedrock.us-east-1.amazonaws.com')
    except:
        bedrock = None
    md = {
        'models': models,
        'embedding': embedding,
        'region': region,
        'sagemaker': sagemaker,
        'bedrock': bedrock
    }
    return md


def get_model_tags(endpoint):
    """ Returns chat tokens/tags for models. """
    if 'claude' in endpoint:
        human_tag = '\n\nHuman:'
        robot_tag = '\n\nAssistant:'
    elif 'ai21.j2' in endpoint:
        human_tag = '\n##\n'
        robot_tag = ''
    elif 'titan' in endpoint:
        human_tag = '\nUser:'
        robot_tag = '\nBot:'
    else:
        human_tag = '\n\n'
        robot_tag = ''
    return human_tag, robot_tag


def call_bedrock(body, endpoint, attempts=5):
    """ Makes request to Bedrock endpoint. """ 
    for _ in range(attempts):
        try:
            response = st.session_state['md']['bedrock'].invoke_model(
                body=body,
                modelId=endpoint,
                accept='application/json',
                contentType='application/json'
            )
            return response
        except Exception as e:
            print(e)
            time.sleep(2)
            continue
    return None


def call_sagemaker(body, endpoint, eula=False):
    """ Makes request to SageMaker endpoint. """
    if eula is True:
        response = st.session_state['md']['sagemaker'].invoke_endpoint(
            EndpointName=endpoint,
            ContentType="application/json",
            Body=body,
            CustomAttributes="accept_eula=true"
        )
    else:
        response = st.session_state['md']['sagemaker'].invoke_endpoint(
            EndpointName=endpoint,
            ContentType="application/json",
            Body=body,
        )
    return response


@st.cache_data
def query_embedding(endpoint, payload):
    """ Gets embedding from model. """
    if 'titan' in endpoint:
        body = json.dumps({"inputText": payload['prompt']})
        response = call_bedrock(body, endpoint)
        response_body = json.loads(response.get('body').read()).get('embedding')
    else:
        response_body = None
    return response_body


@st.cache_data
def query_endpoint(endpoint, payload):
    """ Queries language model endpoint. """
    if 'llama' in endpoint:
        body = json.dumps({
            "inputs": [[{"role": "user", "content": payload['prompt']}]], 
            "parameters": {
                "max_new_tokens":payload['max_len'],
                "top_p":payload['top_p'],
                "temperature":payload['temp']
            }})
        response = call_sagemaker(body, endpoint, eula=True)
        response_body = json.loads(response['Body'].read().decode('utf-8'))[0]['generation']['content']
    elif 'falcon' in endpoint:
        body = json.dumps({
            "inputs": payload['prompt'], 
            "parameters": {
                "max_new_tokens":payload['max_len'],
                "top_p":min([0.99, payload['top_p']]),
                "temperature":payload['temp']
            }})
        response = call_sagemaker(body, endpoint)
        response_body = json.loads(response['Body'].read().decode('utf-8'))[0]['generated_text']
    elif 'titan' in endpoint:
        body = json.dumps({
            'inputText':payload['prompt'],
            'textGenerationConfig':{
                'maxTokenCount':payload['max_len'],
                'temperature':payload['temp'],
                'topP':payload['top_p'],
        }})
        response = call_bedrock(body, endpoint)
        response_body = json.loads(response.get('body').read()).get('results')[0].get('outputText')
    elif 'claude' in endpoint:
        body = json.dumps({
            'prompt':payload['prompt'],
            'max_tokens_to_sample':payload['max_len'],
            'temperature':payload['temp'],
            'top_p':payload['top_p'],
        })
        response = call_bedrock(body, endpoint)
        response_body = json.loads(response.get("body").read()).get("completion")
    elif 'ai21.j2' in endpoint:
        body = json.dumps({
            'prompt':payload['prompt'],
            'maxTokens':payload['max_len'],
            'temperature':payload['temp'],
            'topP':payload['top_p'],
            'stopSequences':['##']
        })
        response = call_bedrock(body, endpoint)
        response_body = json.loads(response.get("body").read()).get("completions")[0].get("data").get("text")
    else:
        response_body = None
    return response_body


def query_model(endpoint, payload):
    """ Formats prompt with conversation tags and queries model. """
    human_tag = payload['human_tag']
    robot_tag = payload['robot_tag']
    payload['prompt'] = human_tag + payload['prompt'] + robot_tag
    response = query_endpoint(endpoint, payload)
    return response


def query_model_w_instructions(endpoint, payload):
    """ Formats prompt with instructions and queries model. """
    human_tag = payload['human_tag']
    robot_tag = payload['robot_tag']
    payload['prompt'] = f'''{human_tag} Read the following DOCUMENT and expand using the instructions provided.
\n\nDOCUMENT:<document>{payload['prompt']}</document>\n\n
Read the following INSTRUCTIONS below and re-write the DOCUMENT above.
\n\nINSTRUCTIONS:<instructions>{payload['instructions']}</instructions>\n\n
Re-write the DOCUMENT.{robot_tag}'''
    response = query_endpoint(endpoint, payload)
    return response


def check_feasibility(endpoint, payload):
    """ Calculates feasibility metric (BLEU) between responses. """
    human_tag = payload['human_tag']
    robot_tag = payload['robot_tag']
    prompt = st.session_state['prompt']
    response_finished = st.session_state['finished_response']
    payload['prompt'] = f'''{human_tag} {prompt}{robot_tag} {response_finished}\n\n{human_tag} {prompt}{robot_tag}'''
    response = query_endpoint(endpoint, payload)
    references = [st.session_state['finished_response'].replace('\n',' ').replace('  ', ' ').split()]
    candidate_new = response.replace('\n',' ').replace('  ', ' ').split()
    bleu_new = sentence_bleu(references, candidate_new)
    bleu_new = np.round(bleu_new*100, 2)
    return bleu_new


def check_consistency(endpoint, payload):
    """ Checks consistency between similar prompts. """
    human_tag = payload['human_tag']
    robot_tag = payload['robot_tag']
    prompt = st.session_state['prompt']
    response = st.session_state['finished_response']
    similar_prompt = st.session_state['similar_prompt']
    similar_prompt = ' '.join([st.session_state['similar_prompt'], st.session_state['initial_detail'], st.session_state['full_detail']])
    payload['prompt'] = f'''{human_tag} {prompt}{robot_tag} {response}\n\n{human_tag} {similar_prompt}{robot_tag}'''
    response = query_endpoint(endpoint, payload)
    return response


def sidebar():
    """ Initializes sidebar. """
    st.sidebar.header('About this demo')
    st.sidebar.write(get_about())
    st.sidebar.header('User Preferences')
    
    col1, col2 = st.sidebar.columns(2)
    with col1:
        if st.button('Clear History'):
            st.session_state['history'] = []
    with col2:
        st.download_button('Download Logs', download_logs(), file_name="logs_fewshot.json", mime="application/json")
        pass
    embed_name = st.sidebar.selectbox('Select Embedding Model', st.session_state['md']['embedding'].keys())
    model_name = st.sidebar.selectbox('Select Generation Model', st.session_state['md']['models'].keys())
    if 'Claude' in model_name:
        temp_start = 1.
    elif 'AI21' in model_name:
        temp_start = .75
    elif 'Titan' in model_name:
        temp_start = .05
    else:
        temp_start = .25
    max_len = st.sidebar.slider('Max Generation Length', 500, 9000, 2000, 500) 
    top_p = st.sidebar.slider('Top p', 0., 1., 1., .01)
    temp = st.sidebar.slider('Temperature', 0.01, 1., temp_start, .01)
    st.sidebar.write('---')
    st.sidebar.subheader('Example:')
    st.sidebar.write('**Prompt:** Write a short blog about using Amazon SageMaker to create and train an image classification model.')
    st.sidebar.write('**Initial Instructions:** Add a section showing example code for creating and training a SageMaker estimator.')
    st.sidebar.write('**Finishing Instructions:** Copy-edit the blog post re-write it in the style of an AWS blog.')
    st.sidebar.write('**Similar Prompt:** Write a short blog about using Amazon SageMaker to create and train an XGBoost model.')
    
    params = {
        'endpoint': st.session_state['md']['models'][model_name],
        'embedding': st.session_state['md']['embedding'][embed_name],
        'max_len': max_len,
        'top_p': top_p,
        'temp': temp,
        'model_name': model_name,
        'embedding_name': embed_name
    }
    return params


def main(params):
    """ Main function that executes app logic. """
    st.title('Create Few-Shot Prompts with Amazon Bedrock and SageMaker')
    with st.expander('Tell me more about this works'):
        st.write(get_description())
    payload = {
        'prompt': '',
        'instructons': '',
        'response': '',
        'max_len': params['max_len'],
        'temp': params['temp'],
        'top_p': params['top_p']
    }
    endpoint = params['endpoint']
    embedding = params['embedding']
    model_name = params['model_name']
    human_tag, robot_tag = get_model_tags(endpoint=endpoint)
    payload['human_tag'] = human_tag
    payload['robot_tag'] = robot_tag
    
    prompt = st.chat_input("Let's build a few-shot example! Write a prompt.")
    
    col1, col2 = st.columns(2)
    with col1:
        if prompt:
            st.session_state['prompt'] = prompt
            payload['prompt'] = prompt
            payload['instructions'] = ''
            st.session_state['initial_response'] = query_model(endpoint, payload)
        
        if st.session_state['prompt']:
            with st.chat_message('user'):
                st.write(f"**Prompt:** {st.session_state['prompt']}")
            
        if st.session_state['initial_response']:
            with st.chat_message('assistant'):
                st.session_state['initial_response'] = st.text_area("**Initial Response:**", st.session_state['initial_response'], height=400)
            st.session_state['initial_detail'] = st.text_input("**Initial Instructions:**", "")
        
            if st.button('Update Response'):
                payload['prompt'] = st.session_state['initial_response']
                payload['instructions'] = st.session_state['initial_detail']
                st.session_state['full_response'] = query_model_w_instructions(endpoint, payload)
        
        if st.session_state['full_response']:
            with st.chat_message('assistant'):
                st.session_state['full_response'] = st.text_area("**Full Response:**", st.session_state['full_response'], height=400, key='ta_1')
            st.session_state['full_detail'] = st.text_input("**Finishing Instructions:**", "")
        
            if st.button('Finalize Response'):
                payload['prompt'] = st.session_state['full_response']
                payload['instructions'] = st.session_state['full_detail']
                st.session_state['finished_response'] = query_model_w_instructions(endpoint, payload)
        
        if st.session_state['finished_response']:
            with st.chat_message('assistant'):
                st.session_state['finished_response'] = st.text_area("**Full Response:**", st.session_state['finished_response'], height=400, key='ta_2')
            
            st.session_state['similar_prompt'] = st.text_input('To check the consistency of this few-shot example, type a new (similar) prompt.')
            btn1, btn2, btn3 = st.columns(3)
            with btn1:
                if st.button('Check Feasibility'):
                    st.session_state['feasible'] = check_feasibility(endpoint, payload)
                    st.session_state['feasible_model'] = model_name
            with btn2:
                if st.button('Check Consistency'):
                    st.session_state['consistent'] = check_consistency(endpoint, payload)
            with btn3:
                if st.button('Add Few-Shot Example') and st.session_state['feasible'] and st.session_state['consistent']:
                    st.session_state['history'].append({
                        'prompt': st.session_state['prompt'],
                        'response': st.session_state['finished_response'],
                        'feasible': st.session_state['feasible'],
                        'feasible_model': st.session_state['feasible_model'],
                        'similar_prompt': st.session_state['similar_prompt'],
                        'similar_response': st.session_state['consistent'],
                        'instructions': [st.session_state['initial_detail'], st.session_state['full_detail']]
                    })
        
        if st.session_state['feasible']:
            with st.chat_message('assistant'):
                st.write(f"The feasibility score (BLEU score) is: {st.session_state['feasible']}%. You should interpret this as the model's ability to parrot the finished response given only the initial prompt.\n\nNOTE: Lowering the temperature may result in a higher Feasibility score. Additionally, it should be noted that this metric is somewhat independent of the model's ability to perform well with similar prompts.")
        
        if st.session_state['consistent']:
            with st.chat_message('assistant'):
                st.text_area("**Does this response follow the instructions you detailed above?**", st.session_state['consistent'], height=400, key='ta_3')
        
    with col2:
        for msg in st.session_state['history']:
            with st.chat_message('user'):
                st.write(f"**Prompt ({st.session_state['feasible']}% feasible with {st.session_state['feasible_model']}):** {msg['prompt']}")
                st.write(f"**Response:**\n{msg['response']}")


if __name__ == '__main__':
    FILEPATH = 'demo_fewshot_config.json'
    
    if 'prompt' not in st.session_state:
        initialize_session()
    
    st.session_state['md'] = get_metadata(filepath=FILEPATH)
    params = sidebar()
    main(params)

Overwriting aws_demo_fewshot.py
