Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Jun 27, 2023
1 parent ba8a327 commit f6a9b14
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs, **model_kwargs)


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
id = "openai"
name = "OpenAI"
Expand Down
17 changes: 13 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional, Tuple, Union, Type
from typing import Dict, Optional, Tuple, Type, Union

from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
Expand All @@ -12,6 +12,7 @@
AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider]
ProviderDict = Dict[str, AnyProvider]


def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
if not log:
log = logging.getLogger()
Expand All @@ -33,6 +34,7 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:

return providers


def get_em_providers(
log: Optional[Logger] = None,
) -> EmProvidersDict:
Expand Down Expand Up @@ -78,17 +80,24 @@ def decompose_model_id(
provider_id, local_model_id = model_id.split(":", 1)
return (provider_id, local_model_id)

def get_lm_provider(model_id: str, lm_providers: LmProvidersDict) -> Tuple[str, Type[BaseProvider]]:

def get_lm_provider(
model_id: str, lm_providers: LmProvidersDict
) -> Tuple[str, Type[BaseProvider]]:
"""Gets a two-tuple (<local-model-id>, <provider-class>) specified by a
global model ID."""
return _get_provider(model_id, lm_providers)

def get_em_provider(model_id: str, em_providers: EmProvidersDict) -> Tuple[str, Type[BaseEmbeddingsProvider]]:

def get_em_provider(
model_id: str, em_providers: EmProvidersDict
) -> Tuple[str, Type[BaseEmbeddingsProvider]]:
"""Gets a two-tuple (<local-model-id>, <provider-class>) specified by a
global model ID."""
return _get_provider(model_id, em_providers)


def _get_provider(model_id: str, providers: ProviderDict):
provider_id, local_model_id = decompose_model_id(model_id, providers)
provider = providers.get(provider_id, None)
return local_model_id, provider
return local_model_id, provider
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .clear import ClearChatHandler
from .default import DefaultChatHandler
from .generate import GenerateChatHandler
from .learn import LearnChatHandler
from .learn import LearnChatHandler
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
from typing import Dict, Type

from .base import BaseChatHandler
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain

from .base import BaseChatHandler


class AskChatHandler(BaseChatHandler):
"""Processes messages prefixed with /ask. This actor will
Expand All @@ -27,7 +28,9 @@ def create_llm_chain(
):
self.llm = provider(**provider_params)
self.chat_history = []
self.llm_chain = ConversationalRetrievalChain.from_llm(self.llm, self._retriever)
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm, self._retriever
)

async def _process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
Expand Down
22 changes: 14 additions & 8 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import argparse
from asyncio import AbstractEventLoop
import time
import traceback
from typing import Dict, Optional, Type
from asyncio import AbstractEventLoop

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict, Optional, Type
from uuid import uuid4

from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from jupyter_ai.config_manager import ConfigManager, Logger

# necessary to prevent circular import
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from jupyter_ai.handlers import RootChatHandler

Expand All @@ -19,7 +19,13 @@ class BaseChatHandler:
"""Base ChatHandler class containing shared methods and attributes used by
multiple chat handler classes."""

def __init__(self, log: Logger, loop: AbstractEventLoop, config_manager: ConfigManager, root_chat_handlers: Dict[str, 'RootChatHandler']):
def __init__(
self,
log: Logger,
loop: AbstractEventLoop,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
):
self.log = log
self.loop = loop
self.config_manager = config_manager
Expand All @@ -40,7 +46,7 @@ async def process_message(self, message: HumanChatMessage):
formatted_e = traceback.format_exc()
response = f"Sorry, something went wrong and I wasn't able to index that path.\n\n```\n{formatted_e}\n```"
self.reply(response, message)

async def _process_message(self, message: HumanChatMessage):
"""Processes the message passed by the `Router`"""
raise NotImplementedError("Should be implemented by subclasses.")
Expand Down Expand Up @@ -96,4 +102,4 @@ def parse_args(self, message):
response = f"{self.parser.format_usage()}"
self.reply(response, message)
return None
return args
return args
5 changes: 3 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .base import BaseChatHandler
from jupyter_ai.models import ClearMessage

from .base import BaseChatHandler


class ClearChatHandler(BaseChatHandler):
async def _process_message(self, _):
for handler in self._root_chat_handlers.values():
Expand All @@ -9,4 +11,3 @@ async def _process_message(self, _):

handler.broadcast_message(ClearMessage())
break

5 changes: 3 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import Dict, List, Type

from .base import BaseChatHandler
from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain import ConversationChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.schema import AIMessage
from langchain.memory import ConversationBufferWindowMemory

from .base import BaseChatHandler

SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
Expand Down
61 changes: 34 additions & 27 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,10 @@
import json
import os
import time
from typing import List, Coroutine, Any
from typing import Any, Coroutine, List

from dask.distributed import Client as DaskClient

from jupyter_core.paths import jupyter_data_dir

from langchain import FAISS
from langchain.text_splitter import (
RecursiveCharacterTextSplitter, PythonCodeTextSplitter,
MarkdownTextSplitter, LatexTextSplitter
)
from langchain.schema import Document

from .base import BaseChatHandler
from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
from jupyter_ai.document_loaders.directory import split, get_embeddings
from jupyter_ai.document_loaders.directory import get_embeddings, split
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
from jupyter_core.paths import jupyter_data_dir
Expand All @@ -30,13 +18,16 @@
RecursiveCharacterTextSplitter,
)

from .base import BaseChatHandler

INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices")
METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json")


def compute_delayed(delayed):
return delayed.compute()


class LearnChatHandler(BaseChatHandler, BaseRetriever):
def __init__(self, root_dir: str, dask_client: DaskClient, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -54,7 +45,7 @@ def __init__(self, root_dir: str, dask_client: DaskClient, *args, **kwargs):
self.metadata = IndexMetadata(dirs=[])
self.prev_em_id = None
self.embeddings = None

if not os.path.exists(INDEX_SAVE_DIR):
os.makedirs(INDEX_SAVE_DIR)

Expand Down Expand Up @@ -116,11 +107,19 @@ def _build_list_response(self):

async def learn_dir(self, path: str):
start = time.time()
splitters={
'.py': PythonCodeTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
'.md': MarkdownTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
'.tex': LatexTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap),
'.ipynb': NotebookSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
splitters = {
".py": PythonCodeTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".md": MarkdownTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".tex": LatexTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
".ipynb": NotebookSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
),
}
splitter = ExtensionSplitter(
splitters=splitters,
Expand All @@ -132,18 +131,24 @@ async def learn_dir(self, path: str):
delayed = split(path, splitter=splitter)
doc_chunks = await self.dask_client.submit(compute_delayed, delayed)

self.log.error(f"[/learn] Finished chunking documents. Time: {round((time.time() - start) * 1000)}ms")
self.log.error(
f"[/learn] Finished chunking documents. Time: {round((time.time() - start) * 1000)}ms"
)

em = self.get_embedding_model()
delayed = get_embeddings(doc_chunks, em)
embedding_records = await self.dask_client.submit(compute_delayed, delayed)
self.log.error(f"[/learn] Finished computing embeddings. Time: {round((time.time() - start) * 1000)}ms")
self.log.error(
f"[/learn] Finished computing embeddings. Time: {round((time.time() - start) * 1000)}ms"
)

self.index.add_embeddings(*embedding_records)
self._add_dir_to_metadata(path)

self.log.error(f"[/learn] Complete. Time: {round((time.time() - start) * 1000)}ms")


self.log.error(
f"[/learn] Complete. Time: {round((time.time() - start) * 1000)}ms"
)

def _add_dir_to_metadata(self, path: str):
dirs = self.metadata.dirs
index = next((i for i, dir in enumerate(dirs) if dir.path == path), None)
Expand Down Expand Up @@ -242,8 +247,10 @@ def get_relevant_documents(self, question: str) -> List[Document]:
docs = self.index.similarity_search(question)
return docs
return []

async def aget_relevant_documents(self, query: str) -> Coroutine[Any, Any, List[Document]]:

async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
return self.get_relevant_documents(query)

def get_embedding_model(self):
Expand Down
35 changes: 24 additions & 11 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import json
import logging
import os
from typing import Any, Dict, Union
import logging

from jupyter_ai.models import GlobalConfig
from jupyter_ai_magics.utils import (
AnyProvider,
EmProvidersDict,
LmProvidersDict,
get_em_provider,
get_lm_provider,
)
from jupyter_core.paths import jupyter_data_dir
from jupyter_ai_magics.utils import get_lm_provider, get_em_provider, AnyProvider, LmProvidersDict, EmProvidersDict

Logger = Union[logging.Logger, logging.LoggerAdapter]


class ConfigManager:
"""Provides model and embedding provider id along
with the credentials to authenticate providers.
"""

def __init__(self, log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict):
def __init__(
self, log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict
):
self.log = log
self.save_dir = os.path.join(jupyter_data_dir(), "jupyter_ai")
self.save_path = os.path.join(self.save_dir, "config.json")
Expand All @@ -36,20 +45,25 @@ def update(self, config: GlobalConfig, save_to_disk: bool = True):

def get_config(self):
return self.config

def get_lm_provider(self):
return self.lm_provider

def get_lm_provider_params(self):
return self.lm_provider_params

def get_em_provider(self):
return self.em_provider

def get_em_provider_params(self):
return self.em_provider_params

def _authenticate_provider(self, provider: AnyProvider, provider_params: Dict[str, Any], config: GlobalConfig):
def _authenticate_provider(
self,
provider: AnyProvider,
provider_params: Dict[str, Any],
config: GlobalConfig,
):
auth_strategy = provider.auth_strategy
if auth_strategy and auth_strategy.type == "env":
api_keys = config.api_keys
Expand Down Expand Up @@ -79,7 +93,7 @@ def _update_lm_provider(self, config: GlobalConfig):
self._authenticate_provider(provider, provider_params, config)
self.lm_provider = provider
self.lm_provider_params = provider_params

def _update_em_provider(self, config: GlobalConfig):
model_id = config.embeddings_provider_id

Expand All @@ -93,7 +107,7 @@ def _update_em_provider(self, config: GlobalConfig):
if not provider:
raise ValueError(f"No provider and model found with '{model_id}'")

provider_params = { "model_id": local_model_id }
provider_params = {"model_id": local_model_id}

self._authenticate_provider(provider, provider_params, config)
self.em_provider = provider
Expand All @@ -115,4 +129,3 @@ def _load(self):

# otherwise, create a new empty config file
self.update(GlobalConfig(), True)

Loading

0 comments on commit f6a9b14

Please sign in to comment.