<a href="https://colab.research.google.com/github/higherbar-ai/open-chat-studio-sim/blob/main/src/replay-conversations.ipynb" target="_parent"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"/></a>

# Replay conversations

This notebook replays conversations using an Open Chat Studio experiment.

## Running in Google Colab

Before running this notebook, you'll need to configure a series of secrets in Google Colab; click the key button in the left sidebar, and be sure to click the toggle to give this notebook access to each of the secrets. These are the secrets used by this notebook:

- `OCS_API_KEY`: your Open Chat Studio API key
- `ATHINA_API_KEY`: your Athina API key (optional; only if you want to export results to Athina)
- `EXPERIMENT_ID`: the ID of the experiment you want to issue queries to
- `PARTICIPANT_ID`: the participant ID to use for the queries (defaults to "open-chat-studio-sim")

## Running in a local environment

When you first run the first code cell in this notebook, it will output a template configuration file for you. Edit that file to specify your configuration parameters (see above for their descriptions). 

## Selecting or uploading your conversations to replay

The second code cell will prompt you to select or upload a .csv file with the conversations you want to replay. This file should follow the format of experiment session exports in Open Chat Studio. The columns we rely on here are:

- `Message ID`: the unique identifier for the chat message
- `Message Type`: the chat message type (`human` or `ai`)
- `Message Content`: the chat message
- `Session ID`: the unique session ID for the conversation

## Where results go

The results of the queries will be saved to a file called `replayed_conversations.csv`. If you're running in Google Colab, click the folder button in the sidebar to view and download that file. If you're running locally, it will be output to the `ocs` subdirectory off of your local directory. 

If an Athina API key is configured, the results will also be exported to an Athina dataset.

In [None]:
# install support for Google Colab
%pip install colab-or-not

# set log level to WARNING
import logging
logging.basicConfig(level=logging.WARNING)

# Initialize our environment
from colab_or_not import NotebookBridge
env = NotebookBridge(
    github_repo="higherbar-ai/open-chat-studio-sim",
    requirements_path="requirements.txt",
    module_paths=["src/ocs_api.py", "src/ocs_simulation_support.py"],
    config_path="~/.ocs/.env",
    config_template={
        "OCS_API_KEY": "",
        "ATHINA_API_KEY": "",
        "EXPERIMENT_ID": "",
        "PARTICIPANT_ID": "open-chat-studio-sim",
    }
)
env.setup_environment()

# Internal configuration
api_timeout_seconds = 300      # how long to give API calls before timing out
api_num_retries = 3            # how many times to retry API calls before giving up
api_retry_delay_seconds = 2    # how long to wait between retries
continue_on_error = True       # whether to record errors and continue (if False, errors will halt execution)

# Get API keys from environment
ocs_api_key = env.get_setting("OCS_API_KEY")
athina_api_key = env.get_setting("ATHINA_API_KEY")
experiment_id = env.get_setting("EXPERIMENT_ID")
participant_id = env.get_setting("PARTICIPANT_ID", "open-chat-studio-sim")

# Validate required configuration
if not all([ocs_api_key, experiment_id]):
    raise ValueError("Please supply at least OCS_API_KEY and EXPERIMENT_ID in your secrets or configuration file.")

# Output files to ~/ocs directory if local, otherwise /content if Google Colab
if env.is_colab:
    output_path_prefix = "/content"
else:
    import os
    output_path_prefix = os.path.expanduser("~/ocs")
    os.makedirs(output_path_prefix, exist_ok=True)

# Initialize OCS API support
from ocs_api import OCSAPIClient    # type: ignore[import]
ocs_api_client = OCSAPIClient(
    api_key=ocs_api_key, 
    timeout_seconds=api_timeout_seconds, 
    num_retries=api_num_retries, 
    retry_wait_seconds=api_retry_delay_seconds
)

# Report results
print(f"Configuration loaded for {'Colab' if env.is_colab else 'local'} environment, OCS API initialized.")

## Select or upload your conversations to replay

The code cell below will prompt you to select or upload a .csv file. This file should follow the format of experiment session exports in Open Chat Studio. The columns we rely on here are:

- `Message ID`: the unique identifier for the chat message
- `Message Type`: the chat message type (`human` or `ai`)
- `Message Content`: the chat message
- `Session ID`: the unique session ID for the conversation

In [None]:
# prompt for the CSV file with conversations to replay
to_replay_files = env.get_input_files("CSV file with conversations to replay")

# check for one CSV file
if len(to_replay_files) != 1:
    raise ValueError("Please select exactly one CSV file with conversations to replay.")
elif not to_replay_files[0].endswith(".csv"):
    raise ValueError("Please select a CSV file with conversations to replay.")

to_replay_file = str(to_replay_files[0])

## Replay conversations

The following code block reads each human message from the .csv selected or uploaded above, fetches a new AI response from each (using _the original conversation history_), and saves the results to the `replayed_conversations.csv` file. 

`replayed_conversations.csv` will have the following columns:

- `message_id`: the unique identifier for the original query
- `session_id`: the unique identifier for the _original_ experiment session being replayed (links conversations)
- `replay_session_id`: the unique identifier for the _new_ experiment session created during replay
- `query`: the query sent to the AI assistant
- `response`: the response received from the AI assistant
- `orig_response`: the original response received from the AI assistant
- `context`: the raw conversation history included at the time of replay

In [None]:
import csv
import pandas as pd
import json
import os

# load input file using pandas
conversations_to_replay = pd.read_csv(to_replay_file)

# run through and replay each step of each conversation
results = []
num_sessions = 0
orig_session_id = ""
orig_messages = []
user_message = ""
user_message_id = ""
session_id = ""
for index, row in conversations_to_replay.iterrows():
    if orig_session_id != row["Session ID"]:
        # initialize for new conversations
        orig_session_id = row["Session ID"]
        orig_messages = []
        user_message = ""
        user_message_id = ""
        num_sessions += 1
    
    if row["Message Type"] == "human":
        # remember user message, but only process when we get to the original AI response
        user_message = row["Message Content"]
        user_message_id = row["Message ID"]
    elif user_message and row["Message Type"] == "ai":
        # remember original AI response
        orig_response = row["Message Content"]
    
        # report out
        print(f"Replaying message {user_message_id} for session {orig_session_id}...")
        
        # replay conversation step, catching and logging any errors
        try:
            # create a new session for the step, including the original conversation history
            api_response = ocs_api_client.create_experiment_session(experiment_id, participant_id, orig_messages)
            session_id = api_response["id"]
        
            # send the user message to the experiment
            api_response = ocs_api_client.send_new_api_message(experiment_id, user_message, session_id)
            response = api_response["response"]
        except Exception as e:
            if continue_on_error:
                # log the error and continue to the next message
                logging.error(f"Continuing following error fetching conversation response: {str(e)}")
                response = f"ERROR: {str(e)}"
            else:
                # raise the error to halt execution
                raise

        # add to results
        results.append({
            "message_id": user_message_id,
            "session_id": orig_session_id,
            "replay_session_id": session_id,
            "query": user_message,
            "response": response,
            "orig_response": orig_response,
            "context": json.dumps(orig_messages)
        })
        
        # add original exchange to message history
        orig_messages.append({
            "role": "user",
            "content": user_message
        })
        orig_messages.append({
            "role": "assistant",
            "content": orig_response
        })

# save results to output .csv file
output_file = os.path.join(output_path_prefix, "replayed_conversations.csv")
output_rows = []
fieldnames=["message_id", "session_id", "replay_session_id", "query", "response", "orig_response", "context"]
with open(output_file, "w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC, escapechar='\\')
    writer.writeheader()
    for result in results:
        # output and record for potential next steps
        writer.writerow(result)
        output_rows.append(result)

# report results
print()
print(f"Replayed {num_sessions} conversations and saved {len(results)} results to {output_file}.")

## Optional: Export results to Athina dataset

If an Athina API key is configured, the results can be exported to an Athina dataset. The dataset will be named `replayed-conversations-{experiment_id}-{timestamp}` and will contain the rows from the `replayed_conversations.csv` file.

In [None]:
from ocs_simulation_support import athina_create_dataset    # type: ignore[import]

# optionally export the results to an Athina dataset
if athina_api_key:
    # push new dataset to Athina
    dataset_name = f"replayed-conversations-{experiment_id}-{pd.Timestamp.now().strftime('%Y%m%d%H%M%S')}"
    dataset_description = f"Replayed conversations for experiment {experiment_id} at {pd.Timestamp.now()}"
    try:
        dataset = athina_create_dataset(athina_api_key=athina_api_key, dataset_name=dataset_name, dataset_description=dataset_description, dataset_rows=output_rows)
    except Exception as e:
        print(f"Failed to create Athina dataset: {e}")
    else:
        print(f"Results exported to Athina dataset {dataset.id} (name: {dataset_name}).")