diff --git a/prompting/llms/hf.py b/prompting/llms/hf.py index be46dc1b8..d2a095405 100644 --- a/prompting/llms/hf.py +++ b/prompting/llms/hf.py @@ -207,7 +207,11 @@ def __call__(self, messages: List[Dict[str, str]]): return self.forward(messages=messages) def _make_prompt(self, messages: List[Dict[str, str]]): - return self.llm_pipeline.tokenizer.apply_chat_template( + # The tokenizer.tokenizer is used for a integration with vllm and the mock pipeline, for real hf application, use: + # return self.llm_pipeline.tokenizer.apply_chat_template( + # messages, tokenize=False, add_generation_prompt=True + # ) + return self.llm_pipeline.tokenizer.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index a8bf9e095..85ee7b974 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -67,7 +67,7 @@ def __init__( self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock) self.mock = mock self.gpus = gpus - self.tokenizer = self.llm.llm_engine.tokenizer + self.tokenizer = self.llm.llm_engine.tokenizer.tokenizer def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str: if self.mock: diff --git a/prompting/mock.py b/prompting/mock.py index cc8ea8167..34b172a82 100644 --- a/prompting/mock.py +++ b/prompting/mock.py @@ -30,14 +30,16 @@ class MockModel(torch.nn.Module): def __init__(self, phrase): super().__init__() - self.tokenizer = MockTokenizer() + self.tokenizer = SimpleNamespace( + tokenizer=MockTokenizer() + ) self.phrase = phrase def __call__(self, messages): return self.forward(messages) def forward(self, messages): - role_tag = self.tokenizer.role_expr.format(role="assistant") + role_tag = self.tokenizer.tokenizer.role_expr.format(role="assistant") return f"{role_tag} {self.phrase}" @@ -73,7 +75,7 @@ def forward(self, messages, **kwargs): return self.postprocess(output) def postprocess(self, output, **kwargs): - output = output.split(self.model.tokenizer.role_expr.format(role="assistant"))[ + output = output.split(self.model.tokenizer.tokenizer.role_expr.format(role="assistant"))[ -1 ].strip() return output diff --git a/prompting/utils/logging.py b/prompting/utils/logging.py index 381d68b62..d7029b89c 100644 --- a/prompting/utils/logging.py +++ b/prompting/utils/logging.py @@ -92,10 +92,7 @@ def init_wandb(self, reinit=False): tags=tags, notes=self.config.wandb.notes, ) - bt.logging.success( - prefix="Started a new wandb run", - sufix=f" {self.wandb.name} ", - ) + bt.logging.success(f"Started a new wandb run {self.wandb.name} ") def reinit_wandb(self):