diff --git a/prompting/tasks/inference.py b/prompting/tasks/inference.py index 2c83865ff..be8bba7d9 100644 --- a/prompting/tasks/inference.py +++ b/prompting/tasks/inference.py @@ -30,21 +30,36 @@ class InferenceRewardConfig(BaseRewardConfig): Ask a question about the text and nothing else:""" +SYSTEM_PROMPTS = [ + "", + "You are a helpful AI assistant. Provide concise, accurate answers to any questions asked.", + "You are a friendly and patient assistant. Communicate your responses in a clear, easy-to-understand way, ensuring the user feels supported.", + "You are a creative helper. Offer engaging, imaginative responses that keep the user interested, while maintaining accuracy and clarity.", +] + class InferenceTask(BaseTextTask): name: ClassVar[str] = "inference" # TODO: Once we want to enable the 'actual' inference task with exact models query: str | None = None reference: str | None = None + system_prompt: str | None = None llm_model: ModelConfig | None = None llm_model_id: ModelConfig | None = random.choice(ModelZoo.models_configs).llm_model_id seed: int = Field(default_factory=lambda: random.randint(0, 1_000_000), allow_mutation=False) - sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS + sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS.copy() @model_validator(mode="after") def random_llm_model_id(self): if self.query: # If we are already defining query, as in the case of organics, we also specify model. return self + # Choose system prompt and randomize inference settings + self.system_prompt = random.choice(SYSTEM_PROMPTS) + self.messages = [] + if self.system_prompt: + self.messages.append({"role": "system", "content": self.system_prompt}) + self.sampling_params["temperature"] = random.randint(0, 10) / 10 + self.sampling_params["max_new_tokens"] = random.choice([256, 512, 1024, 2048]) if np.random.rand() < 0.2: self.llm_model_id = None @@ -55,8 +70,9 @@ def random_llm_model_id(self): def make_query(self, dataset_entry: ChatEntry) -> str: if self.query: return self.query - self.query = dataset_entry.messages - self.messages = dataset_entry.messages + self.messages.extend(dataset_entry.messages) + self.query = self.messages + return self.query def make_reference(self, dataset_entry: ChatEntry) -> str: diff --git a/pyproject.toml b/pyproject.toml index b4adb1e16..5eeb0f612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ max-line-length = 120 extend-ignore = "D203,E203,E251,E266,E302,E305,E401,E402,E501,F401,F403,W503" exclude = ".git,__pycache__,dist,.venv,venv,*/lib/python*/site-packages,*/lib64/python*/site-packages" # TODO: Decrease to at least 10 (prompting/weight_setting/weight_setter.py). -max-complexity = 13 +max-complexity = 14 [tool.isort] atomic = true