In [6]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
import os
from nomic import embed
from supabase import create_client, Client
from openai import OpenAI
from prisma import Prisma
from groq import Groq
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
import time
import logging
import traceback

In [7]:
load_dotenv()
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
MODEL = 'llama-3.3-70b-versatile'

In [8]:
async def getMessages(userId, conversationId):
    db = Prisma ()
    await db.connect()
    messages = []
    conversation = await db.conversation.find_first_or_raise(
        where={
            'id': conversationId,
            'userId': userId,
        },
        include={
            'messages': True
        }
    )
    for msg in conversation.messages:
        messages.append({"role": msg.role, "content": msg.content})
    await db.disconnect()
    return messages

In [9]:
async def updateMessages(newMessage, conversationId):
    db = Prisma ()
    await db.connect()
    conversation = await db.conversation.update(
        where={
            'id': conversationId
        },
        data={
            'messages':{
                'create': newMessage
            }
        }
    )
    await db.disconnect()
    return conversation

In [10]:
async def retrieveDocs(query, userId):
    queryVector = embed.text(
        texts = [query],
        model='nomic-embed-text-v1.5',
        task_type='search_query',     
    )
    queryEmbedding = queryVector['embeddings'][0]
    #print(queryEmbedding)
    embedding_array = f"{queryEmbedding}"

    db = Prisma()
    await db.connect()
    try:
        docIds = await db.query_raw(
           '''
            SELECT "documentId"
            FROM "DocumentChunks"
            WHERE "userId" = CAST($1 AS TEXT)
            ORDER BY "embedding" <-> CAST($2 AS vector) ASC
            LIMIT 5
            ''',
            userId,
            embedding_array 
        )
        
        return docIds
            
    finally:
        await db.disconnect()

In [11]:
async def getDocContext(docIds):
    db = Prisma()
    await db.connect()
    context = ""
    for docId in docIds:
        doc = await db.document.find_first_or_raise(
            where={
                'id': docId['documentId']
            },
        )
        context += doc.content
    await db.disconnect()
    if (context != ""):
        context = "\n\n The following is the relevant user's journal contents: " + context
    return context

In [12]:
# Initialize FastAPI client
app = FastAPI()

origins = [
    "http://localhost:3000",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

In [13]:
class Message (BaseModel):
    role: str
    content: str
    
# Function to handle one step of the conversation
async def chat_with_assistant(input_str, userId, conversationId):
    start_time = time.time()
    # Get the documents and add them to the context
    docIds = await retrieveDocs(input_str, userId)
    print("Embedding query and retrieving docIds: " + str(time.time()-start_time))
    userDocContext = await getDocContext(docIds)
    messages = await getMessages(userId, conversationId)
    print("Retrieving doc contents: " + str(time.time()-start_time))
    
    # Append the user input to the chat history
    promptFormated = [{
        "role": "user",
        "content": "This is the user's key query: " + input_str + "| end of query \n\n" + userDocContext
    }]
    messages += promptFormated
    print(messages)
    # Create the response
    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        tools=th.get_tools(),
        temperature=1.2
    )
    print("Initial feed??: " + str(time.time()-start_time))

    tool_run = th.run_tools(response)
    print(f"Type of tool_run: {type(tool_run)}")

    promptFormated.extend(tool_run)
    print("With tools: " + str(time.time()-start_time))

    response = client.chat.completions.create(
        model=MODEL,
        messages=messages + promptFormated,
        tools=th.get_tools(),
        temperature=1.2
    )
    # Append the response to the chat history
    newMessage = {
        "role": "assistant",
        "content": response.choices[0].message.content
    }
    print("Final: " + str(time.time()-start_time))
    #print(messages)
    # Print the assistant's response
    await updateMessages(newMessage, conversationId)
    return (Message(role="assistant", content = response.choices[0].message.content))


In [14]:
class ChatRequest (BaseModel):
    conversationId: str
    title: str
    userId: str
    userMessage: str
    

@app.post("/chatbot/chat", response_model=Message)  # This line decorates 'translate' as a POST endpoint
async def chat(request: ChatRequest):
    try:
        # Call your translation function
        message = await chat_with_assistant(request.userMessage, request.userId, request.conversationId)
        return jsonable_encoder(message)
    except Exception as e:
        # Handle exceptions or errors during translation
        raise HTTPException(status_code=500, detail=str(e))

In [15]:
import nest_asyncio
import uvicorn

nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")


INFO:     Started server process [33875]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [33875]
