# Prompt Cache Lab - Multiturn Chat

![graph](./output.png)

# 사전 설정

In [None]:
%load_ext autoreload
%autoreload 2
%pip install ipywidgets
%pip install boto3 botocore --upgrade
%pip install pandas
%pip install matplot

In [None]:
import boto3, botocore
retry_config = botocore.config.Config(
    retries={"max_attempts": 1, "mode": "standard"}
)
session = boto3.Session(
    region_name='us-west-2'
)

bedrock_client = session.client("bedrock-runtime", config=retry_config)

print ("\n== FM lists ==")
model_list = session.client("bedrock").list_foundation_models()['modelSummaries']
print('\n'.join([model['modelId'] for model in model_list if model['modelId'].startswith('anth')]))

### Multi-turn 채팅에서 Prompt Cache 활용하기

In [None]:
import pandas as pd
import copy
class ConversationManager:
    def __init__(self, model_id, document):
        # System Prompt 정의
        self._system_prompt = [
            {
                "text": "주어진 문서의 내용을 바탕으로 답변을 합니다.."
            },
            {
                "text": f"## document:\n{document} "
            },
            {
                "cachePoint": {
                    "type": "default"
                }
            }
        ]
        self._model_id = model_id
        self._history = []
        self._usage = []
        
    def query(self, query):
        self._history.append({
            'role': 'user',
            'content': [
                {
                    'text': query
                }
            ]
        })
        response = self._chat()
        self._history.append(response)
    
    def get_converstaion(self):
        conversation = []
        for message in self._history:
            if message.get('role') == 'user':
                conversation.append("User: " + message.get('content')[0].get('text'))
            else:
                conversation.append("Bot: " + message.get('content')[0].get('text'))
        return conversation
    
    def get_usage(self):
        df = pd.DataFrame(self._usage)
        df.columns = [ "LatencyMs", "CacheRead", "CacheWrite", "Inputs", "Outputs"]
        return df
        
    def _chat(self):
        count = 2
        message_list = copy.deepcopy(self._history)
        for message in reversed(message_list):
            if count == 0:
                break
            if message.get('role') == 'user':
                message.get('content').append({
                    "cachePoint": {
                        "type": "default"
                    }
                })
                count = count - 1
        
        print(message_list)

        # 추론에 필요한 Hyperparameter 정의
        inference_config = {
            'maxTokens': 4096,
            'temperature': 0,
            'topP': 1
        }

        # Converse API 호출
        response = bedrock_client.converse(
            system=self._system_prompt,
            messages=message_list,
            modelId=self._model_id,
            inferenceConfig=inference_config
        )
        
        print(response['usage'])

        self._usage.append((response['metrics']['latencyMs'], response['usage'].get('cacheReadInputTokens', 0), response['usage'].get('cacheWriteInputTokens', 0), response['usage']['inputTokens'], response['usage']['outputTokens']))
        return response['output']['message']
        



In [None]:
with open('documents/prompt_caching_article.md', 'r', encoding='utf-8') as f:
    document = f.read()
len(document)

In [None]:
questions = [
    "이 글의 전체 내용을 요약해주세요.",
    "본문의 기술을 활용하여 해결할 수 있는 과제는 무엇인가요?",
    "본문을 이해하기 위해 필요한 배경 지식은 무엇인가요?",
    "이 배경지식을 갖추었는지 확인하기 위한 질문을 만들어주세요.",
    "전체 내용을 이해하였는지 확인할 수 있는 질문을 다섯 개 만들어주세요."
]

In [None]:
model_id = 'anthropic.claude-3-5-haiku-20241022-v1:0'

conversation = ConversationManager(model_id=model_id, document=document)
for q in questions:
    conversation.query(q)

In [None]:
print("\n\n".join(conversation.get_converstaion()))

In [None]:
conversation.get_usage()

In [None]:
df = conversation.get_usage()
columns_to_stack = ["CacheRead", "CacheWrite", "Inputs", "Outputs"]

import matplotlib.pyplot as plt
# Create figure with two y-axes
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()

# Plot stacked bars on first y-axis
bottom_stack = [0] * len(df)
for column in columns_to_stack:
    ax1.bar(df.index, df[column], bottom=bottom_stack, label=column)
    bottom_stack = [sum(x) for x in zip(bottom_stack, df[column])]

# Plot line on second y-axis  
ax2.plot(df.index, df['LatencyMs'], color='red', marker='o', label='Latency')

# Set labels and title
ax1.set_xlabel("Question Turn")
ax1.set_ylabel("Token Usage")
ax2.set_ylabel("Latency")
plt.title("Cache Read/Write and Latency")

# Combine legends
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower right')

plt.xticks(df.index, labels=(df.index + 1), rotation=0)
plt.ylim(0)
plt.show()