Skip to content

Commit

Permalink
Adjusted GPT4All llm to streaming API and added support for GPT4All_J (
Browse files Browse the repository at this point in the history
…#4131)

Fix for these issues:
#4126

#3839 (comment)

---------

Co-authored-by: Pawel Faron <ext-pawel.faron@vaisala.com>
  • Loading branch information
PawelFaron and Pawel Faron committed May 6, 2023
1 parent 075d963 commit 04b74d0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 18 deletions.
4 changes: 3 additions & 1 deletion docs/modules/models/llms/integrations/gpt4all.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@
"# Callbacks support token-wise streaming\n",
"callbacks = [StreamingStdOutCallbackHandler()]\n",
"# Verbose is required to pass to the callback manager\n",
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)"
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
"# If you want to use GPT4ALL_J model add the backend parameter\n",
"llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)"
]
},
{
Expand Down
69 changes: 52 additions & 17 deletions langchain/llms/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class GPT4All(LLM):
model: str
"""Path to the pre-trained GPT4All model file."""

backend: str = Field("llama", alias="backend")

n_ctx: int = Field(512, alias="n_ctx")
"""Token context window."""

Expand Down Expand Up @@ -93,21 +95,28 @@ class Config:

extra = Extra.forbid

@property
def _default_params(self) -> Dict[str, Any]:
def _llama_default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"seed": self.seed,
"n_predict": self.n_predict,
"n_threads": self.n_threads,
"n_batch": self.n_batch,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
}

def _gptj_default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"n_predict": self.n_predict,
"n_threads": self.n_threads,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
}

@staticmethod
def _llama_param_names() -> Set[str]:
"""Get the identifying parameters."""
Expand All @@ -122,14 +131,41 @@ def _llama_param_names() -> Set[str]:
"embedding",
}

@staticmethod
def _gptj_param_names() -> Set[str]:
"""Get the identifying parameters."""
return set()

@staticmethod
def _model_param_names(backend: str) -> Set[str]:
if backend == "llama":
return GPT4All._llama_param_names()
else:
return GPT4All._gptj_param_names()

def _default_params(self) -> Dict[str, Any]:
if self.backend == "llama":
return self._llama_default_params()
else:
return self._gptj_default_params()

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
from pygpt4all.models.gpt4all import GPT4All as GPT4AllModel

llama_keys = cls._llama_param_names()
model_kwargs = {k: v for k, v in values.items() if k in llama_keys}
backend = values["backend"]
if backend == "llama":
from pygpt4all import GPT4All as GPT4AllModel
elif backend == "gptj":
from pygpt4all import GPT4All_J as GPT4AllModel
else:
raise ValueError(f"Incorrect gpt4all backend {cls.backend}")

model_kwargs = {
k: v
for k, v in values.items()
if k in GPT4All._model_param_names(backend)
}
values["client"] = GPT4AllModel(
model_path=values["model"],
**model_kwargs,
Expand All @@ -147,11 +183,11 @@ def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
**self._default_params,
**self._default_params(),
**{
k: v
for k, v in self.__dict__.items()
if k in GPT4All._llama_param_names()
if k in self._model_param_names(self.backend)
},
}

Expand Down Expand Up @@ -181,15 +217,14 @@ def _call(
prompt = "Once upon a time, "
response = model(prompt, n_predict=55)
"""
text_callback = None
if run_manager:
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
text = self.client.generate(
prompt,
new_text_callback=text_callback,
**self._default_params,
)
else:
text = self.client.generate(prompt, **self._default_params)
text = ""
for token in self.client.generate(prompt, **self._default_params()):
if text_callback:
text_callback(token)
text += token
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

0 comments on commit 04b74d0

Please sign in to comment.