-
Notifications
You must be signed in to change notification settings - Fork 6
/
world_model.py
78 lines (73 loc) · 2.97 KB
/
world_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Summarizes the consequences of actions each day."""
from backends import (
OpenAIChatBackend,
OpenAICompletionBackend,
ClaudeCompletionBackend,
HuggingFaceCausalLMBackend,
)
from data_types import BackendResponse, WorldModelResponse
import prompts
from world import World
class WorldModel:
"""Uses LLMs to summarize the consequences of actions each day."""
def __init__(self, model_name: str, **kwargs) -> None:
"""Load a backend for the model."""
self.model_name = model_name
disable_completion_preface = kwargs.pop("disable_completion_preface", False)
self.use_completion_preface = not disable_completion_preface
if model_name == "mock":
self.backend = None
elif (
"gpt-4-base" in model_name
or "text-" in model_name
or "davinci" in model_name
or "turbo-instruct" in model_name
):
self.backend = OpenAICompletionBackend(model_name)
elif "claude" in model_name:
self.backend = ClaudeCompletionBackend(model_name)
elif "llama" in model_name:
self.local_llm_path = kwargs.pop("local_llm_path")
self.device = kwargs.pop("device")
self.quantization = kwargs.pop("quantization")
self.fourbit_compute_dtype = kwargs.pop("fourbit_compute_dtype")
self.backend = HuggingFaceCausalLMBackend(
model_name,
self.local_llm_path,
self.device,
self.quantization,
self.fourbit_compute_dtype,
)
else:
# Chat models can't specify the start of the completion
self.use_completion_preface = False
self.backend = OpenAIChatBackend(model_name)
def summarize_consequences(self, world: World) -> WorldModelResponse:
"""Summarize the consequences of each action."""
system_prompt = prompts.get_world_model_system_prompt(world)
user_prompt = prompts.get_world_model_user_prompt(world)
if self.model_name == "mock":
return WorldModelResponse(
consequences="TODO placeholder, will replace with llm later",
completion_time_sec=0.0,
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
system_prompt=system_prompt,
user_prompt=user_prompt,
)
completion: BackendResponse = self.backend.complete(
system_prompt,
user_prompt,
temperature=0.0,
top_p=1.0,
)
return WorldModelResponse(
consequences=completion.completion,
completion_time_sec=completion.completion_time_sec,
prompt_tokens=completion.prompt_tokens,
completion_tokens=completion.completion_tokens,
total_tokens=completion.total_tokens,
system_prompt=system_prompt,
user_prompt=user_prompt,
)