Skip to content

Commit

Permalink
Add ConfigManager for chat models
Browse files Browse the repository at this point in the history
- Introduce ConfigManager in devchat/config.py.
- Enable creation, loading, validation, and updating of configs.
- Add tests for ConfigManager in tests/test_config.py.
  • Loading branch information
basicthinker committed Sep 4, 2023
1 parent 317a693 commit 3d7c49b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
72 changes: 72 additions & 0 deletions devchat/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
from typing import List, Optional
from pydantic import BaseModel
import yaml
from devchat.openai import OpenAIChatConfig


class ModelConfig(BaseModel):
id: str
max_input_tokens: Optional[int]
parameters: Optional[OpenAIChatConfig]


class ChatConfig(BaseModel):
models: Optional[List[ModelConfig]]


class ConfigManager:
def __init__(self, dir_path: str):
self.config_path = os.path.join(dir_path, 'config.yml')
if not os.path.exists(self.config_path):
self._create_sample_config()
self.config = self._load_and_validate_config()

def _create_sample_config(self):
sample_config = ChatConfig(models=[
ModelConfig(
id="gpt-4",
max_input_tokens=6000,
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0, stream=True)
),
ModelConfig(
id="gpt-3.5-turbo-16k",
max_input_tokens=12000,
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0, stream=True)
),
ModelConfig(
id="gpt-3.5-turbo",
max_input_tokens=3000,
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0, stream=True)
)
])
with open(self.config_path, 'w', encoding='utf-8') as file:
yaml.dump(sample_config.dict(), file)

def _load_and_validate_config(self) -> ChatConfig:
with open(self.config_path, 'r', encoding='utf-8') as file:
config_data = yaml.safe_load(file)
return ChatConfig(**config_data)

def get_model_config(self, model_id: Optional[str] = None) -> ModelConfig:
if not model_id:
if len(self.config.models) > 0:
return self.config.models[0]
return None
for model in self.config.models:
if model.id == model_id:
return model
return None

def update_model_config(self, model_config: ModelConfig) -> ModelConfig:
model = self.get_model_config(model_config.id)
if not model:
return None
if model_config.max_input_tokens is not None:
model.max_input_tokens = model_config.max_input_tokens
if model_config.parameters is not None:
updated_parameters = model.parameters.dict()
updated_parameters.update(
{k: v for k, v in model_config.parameters.dict().items() if v is not None})
model.parameters = OpenAIChatConfig(**updated_parameters)
return model
34 changes: 34 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from devchat.config import ConfigManager, ModelConfig, ChatConfig, OpenAIChatConfig


def test_create_sample_config():
ConfigManager('/tmp')
assert os.path.exists('/tmp/config.yml')


def test_load_and_validate_config():
config_manager = ConfigManager('/tmp')
assert isinstance(config_manager.config, ChatConfig)


def test_get_model_config():
config_manager = ConfigManager('/tmp')
model_config = config_manager.get_model_config('gpt-4')
assert model_config.id == 'gpt-4'
assert model_config.max_input_tokens == 6000
assert model_config.parameters.temperature == 0
assert model_config.parameters.stream is True


def test_update_model_config():
config_manager = ConfigManager('/tmp')
new_model_config = ModelConfig(
id='gpt-4',
max_input_tokens=7000,
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0.5)
)
updated_model_config = config_manager.update_model_config(new_model_config)
assert updated_model_config.max_input_tokens == 7000
assert updated_model_config.parameters.temperature == 0.5
assert updated_model_config.parameters.stream is True

0 comments on commit 3d7c49b

Please sign in to comment.