# Mistral Chat Completions with Elasticsearch Inference API

This notebook demonstrates how to set up a Mistral chat completion inference endpoint in Elasticsearch and stream chat responses using the inference API

## Prerequisites
- Elasticsearch cluster 
- Elasticsearch API key
- Mistral API key

In [None]:
%pip install requests tqdm elasticsearch

In [30]:
import requests
import json
from typing import Generator
from tqdm import tqdm
from elasticsearch import Elasticsearch
import getpass

## Configuration

Set up your Elasticsearch and Mistral API credentials. For security, consider using environment variables.

In [31]:
# Credentials - Enter your API keys securely
ELASTICSEARCH_URL = input("Enter your Elasticsearch URL: ").strip()
ELASTICSEARCH_API_KEY = getpass.getpass("Enter your Elasticsearch API key: ")
MISTRAL_API_KEY = getpass.getpass("Enter your Mistral API key: ")

In [None]:
# Configurations, no need to change these values
MISTRAL_MODEL = "mistral-large-latest"  # Mistral model to use
INFERENCE_ENDPOINT_NAME = (
    "mistral-embeddings-chat-completion"  # Name for the inference endpoint
)

ELASTICSEARCH_HEADERS = {
    "Authorization": f"ApiKey {ELASTICSEARCH_API_KEY}",
    "Content-Type": "application/json",
}

In [None]:
# Initialize Elasticsearch client
es_client = Elasticsearch(hosts=[ELASTICSEARCH_URL], api_key=ELASTICSEARCH_API_KEY)

## Create the Inference Endpoint

Create the Mistral chat completion endpoint if it doesn't exist.

In [None]:
print(
    f"Creating Mistral inference endpoint: {INFERENCE_ENDPOINT_NAME} at {ELASTICSEARCH_URL}"
)

try:
    # Create the inference endpoint using the Elasticsearch client
    response = es_client.inference.put(
        task_type="chat_completion",
        inference_id=INFERENCE_ENDPOINT_NAME,
        body={
            "service": "mistral",
            "service_settings": {"api_key": MISTRAL_API_KEY, "model": MISTRAL_MODEL},
        },
    )

    print("Inference endpoint created successfully!")
    print(f"Response: {json.dumps(response.body, indent=2)}")

except Exception as e:
    print(f"‚ùå Error creating inference endpoint: {str(e)}")
    # If the endpoint already exists, that's okay
    if "already exists" in str(e).lower():
        print("‚úÖ Inference endpoint already exists, continuing...")

## Chat Streaming Functions

Let's create functions to handle streaming chat responses from the inference endpoint.

In [None]:
def stream_chat_completion(
    host: str, endpoint_name: str, messages: list
) -> Generator[str, None, None]:
    url = f"{host}/_inference/chat_completion/{endpoint_name}/_stream"

    payload = {"messages": messages}

    try:
        response = requests.post(
            url, json=payload, headers=ELASTICSEARCH_HEADERS, stream=True
        )
        response.raise_for_status()

        for line in response.iter_lines(decode_unicode=True):
            if line:
                line = line.strip()

                # Handle Server-Sent Events format
                # Skip event lines like "event: message"
                if line.startswith("event:"):
                    continue

                # Process data lines
                if line.startswith("data: "):
                    data_content = line[6:]  # Remove "data: " prefix

                    # Skip empty data or special markers
                    if not data_content.strip() or data_content.strip() == "[DONE]":
                        continue

                    try:
                        chunk_data = json.loads(data_content)

                        # Extract the content from the Mistral response structure
                        if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
                            choice = chunk_data["choices"][0]
                            if "delta" in choice and "content" in choice["delta"]:
                                content = choice["delta"]["content"]
                                if content:  # Only yield non-empty content
                                    yield content

                    except json.JSONDecodeError as json_err:
                        # If JSON parsing fails, log the error but continue
                        print(f"\nJSON decode error: {json_err}")
                        print(f"Problematic data: {data_content}")
                        continue

    except requests.exceptions.RequestException as e:
        yield f"Error: {str(e)}"


print("‚úÖ Streaming function defined!")

## Testing the Inference Endpoint 

Now let's test our inference endpoint with a simple question. This will demonstrate streaming responses are working well from Elasticsearch.

In [None]:
user_question = "What SNES games had a character on a skateboard throwing axes?"

messages = [
    {
        "role": "system",
        "content": "You are a helpful gaming expert that provides concise answers about video games.",
    },
    {"role": "user", "content": user_question},
]

print(f"User: {user_question}")
print("Assistant: \n")

for chunk in stream_chat_completion(
    ELASTICSEARCH_URL, INFERENCE_ENDPOINT_NAME, messages
):
    print(chunk, end="", flush=True)

# Context Engineering with Elasticsearch

In this section, we'll demonstrate how to:
1. Index documents into Elasticsearch 
2. Search for relevant context
3. Use retrieved documents to enhance our chat completions with contextual information

This approach combines retrieval-augmented generation (RAG) with Mistral's chat capabilities through Elasticsearch.

## Step 1: Index some documents

First, let's create an Elasticsearch index to store our documents with both text content and vector embeddings for semantic search.

In [None]:
INDEX_NAME = "snes-games"
snes_mapping = {
    "mappings": {
        "properties": {
            "id": {"type": "keyword"},
            "title": {"type": "text", "copy_to": "description_semantic"},
            "publishers": {"type": "keyword"},
            "year_US": {"type": "keyword"},
            "year_JP": {"type": "keyword"},
            "category": {"type": "keyword", "copy_to": "description_semantic"},
            "description": {"type": "text", "copy_to": "description_semantic"},
            "description_semantic": {"type": "semantic_text"},
        }
    }
}

try:
    # Create the index using the Elasticsearch client
    response = es_client.indices.create(index=INDEX_NAME, body=snes_mapping)

    print(f"‚úÖ Index '{INDEX_NAME}' created successfully!")
    print(f"Response: {json.dumps(response.body, indent=2)}")

except Exception as e:
    print(f"‚ùå Error creating index '{INDEX_NAME}': {str(e)}")
    # If the index already exists, that's okay
    if (
        "already exists" in str(e).lower()
        or "resource_already_exists_exception" in str(e).lower()
    ):
        print(f"‚úÖ Index '{INDEX_NAME}' already exists, continuing...")

In [None]:
def bulk_index_games(games_batch):
    if not games_batch:
        return 0
    bulk_body = ""
    for game_doc in games_batch:
        index_meta = {"index": {"_index": INDEX_NAME, "_id": game_doc["id"]}}
        bulk_body += json.dumps(index_meta) + "\n" + json.dumps(game_doc) + "\n"
    bulk_url = f"{ELASTICSEARCH_URL}/_bulk"
    bulk_headers = {**ELASTICSEARCH_HEADERS, "Content-Type": "application/x-ndjson"}
    try:
        response = requests.post(bulk_url, data=bulk_body, headers=bulk_headers)
        response.raise_for_status()
        result = response.json()
        return sum(
            1
            for item in result.get("items", [])
            if item.get("index", {}).get("status") in [200, 201]
        )
    except:
        return 0


csv_file_path = "snes_games.csv"
BATCH_SIZE = 50
try:
    with open(csv_file_path, "r", encoding="utf-8") as file:
        file.readline()
        actual_headers = [
            "ID",
            "Title",
            "Publishers",
            "Year_North_America",
            "Year_JP",
            "Category",
            "Description",
        ]
        total_indexed, current_batch = 0, []
        lines = [line for line in file if line.strip()]

        for line in tqdm(lines, desc="Indexing SNES games"):
            line = line.strip()
            parts, current_part, in_quotes = [], "", False

            for char in line:
                if char == '"':
                    in_quotes = not in_quotes
                    current_part += char
                elif char == "|" and not in_quotes:
                    parts.append(current_part)
                    current_part = ""
                else:
                    current_part += char
            if current_part:
                parts.append(current_part)

            row = {}
            for i, header in enumerate(actual_headers):
                value = parts[i].strip() if i < len(parts) else ""
                if value.startswith('"') and value.endswith('"'):
                    value = value[1:-1]
                row[header] = value

            game_doc = {
                "id": row.get("ID", ""),
                "title": row.get("Title", ""),
                "publishers": row.get("Publishers", ""),
                "year_US": row.get("Year_North_America", ""),
                "year_JP": row.get("Year_JP", ""),
                "category": row.get("Category", ""),
                "description": row.get("Description", ""),
            }
            current_batch.append(game_doc)
            if len(current_batch) >= BATCH_SIZE:
                total_indexed += bulk_index_games(current_batch)
                current_batch = []
        if current_batch:
            total_indexed += bulk_index_games(current_batch)
except:
    pass

## Step 2: Search for Relevant Context

Now let's create a function to search our indexed documents for relevant context based on a user's query.

In [None]:
def search_documents(query: str, max_results: int = 3) -> list:
    search_body = {
        "size": max_results,
        "query": {"semantic": {"field": "description_semantic", "query": query}},
    }

    try:
        response = es_client.search(index=INDEX_NAME, body=search_body)

        return response.body["hits"]["hits"]

    except Exception as e:
        print(f"‚ùå Error searching documents: {str(e)}")
        return []

In [None]:
test_query = "What SNES games had a character on a skateboard throwing axes?"
print(f"üîç Searching for: '{test_query}'")

search_results = search_documents(test_query, 5)

for i, doc in enumerate(search_results, 1):
    print(
        f"\n{i}. {doc['_source']['title']} - {doc['_source']['description']} (Score: {doc['_score']:.2f})"
    )

## Step 3: RAG-Enhanced Chat Function

Now let's create a function that combines document retrieval with our Mistral chat completion for contextual responses.

In [None]:
def rag_chat(user_question: str, max_context_docs: int = 10) -> str:
    context_docs = search_documents(user_question, max_context_docs)

    context_text = ""
    if context_docs:
        context_text = "\n\nRelevant context information:\n"
        for i, doc in enumerate(context_docs, 1):
            context_text += f"\n{i}. {doc['_source']}\n"

    system_prompt = """
        You are a helpful assistant that answers about Super Nintendo games. 
        Use the provided context information to answer the user's question accurately. 
        If the context doesn't contain relevant information, you can use your general knowledge.
        """

    user_prompt = user_question
    if context_text:
        user_prompt = f"{context_text}\n\nQuestion: {user_question}"

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    full_response = ""
    for chunk in stream_chat_completion(
        ELASTICSEARCH_URL, INFERENCE_ENDPOINT_NAME, messages
    ):
        print(chunk, end="", flush=True)
        full_response += chunk

    return full_response

In [None]:
test_question = "What SNES games had a character on a skateboard throwing axes?"
rag_chat(test_question)