-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Introduce ConfigManager in devchat/config.py. - Enable creation, loading, validation, and updating of configs. - Add tests for ConfigManager in tests/test_config.py.
- Loading branch information
1 parent
317a693
commit 3d7c49b
Showing
2 changed files
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
from typing import List, Optional | ||
from pydantic import BaseModel | ||
import yaml | ||
from devchat.openai import OpenAIChatConfig | ||
|
||
|
||
class ModelConfig(BaseModel): | ||
id: str | ||
max_input_tokens: Optional[int] | ||
parameters: Optional[OpenAIChatConfig] | ||
|
||
|
||
class ChatConfig(BaseModel): | ||
models: Optional[List[ModelConfig]] | ||
|
||
|
||
class ConfigManager: | ||
def __init__(self, dir_path: str): | ||
self.config_path = os.path.join(dir_path, 'config.yml') | ||
if not os.path.exists(self.config_path): | ||
self._create_sample_config() | ||
self.config = self._load_and_validate_config() | ||
|
||
def _create_sample_config(self): | ||
sample_config = ChatConfig(models=[ | ||
ModelConfig( | ||
id="gpt-4", | ||
max_input_tokens=6000, | ||
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0, stream=True) | ||
), | ||
ModelConfig( | ||
id="gpt-3.5-turbo-16k", | ||
max_input_tokens=12000, | ||
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0, stream=True) | ||
), | ||
ModelConfig( | ||
id="gpt-3.5-turbo", | ||
max_input_tokens=3000, | ||
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0, stream=True) | ||
) | ||
]) | ||
with open(self.config_path, 'w', encoding='utf-8') as file: | ||
yaml.dump(sample_config.dict(), file) | ||
|
||
def _load_and_validate_config(self) -> ChatConfig: | ||
with open(self.config_path, 'r', encoding='utf-8') as file: | ||
config_data = yaml.safe_load(file) | ||
return ChatConfig(**config_data) | ||
|
||
def get_model_config(self, model_id: Optional[str] = None) -> ModelConfig: | ||
if not model_id: | ||
if len(self.config.models) > 0: | ||
return self.config.models[0] | ||
return None | ||
for model in self.config.models: | ||
if model.id == model_id: | ||
return model | ||
return None | ||
|
||
def update_model_config(self, model_config: ModelConfig) -> ModelConfig: | ||
model = self.get_model_config(model_config.id) | ||
if not model: | ||
return None | ||
if model_config.max_input_tokens is not None: | ||
model.max_input_tokens = model_config.max_input_tokens | ||
if model_config.parameters is not None: | ||
updated_parameters = model.parameters.dict() | ||
updated_parameters.update( | ||
{k: v for k, v in model_config.parameters.dict().items() if v is not None}) | ||
model.parameters = OpenAIChatConfig(**updated_parameters) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
from devchat.config import ConfigManager, ModelConfig, ChatConfig, OpenAIChatConfig | ||
|
||
|
||
def test_create_sample_config(): | ||
ConfigManager('/tmp') | ||
assert os.path.exists('/tmp/config.yml') | ||
|
||
|
||
def test_load_and_validate_config(): | ||
config_manager = ConfigManager('/tmp') | ||
assert isinstance(config_manager.config, ChatConfig) | ||
|
||
|
||
def test_get_model_config(): | ||
config_manager = ConfigManager('/tmp') | ||
model_config = config_manager.get_model_config('gpt-4') | ||
assert model_config.id == 'gpt-4' | ||
assert model_config.max_input_tokens == 6000 | ||
assert model_config.parameters.temperature == 0 | ||
assert model_config.parameters.stream is True | ||
|
||
|
||
def test_update_model_config(): | ||
config_manager = ConfigManager('/tmp') | ||
new_model_config = ModelConfig( | ||
id='gpt-4', | ||
max_input_tokens=7000, | ||
parameters=OpenAIChatConfig(model="TO_REMOVE", temperature=0.5) | ||
) | ||
updated_model_config = config_manager.update_model_config(new_model_config) | ||
assert updated_model_config.max_input_tokens == 7000 | ||
assert updated_model_config.parameters.temperature == 0.5 | ||
assert updated_model_config.parameters.stream is True |