Skip to content

Commit

Permalink
refactor: ♻️ refactoring LLM to avoid repetition
Browse files Browse the repository at this point in the history
  • Loading branch information
sestinj committed Sep 2, 2023
1 parent 2f792f4 commit 8967e2d
Show file tree
Hide file tree
Showing 30 changed files with 242 additions and 513 deletions.
36 changes: 24 additions & 12 deletions .github/ISSUE_TEMPLATE/bug-report-🐛.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,55 @@
---
name: "Bug report \U0001F41B"
about: Create a report to help us fix your bug
title: ''
title: ""
labels: bug
assignees: ''

assignees: ""
---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Steps to reproduce the behavior:

1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

**Expected behavior**
A clear and concise description of what you expected to happen.

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Environment**

- Operating System: [e.g. MacOS]
- Python Version: [e.g. 3.10.6]
- Continue Version: [e.g. v0.0.207]

**Console logs**
**Logs**

```
REPLACE THIS SECTION WITH CONSOLE LOGS OR A SCREENSHOT...
```

To get the Continue server logs:

1. cmd+shift+p (MacOS) / ctrl+shift+p (Windows)
2. Search for and then select "Continue: View Continue Server Logs"
3. Scroll to the bottom of `continue.log` and copy the last 100 lines or so

To get the VS Code console logs:

To get the console logs in VS Code:
1. cmd+shift+p (MacOS) / ctrl+shift+p (Windows)
2. Search for and then select "Developer: Toggle Developer Tools"
3. Select Console
4. Read the console logs
```

If the problem is related to LLM prompting:

1. Hover the problematic response in the Continue UI
2. Click the "magnifying glass" icon
3. Copy the contents of the `continue_logs.txt` file that opens

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Additional context**
Add any other context about the problem here.
6 changes: 1 addition & 5 deletions .github/ISSUE_TEMPLATE/feature-request-💪.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
---
name: "Feature request \U0001F4AA"
about: Suggest an idea for this project
title: ''
title: ""
labels: enhancement
assignees: TyDunn

---

**Is your feature request related to a problem? Please describe.**
Expand All @@ -13,8 +12,5 @@ A clear and concise description of what the problem is. Ex. I'm always frustrate
**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.
4 changes: 0 additions & 4 deletions continuedev/src/continuedev/core/abstract_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ async def add_directory(self, path: str):
async def delete_directory(self, path: str):
pass

@abstractmethod
async def get_user_secret(self, env_var: str) -> str:
pass

config: ContinueConfig

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion continuedev/src/continuedev/core/autopilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ async def accept_user_input(self, user_input: str):
if self.session_info is None:

async def create_title():
title = await self.continue_sdk.models.medium.complete(
title = await self.continue_sdk.models.medium._complete(
f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ',
max_tokens=20,
)
Expand Down
13 changes: 1 addition & 12 deletions continuedev/src/continuedev/core/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,7 @@ def write_log(self, message: str):
self.history.timeline[self.history.current_index].logs.append(message)

async def start_model(self, llm: LLM):
kwargs = {}
if llm.requires_api_key:
kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key)
if llm.requires_unique_id:
kwargs["unique_id"] = self.ide.unique_id
if llm.requires_write_log:
kwargs["write_log"] = self.write_log
await llm.start(**kwargs)
await llm.start(unique_id=self.ide.unique_id, write_log=self.write_log)

async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
Expand Down Expand Up @@ -211,10 +204,6 @@ async def delete_directory(self, path: str):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))

async def get_user_secret(self, env_var: str) -> str:
# TODO support error prompt dynamically set on env_var
return await self.ide.getUserSecret(env_var)

_last_valid_config: ContinueConfig = None

def _load_config_dot_py(self) -> ContinueConfig:
Expand Down
55 changes: 33 additions & 22 deletions continuedev/src/continuedev/libs/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,77 @@
from abc import ABC, abstractproperty
from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from abc import ABC
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union

from ...core.main import ChatMessage
from ...models.main import ContinueBaseModel
from ..util.count_tokens import DEFAULT_ARGS, count_tokens


class LLM(ContinueBaseModel, ABC):
requires_api_key: Optional[str] = None
requires_unique_id: bool = False
requires_write_log: bool = False
title: Optional[str] = None
system_message: Optional[str] = None

context_length: int = 2048
"The maximum context length of the LLM in tokens, as counted by count_tokens."

unique_id: Optional[str] = None
"The unique ID of the user."

model: str
"The model name"

prompt_templates: dict = {}

write_log: Optional[Callable[[str], None]] = None
"A function that takes a string and writes it to the log."

api_key: Optional[str] = None
"The API key for the LLM provider."

class Config:
arbitrary_types_allowed = True
extra = "allow"

def dict(self, **kwargs):
original_dict = super().dict(**kwargs)
original_dict.pop("write_log", None)
original_dict["name"] = self.name
original_dict["class_name"] = self.__class__.__name__
return original_dict

@abstractproperty
def name(self):
"""Return the name of the LLM."""
raise NotImplementedError
def collect_args(self, **kwargs) -> Any:
"""Collect the arguments for the LLM."""
args = {**DEFAULT_ARGS.copy(), "model": self.model, "max_tokens": 1024}
args.update(kwargs)
return args

async def start(self, *, api_key: Optional[str] = None, **kwargs):
async def start(
self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None
):
"""Start the connection to the LLM."""
raise NotImplementedError
self.write_log = write_log
self.unique_id = unique_id

async def stop(self):
"""Stop the connection to the LLM."""
raise NotImplementedError
pass

async def complete(
async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError

def stream_complete(
def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the completion through generator."""
raise NotImplementedError

async def stream_chat(
async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the chat through generator."""
raise NotImplementedError

def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
raise NotImplementedError

@abstractproperty
def context_length(self) -> int:
"""Return the context length of the LLM in tokens, as counted by count_tokens."""
raise NotImplementedError
return count_tokens(self.model, text)
55 changes: 15 additions & 40 deletions continuedev/src/continuedev/libs/llm/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,36 @@
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
from typing import Any, Coroutine, Dict, Generator, List, Union

from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic

from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
from ..util.count_tokens import compile_chat_messages


class AnthropicLLM(LLM):
api_key: str
"Anthropic API key"

model: str = "claude-2"

requires_write_log = True
_async_client: AsyncAnthropic = None

class Config:
arbitrary_types_allowed = True

write_log: Optional[Callable[[str], None]] = None

async def start(
self,
*,
api_key: Optional[str] = None,
write_log: Callable[[str], None],
**kwargs,
):
self.write_log = write_log
await super().start(**kwargs)
self._async_client = AsyncAnthropic(api_key=self.api_key)

async def stop(self):
pass

@property
def name(self):
return self.model
if self.model == "claude-2":
self.context_length = 100_000

@property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.model}
def collect_args(self, **kwargs) -> Any:
args = super().collect_args(**kwargs)

def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
args = args.copy()
if "max_tokens" in args:
args["max_tokens_to_sample"] = args["max_tokens"]
del args["max_tokens"]
Expand All @@ -51,15 +40,6 @@ def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
del args["presence_penalty"]
return args

def count_tokens(self, text: str):
return count_tokens(self.model, text)

@property
def context_length(self):
if self.model == "claude-2":
return 100000
raise Exception(f"Unknown Anthropic model {self.model}")

def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""

Expand All @@ -76,13 +56,11 @@ def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt += AI_PROMPT
return prompt

async def stream_complete(
async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args = self.collect_args(**kwargs)
args["stream"] = True
args = self._transform_args(args)
prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}"

self.write_log(f"Prompt: \n\n{prompt}")
Expand All @@ -95,13 +73,11 @@ async def stream_complete(

self.write_log(f"Completion: \n\n{completion}")

async def stream_chat(
async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args = self.collect_args(**kwargs)
args["stream"] = True
args = self._transform_args(args)

messages = compile_chat_messages(
args["model"],
Expand All @@ -123,11 +99,10 @@ async def stream_chat(

self.write_log(f"Completion: \n\n{completion}")

async def complete(
async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
args = self._transform_args(args)
args = self.collect_args(**kwargs)

messages = compile_chat_messages(
args["model"],
Expand Down

0 comments on commit 8967e2d

Please sign in to comment.