Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion prompting/llms/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ def __call__(self, messages: List[Dict[str, str]]):
return self.forward(messages=messages)

def _make_prompt(self, messages: List[Dict[str, str]]):
return self.llm_pipeline.tokenizer.apply_chat_template(
# The tokenizer.tokenizer is used for a integration with vllm and the mock pipeline, for real hf application, use:
# return self.llm_pipeline.tokenizer.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
return self.llm_pipeline.tokenizer.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

Expand Down
2 changes: 1 addition & 1 deletion prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock)
self.mock = mock
self.gpus = gpus
self.tokenizer = self.llm.llm_engine.tokenizer
self.tokenizer = self.llm.llm_engine.tokenizer.tokenizer

def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str:
if self.mock:
Expand Down
8 changes: 5 additions & 3 deletions prompting/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ class MockModel(torch.nn.Module):
def __init__(self, phrase):
super().__init__()

self.tokenizer = MockTokenizer()
self.tokenizer = SimpleNamespace(
tokenizer=MockTokenizer()
)
self.phrase = phrase

def __call__(self, messages):
return self.forward(messages)

def forward(self, messages):
role_tag = self.tokenizer.role_expr.format(role="assistant")
role_tag = self.tokenizer.tokenizer.role_expr.format(role="assistant")
return f"{role_tag} {self.phrase}"


Expand Down Expand Up @@ -73,7 +75,7 @@ def forward(self, messages, **kwargs):
return self.postprocess(output)

def postprocess(self, output, **kwargs):
output = output.split(self.model.tokenizer.role_expr.format(role="assistant"))[
output = output.split(self.model.tokenizer.tokenizer.role_expr.format(role="assistant"))[
-1
].strip()
return output
Expand Down
5 changes: 1 addition & 4 deletions prompting/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ def init_wandb(self, reinit=False):
tags=tags,
notes=self.config.wandb.notes,
)
bt.logging.success(
prefix="Started a new wandb run",
sufix=f"<blue> {self.wandb.name} </blue>",
)
bt.logging.success(f"Started a new wandb run <blue> {self.wandb.name} </blue>")


def reinit_wandb(self):
Expand Down