Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Jun 25, 2024
1 parent 322f353 commit 5efb6f2
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 52 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,3 @@ dev.sh
.yarn

.conda/

7 changes: 3 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
Together,
)
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler

from langchain_core.language_models.llms import BaseLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
Expand Down Expand Up @@ -456,14 +455,14 @@ def _supports_sync_streaming(self):
return not (self.__class__._stream is BaseChatModel._stream)
else:
return not (self.__class__._stream is BaseLLM._stream)

@property
def _supports_async_streaming(self):
if self.is_chat_provider:
return not (self.__class__._astream is BaseChatModel._astream)
else:
return not (self.__class__._astream is BaseLLM._astream)

@property
def supports_streaming(self):
return self._supports_sync_streaming or self._supports_async_streaming
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-test/jupyter_ai_test/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This file is auto-generated by Hatchling. As such, do not:
# - modify
# - track in version control e.g. be sure to add to .gitignore
__version__ = VERSION = '0.1.0'
__version__ = VERSION = "0.1.0"
12 changes: 7 additions & 5 deletions packages/jupyter-ai-test/jupyter_ai_test/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import time
from typing import Any, List, Optional, Iterator
from typing import Any, Iterator, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs.generation import GenerationChunk


class TestLLM(LLM):
model_id: str = "test"

Expand All @@ -21,7 +22,8 @@ def _call(
) -> str:
time.sleep(3)
return f"Hello! This is a dummy response from a test LLM."



class TestLLMWithStreaming(LLM):
model_id: str = "test"

Expand All @@ -47,9 +49,9 @@ def _stream(
**kwargs: Any,
) -> Iterator[GenerationChunk]:
time.sleep(5)
yield GenerationChunk(text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n")
yield GenerationChunk(
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n"
)
for i in range(1, 101):
time.sleep(0.5)
yield GenerationChunk(text=f"{i}, ")


10 changes: 3 additions & 7 deletions packages/jupyter-ai-test/jupyter_ai_test/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import ClassVar, List

from jupyter_ai import AuthStrategy, BaseProvider, Field

from .test_llms import TestLLM, TestLLMWithStreaming
Expand All @@ -11,9 +12,7 @@ class TestProvider(BaseProvider, TestLLM):
name: ClassVar[str] = "Test Provider"
"""User-facing name of this provider."""

models: ClassVar[List[str]] = [
"test"
]
models: ClassVar[List[str]] = ["test"]
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

Expand Down Expand Up @@ -49,9 +48,7 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming):
name: ClassVar[str] = "Test Provider (streaming)"
"""User-facing name of this provider."""

models: ClassVar[List[str]] = [
"test"
]
models: ClassVar[List[str]] = ["test"]
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

Expand All @@ -78,4 +75,3 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming):
fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage


class TestSlashCommand(BaseChatHandler):
"""
A test slash command implementation that developers should build from. The
string used to invoke this command is set by the `slash_id` keyword argument
in the `routing_type` attribute. The command is mainly implemented in the
`process_message()` method. See built-in implementations under
`jupyter_ai/handlers` for further reference.
`jupyter_ai/handlers` for further reference.
The provider is made available to Jupyter AI by the entry point declared in
`pyproject.toml`. If this class or parent module is renamed, make sure the
update the entry point there as well.
"""

id = "test"
name = "Test"
help = "A test slash command."
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):

handler.broadcast_message(agent_msg)
break

@property
def persona(self):
return self.config_manager.persona
Expand Down
31 changes: 17 additions & 14 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import time
from typing import Dict, Type
from uuid import uuid4
import time

from jupyter_ai.models import HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
HumanChatMessage,
)
from jupyter_ai_magics.providers import BaseProvider
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.history import RunnableWithMessageHistory

from .base import BaseChatHandler, SlashCommandRoutingType
from ..history import BoundedChatHistory
from .base import BaseChatHandler, SlashCommandRoutingType


class DefaultChatHandler(BaseChatHandler):
Expand Down Expand Up @@ -48,9 +52,8 @@ def create_llm_chain(
input_messages_key="input",
history_messages_key="history",
)

self.llm_chain = runnable

self.llm_chain = runnable

def _start_stream(self, human_msg: HumanChatMessage) -> str:
"""
Expand All @@ -64,7 +67,7 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False
complete=False,
)

for handler in self._root_chat_handlers.values():
Expand All @@ -75,16 +78,14 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:
break

return stream_id

def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False):
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id,
content=content,
stream_complete=complete
id=stream_id, content=content, stream_complete=complete
)

for handler in self._root_chat_handlers.values():
Expand All @@ -93,7 +94,6 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals

handler.broadcast_message(stream_chunk_msg)
break


async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
Expand All @@ -105,7 +105,10 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
async for chunk in self.llm_chain.astream({ "input": message.body }, config={"configurable": {"session_id": "static_session"}}):
async for chunk in self.llm_chain.astream(
{"input": message.body},
config={"configurable": {"session_id": "static_session"}},
):
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
Expand All @@ -120,6 +123,6 @@ async def process_message(self, message: HumanChatMessage):
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
37 changes: 23 additions & 14 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

from .models import (
AgentChatMessage,
AgentStreamMessage,
AgentStreamChunkMessage,
AgentStreamMessage,
ChatClient,
ChatHistory,
ChatMessage,
ChatRequest,
ChatUser,
ClosePendingMessage,
ConnectionMessage,
HumanChatMessage,
ListProvidersEntry,
Expand All @@ -32,7 +33,6 @@
ListSlashCommandsResponse,
Message,
PendingMessage,
ClosePendingMessage,
UpdateConfigRequest,
)

Expand All @@ -49,16 +49,15 @@ class ChatHistoryHandler(BaseAPIHandler):
@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]

@property
def pending_messages(self) -> List[PendingMessage]:
return self.settings["pending_messages"]

@tornado.web.authenticated
async def get(self):
history = ChatHistory(
messages=self.chat_history,
pending_messages=self.pending_messages
messages=self.chat_history, pending_messages=self.pending_messages
)
self.finish(history.json())

Expand Down Expand Up @@ -106,7 +105,7 @@ def loop(self) -> AbstractEventLoop:
@property
def pending_messages(self) -> List[PendingMessage]:
return self.settings["pending_messages"]

@pending_messages.setter
def pending_messages(self, new_pending_messages):
self.settings["pending_messages"] = new_pending_messages
Expand Down Expand Up @@ -186,10 +185,14 @@ def open(self):
self.root_chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**current_user, id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(
client_id=client_id,
history=ChatHistory(messages=self.chat_history, pending_messages=self.pending_messages)
).dict())
self.write_message(
ConnectionMessage(
client_id=client_id,
history=ChatHistory(
messages=self.chat_history, pending_messages=self.pending_messages
),
).dict()
)

self.log.info(f"Client connected. ID: {client_id}")
self.log.debug("Clients are : %s", self.root_chat_handlers.keys())
Expand All @@ -208,7 +211,9 @@ def broadcast_message(self, message: Message):
client.write_message(message.dict())

# append all messages of type `ChatMessage` directly to the chat history
if isinstance(message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage)):
if isinstance(
message, (HumanChatMessage, AgentChatMessage, AgentStreamMessage)
):
self.chat_history.append(message)
elif isinstance(message, AgentStreamChunkMessage):
# for stream chunks, modify the corresponding `AgentStreamMessage`
Expand All @@ -217,16 +222,20 @@ def broadcast_message(self, message: Message):

# iterate backwards from the end of the list
for i in range(len(self.chat_history) - 1, -1, -1):
if self.chat_history[i].type == 'agent-stream' and self.chat_history[i].id == chunk.id:
if (
self.chat_history[i].type == "agent-stream"
and self.chat_history[i].id == chunk.id
):
stream_message: AgentStreamMessage = self.chat_history[i]
stream_message.body += chunk.content
stream_message.complete = chunk.stream_complete
break
elif isinstance(message, PendingMessage):
self.pending_messages.append(message)
elif isinstance(message, ClosePendingMessage):
self.pending_messages = list(filter(lambda m: m.id != message.id, self.pending_messages))

self.pending_messages = list(
filter(lambda m: m.id != message.id, self.pending_messages)
)

async def on_message(self, message):
self.log.debug("Message received: %s", message)
Expand Down
6 changes: 4 additions & 2 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from langchain_core.messages import BaseMessage
from typing import List, Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Sequence


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
"""
Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ class AgentChatMessage(BaseModel):
this defaults to a description of `JupyternautPersona`.
"""


class AgentStreamMessage(AgentChatMessage):
type: Literal['agent-stream'] = 'agent-stream'
type: Literal["agent-stream"] = "agent-stream"
complete: bool
# other attrs inherited from `AgentChatMessage`


class AgentStreamChunkMessage(BaseModel):
type: Literal['agent-stream-chunk'] = 'agent-stream-chunk'
type: Literal["agent-stream-chunk"] = "agent-stream-chunk"
id: str
content: str
stream_complete: bool
Expand Down Expand Up @@ -118,6 +120,7 @@ class ClosePendingMessage(BaseModel):

class ChatHistory(BaseModel):
"""History of chat messages"""

messages: List[ChatMessage]
pending_messages: List[PendingMessage]

Expand Down

0 comments on commit 5efb6f2

Please sign in to comment.