Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a websocket for streaming from the chat UI #679

Merged
merged 32 commits into from Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9b88976
Initial pass at backend changes to support agents
sabaimran Mar 11, 2024
352168d
Customize default behaviors for conversations without agents or with …
sabaimran Mar 11, 2024
6ab6493
Add a new web client route for viewing all agents
sabaimran Mar 11, 2024
8e1445b
Use agent_id for getting correct agent
sabaimran Mar 11, 2024
290712c
Add web UI views for agents
sabaimran Mar 13, 2024
c45030a
Fix agent view
sabaimran Mar 14, 2024
3caf0a7
Spruce up the 404 page and improve the overall layout for agents pages
sabaimran Mar 14, 2024
345afec
Resolve merge conflicts/ use agent_slug instead of agent_id for lookup
sabaimran Mar 14, 2024
d734be6
Rename agents_page -> agent_page
sabaimran Mar 15, 2024
1b3fc68
Fix unit test for adding the filename to the compiled markdown entry
sabaimran Mar 15, 2024
416feb1
Fix layout of agent, agents pages
sabaimran Mar 15, 2024
7fc484b
Merge branch 'master' of github.com:khoj-ai/khoj into features/custom…
sabaimran Mar 15, 2024
724557f
Merge branch 'master' of github.com:khoj-ai/khoj into features/add-ag…
sabaimran Mar 15, 2024
36af977
Add the websockets dependency to pyproject.toml
sabaimran Mar 20, 2024
a346f79
Add support for chatting via the web socket connection
sabaimran Mar 20, 2024
d4e83b0
Update the web UI for the chat interface to establish a connection vi…
sabaimran Mar 20, 2024
70ad789
Use a common method for sending a generic message to the client from …
sabaimran Mar 20, 2024
d84188b
Scroll down when a message is added in the chat interface's handle st…
sabaimran Mar 20, 2024
255b69d
Add a comma delimeter between outputted search queries
sabaimran Mar 20, 2024
6ba0d8e
Add a connected notification if the websocket is connected
sabaimran Mar 20, 2024
d38089a
Merge with origin
sabaimran Mar 22, 2024
2399d91
Merge migrations
sabaimran Mar 22, 2024
2061761
Merge branch 'features/customize-chat-with-agents' of github.com:khoj…
sabaimran Mar 23, 2024
6b4c4f1
Merge branch 'features/add-agents-ui' of github.com:khoj-ai/khoj into…
sabaimran Mar 23, 2024
8edbd70
Let the name, slug of the default agent be Khoj, khoj
sabaimran Mar 23, 2024
4deb849
Merge branch 'features/add-agents-ui' of github.com:khoj-ai/khoj into…
sabaimran Mar 23, 2024
47fc7e1
Rebase with matser
sabaimran Apr 2, 2024
228ad68
Merge with origin/master
sabaimran Apr 2, 2024
867e100
Remove superfluous newline
sabaimran Apr 2, 2024
bf1187f
Use new online/websearch logic and add agent to chat_metadata
sabaimran Apr 2, 2024
f484266
resolve merge conflict in chat.html
sabaimran Apr 2, 2024
b4f71e0
Add timeout after 10 minutes of inactivity on socket
sabaimran Apr 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -76,6 +76,7 @@ dependencies = [
"django-phonenumber-field == 7.3.0",
"phonenumbers == 8.13.27",
"markdownify ~= 0.11.6",
"websockets == 12.0",
]
dynamic = ["version"]

Expand Down
8 changes: 8 additions & 0 deletions src/khoj/configure.py
Expand Up @@ -21,6 +21,7 @@
from starlette.requests import HTTPConnection

from khoj.database.adapters import (
AgentAdapters,
ClientApplicationAdapters,
ConversationAdapters,
SubscriptionState,
Expand Down Expand Up @@ -229,11 +230,16 @@ def configure_server(

state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent()
initialize_content(regenerate, search_type, init, user)
except Exception as e:
raise e


def setup_default_agent():
AgentAdapters.create_default_agent()


def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
# Initialize Content from Config
if state.search_models:
Expand Down Expand Up @@ -262,13 +268,15 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
def configure_routes(app):
# Import APIs here to setup search types before while configuring server
from khoj.routers.api import api
from khoj.routers.api_agents import api_agents
from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config
from khoj.routers.indexer import indexer
from khoj.routers.web_client import web_client

app.include_router(api, prefix="/api")
app.include_router(api_chat, prefix="/api/chat")
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_config, prefix="/api/config")
app.include_router(indexer, prefix="/api/v1/index")
app.include_router(web_client)
Expand Down
95 changes: 91 additions & 4 deletions src/khoj/database/adapters/__init__.py
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor

from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
Expand All @@ -37,6 +38,7 @@
UserRequests,
UserSearchModelConfig,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
Expand Down Expand Up @@ -391,6 +393,79 @@ async def aget_client_application_by_id(client_id: str, client_secret: str):
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()


class AgentAdapters:
DEFAULT_AGENT_NAME = "khoj"
DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png"

@staticmethod
async def aget_agent_by_id(agent_id: int):
return await Agent.objects.filter(id=agent_id).afirst()

@staticmethod
async def aget_agent_by_slug(agent_slug: str):
return await Agent.objects.filter(slug__iexact=agent_slug.lower()).afirst()

@staticmethod
def get_agent_by_slug(slug: str, user: KhojUser = None):
agent = Agent.objects.filter(slug=slug).first()
# Check if agent is public or created by the user
if agent and (agent.public or agent.creator == user):
return agent
return None

@staticmethod
def get_all_accessible_agents(user: KhojUser = None):
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at")

@staticmethod
async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]:
get_all_accessible_agents = sync_to_async(
lambda: Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at").all(),
thread_sensitive=True,
)
agents = await get_all_accessible_agents()
return await sync_to_async(list)(agents)

@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
if agent == AgentAdapters.get_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used
return None
return agent

@staticmethod
def get_default_agent():
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()

@staticmethod
def create_default_agent():
default_conversation_config = ConversationAdapters.get_default_conversation_config()
default_personality = prompts.personality.format(current_date="placeholder")

if Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).exists():
agent = Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
agent.tuning = default_personality
agent.chat_model = default_conversation_config
agent.save()
return agent

# The default agent is public and managed by the admin. It's handled a little differently than other agents.
return Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
public=True,
managed_by_admin=True,
chat_model=default_conversation_config,
tuning=default_personality,
tools=["*"],
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
)

@staticmethod
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()


class ConversationAdapters:
@staticmethod
def get_conversation_by_user(
Expand Down Expand Up @@ -431,7 +506,14 @@ def get_conversation_by_id(conversation_id: int):
return Conversation.objects.filter(id=conversation_id).first()

@staticmethod
async def acreate_conversation_session(user: KhojUser, client_application: ClientApplication = None):
async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None
):
if agent_slug:
agent = await AgentAdapters.aget_agent_by_slug(agent_slug)
if agent is None:
raise HTTPException(status_code=400, detail="Invalid agent id")
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
return await Conversation.objects.acreate(user=user, client=client_application)

@staticmethod
Expand All @@ -446,7 +528,7 @@ async def aget_conversation_by_user(
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")

if await conversation.aexists():
return await conversation.afirst()
return await conversation.prefetch_related("agent").afirst()

return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)

Expand Down Expand Up @@ -606,9 +688,14 @@ async def aget_conversation_starters(user: KhojUser, max_results=3):
return random.sample(all_questions, max_results)

@staticmethod
def get_valid_conversation_config(user: KhojUser):
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
conversation_config = ConversationAdapters.get_conversation_config(user)

if conversation.agent and conversation.agent.chat_model:
conversation_config = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)

if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()

Expand Down
2 changes: 2 additions & 0 deletions src/khoj/database/admin.py
Expand Up @@ -6,6 +6,7 @@
from django.http import HttpResponse

from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
Expand Down Expand Up @@ -50,6 +51,7 @@ class KhojUserAdmin(UserAdmin):
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(Agent)


@admin.register(Entry)
Expand Down
53 changes: 53 additions & 0 deletions src/khoj/database/migrations/0031_agent_conversation_agent.py
@@ -0,0 +1,53 @@
# Generated by Django 4.2.10 on 2024-03-13 07:38

import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0030_conversation_slug_and_title"),
]

operations = [
migrations.CreateModel(
name="Agent",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("tuning", models.TextField()),
("avatar", models.URLField(blank=True, default=None, max_length=400, null=True)),
("tools", models.JSONField(default=list)),
("public", models.BooleanField(default=False)),
("managed_by_admin", models.BooleanField(default=False)),
("slug", models.CharField(blank=True, default=None, max_length=200, null=True)),
(
"chat_model",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodeloptions"),
),
(
"creator",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="conversation",
name="agent",
field=models.ForeignKey(
blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, to="database.agent"
),
),
]
57 changes: 46 additions & 11 deletions src/khoj/database/models/__init__.py
@@ -1,7 +1,11 @@
import uuid
from random import choice

from django.contrib.auth.models import AbstractUser
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.signals import pre_save
from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField

Expand Down Expand Up @@ -69,6 +73,47 @@ class Type(models.TextChoices):
renewal_date = models.DateTimeField(null=True, default=None, blank=True)


class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"

max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)


class Agent(BaseModel):
creator = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
name = models.CharField(max_length=200)
tuning = models.TextField()
avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
public = models.BooleanField(default=False)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)


@receiver(pre_save, sender=Agent)
def verify_agent(sender, instance, **kwargs):
# check if this is a new instance
if instance._state.adding:
if Agent.objects.filter(name=instance.name, public=True).exists():
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
if Agent.objects.filter(name=instance.name, creator=instance.creator).exists():
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")

slug = instance.name.lower().replace(" ", "-")
observed_random_numbers = set()
while Agent.objects.filter(slug=slug).exists():
random_number = choice([i for i in range(0, 10000) if i not in observed_random_numbers])
observed_random_numbers.add(random_number)
slug = f"{slug}-{random_number}"
instance.slug = slug


class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
Expand Down Expand Up @@ -153,17 +198,6 @@ class ModelType(models.TextChoices):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)


class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"

max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)


class UserConversationConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
Expand All @@ -180,6 +214,7 @@ class Conversation(BaseModel):
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
title = models.CharField(max_length=200, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)


class ReflectiveQuestion(BaseModel):
Expand Down