In [214]:
import os
from dotenv import load_dotenv
from mistralai import Mistral
from collections.abc import Sequence
from typing import cast
import psycopg2

In [215]:
load_dotenv()
PATH_DIR = os.getenv("DOCS_PATH")
CHUNK_SIZE = 1000
OVERLAP_SIZE = 50
MODEL_EMBEDDING = "mistral-embed"
MODEL_GENERATION = "mistral-small-2506" #mistral small 3.2
BATCH_SIZE = 20 

In [216]:
# read the docs and return as a single string
def read_docs() -> str:
    # search for all .md files in the directory
    docs : str = ""
    for root, _, files in os.walk(str(PATH_DIR)):
        for file in files:
            if file.endswith(".md"):
                with open(os.path.join(root, file), "r") as f:
                    docs += f.read() + "\n"
    return docs

In [217]:
# small sanity check to see if the number of .md files is correct
def print_number_of_md_files() -> None:
    num : int = 0
    for _,_, files in os.walk(str(PATH_DIR)):
        for file in files:
            if file.endswith(".md"):
                num += 1
    print(f"Number of Markdown files in {PATH_DIR}: {num}")
print_number_of_md_files()

Number of Markdown files in /Users/amelieartmann/Documents/devguard-chatbot: 73


In [218]:
""" # chunk given string into smaller chunks of chosen size with overlap
def chunking(docs: str) -> list[str]:
    chunks : list = []
    for i in range(0, len(docs), CHUNK_SIZE - OVERLAP_SIZE):
        chunks.append(docs[i:i+CHUNK_SIZE])
    return chunks """

# do chunking without splitting up words
def better_chunking(docs: str) -> list[str]:
    chunks : list[str] = []
    start : int = 0
    while start < len(docs):
        end : int = start + CHUNK_SIZE
        if end >= len(docs):
            chunks.append(docs[start:])
            break
        else:
            # find last space before end
            last_space : int = docs.rfind(" ", start, end)
            # if no space found, just split at end
            if last_space == -1:
                last_space = end
            chunks.append(docs[start:last_space])
            start = last_space - OVERLAP_SIZE
    return chunks

In [219]:
api_key = os.getenv("API_KEY")
client = Mistral(api_key=api_key)

In [220]:
# get the embeddings for a list of chunks, return a list of embeddings
def get_embeddings(chunks: list[str]) -> list[list[float]]:
    embeddings: list[list[float]] = []
    # call the api with batches to avoid hitting the rate limit
    for i in range(0, len(chunks), BATCH_SIZE):
        batch = chunks[i:i + BATCH_SIZE]
        response = client.embeddings.create(
            model=MODEL_EMBEDDING,
            inputs=batch
        )
        for data_item in response.data:
            embedding: Sequence[float] = cast(Sequence[float], data_item.embedding)
            embeddings.append(list(embedding))
    return embeddings

# get embedding for a single chunk of text
def text_embedding(chunk: str) -> list[float]:
    # call the mistral api to get the embedding for the given text
    response = client.embeddings.create(
        model=MODEL_EMBEDDING,
        inputs=[chunk]
    )
    embedding = cast(Sequence[float], response.data[0].embedding)
    return list(embedding)

In [221]:
# first read the docs
docs : str = read_docs()

# then chunk the docs
chunks : list[str] = better_chunking(docs)
print(f"Number of better chunks: {len(chunks)}")

Number of better chunks: 168


In [222]:
# connect to the database
conn = psycopg2.connect(
    dbname="embedding_db",
    user=os.getenv("DB_USER"),
    password=os.getenv("DB_PASSWORD"),
    host="localhost",
    port = 5432
)

In [223]:
# insert the chunks and their embeddings into the database
cur = conn.cursor()
embeddings = get_embeddings(chunks)
for chunk, embedding in zip(chunks, embeddings):
    cur.execute(
        "INSERT INTO documents (content, embedding) VALUES (%s, %s)",
        (chunk, embedding)
    )

conn.commit()

In [224]:
def retrieve_top_k(query:str, k:int=3) -> list[tuple[str,float]] | None:
    """
    Supported distance functions (for non-binary vectors):
    <+> L1 distance
    <-> L2 distance
    <=> cosine distance -> use 1 - cosine distance
    <#> inner product -> multiply by -1, since else it negative inner product
    """
    query_embedding : list[float] = text_embedding(query)
    try:
        cur.execute("""
        SELECT content,
            1 -(embedding <=> %s::vector) AS distance
        FROM documents
        ORDER BY embedding <=> %s::vector
        LIMIT %s;
        """, (query_embedding, query_embedding, k))
        results: list[tuple[str, float]] = cur.fetchall()
    
        # Optional: print results
        #for content, similarity in results:
        #    print(content, similarity)
        
        return results
    except Exception as e:
        print("SQL ERROR:", e)
        conn.rollback()


In [228]:
def generate_response(query: str, context: list[tuple[str, float]]) -> str:
    # format context
    context_text = "\n\n".join(
        f"- {content}" for content, _ in context
    )

    prompt = f"""
    Use ONLY the following context to answer the question.
    If the answer cannot be answered using the context, say you don't know.
    Context:
    {context_text}

    Question:
    {query}
    """

    message= [{"role": "user", "content": prompt}]

    """
        Toggling the safe prompt will prepend your messages with the following system prompt:
        Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.
    """
    response = client.chat.complete(
        model=MODEL_GENERATION,
        messages=message,
        safe_prompt=True,
        temperature=0.0 # no randomness, since we want the same answer for the same question and context
    )
    return str(response.choices[0].message.content)

In [230]:
# Example usage:
query = "What is DevGuard?"
top_k_results = retrieve_top_k(query, k=5)

if top_k_results is not None:
    test_query = generate_response(query, top_k_results)
    print(test_query)

DevGuard is a tool built by developers for developers to simplify vulnerability management. It aims to integrate security seamlessly into the software development lifecycle, making security practices accessible and efficient for everyone, regardless of their security expertise.
