diff --git a/prompting/llms/hf.py b/prompting/llms/hf.py
index be46dc1b8..d2a095405 100644
--- a/prompting/llms/hf.py
+++ b/prompting/llms/hf.py
@@ -207,7 +207,11 @@ def __call__(self, messages: List[Dict[str, str]]):
return self.forward(messages=messages)
def _make_prompt(self, messages: List[Dict[str, str]]):
- return self.llm_pipeline.tokenizer.apply_chat_template(
+ # The tokenizer.tokenizer is used for a integration with vllm and the mock pipeline, for real hf application, use:
+ # return self.llm_pipeline.tokenizer.apply_chat_template(
+ # messages, tokenize=False, add_generation_prompt=True
+ # )
+ return self.llm_pipeline.tokenizer.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py
index a8bf9e095..85ee7b974 100644
--- a/prompting/llms/vllm_llm.py
+++ b/prompting/llms/vllm_llm.py
@@ -67,7 +67,7 @@ def __init__(
self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock)
self.mock = mock
self.gpus = gpus
- self.tokenizer = self.llm.llm_engine.tokenizer
+ self.tokenizer = self.llm.llm_engine.tokenizer.tokenizer
def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str:
if self.mock:
diff --git a/prompting/mock.py b/prompting/mock.py
index cc8ea8167..34b172a82 100644
--- a/prompting/mock.py
+++ b/prompting/mock.py
@@ -30,14 +30,16 @@ class MockModel(torch.nn.Module):
def __init__(self, phrase):
super().__init__()
- self.tokenizer = MockTokenizer()
+ self.tokenizer = SimpleNamespace(
+ tokenizer=MockTokenizer()
+ )
self.phrase = phrase
def __call__(self, messages):
return self.forward(messages)
def forward(self, messages):
- role_tag = self.tokenizer.role_expr.format(role="assistant")
+ role_tag = self.tokenizer.tokenizer.role_expr.format(role="assistant")
return f"{role_tag} {self.phrase}"
@@ -73,7 +75,7 @@ def forward(self, messages, **kwargs):
return self.postprocess(output)
def postprocess(self, output, **kwargs):
- output = output.split(self.model.tokenizer.role_expr.format(role="assistant"))[
+ output = output.split(self.model.tokenizer.tokenizer.role_expr.format(role="assistant"))[
-1
].strip()
return output
diff --git a/prompting/utils/logging.py b/prompting/utils/logging.py
index 381d68b62..d7029b89c 100644
--- a/prompting/utils/logging.py
+++ b/prompting/utils/logging.py
@@ -92,10 +92,7 @@ def init_wandb(self, reinit=False):
tags=tags,
notes=self.config.wandb.notes,
)
- bt.logging.success(
- prefix="Started a new wandb run",
- sufix=f" {self.wandb.name} ",
- )
+ bt.logging.success(f"Started a new wandb run {self.wandb.name} ")
def reinit_wandb(self):