/
base.py
85 lines (67 loc) 路 2.78 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Hypothetical Document Embeddings.
https://arxiv.org/abs/2212.10496
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import numpy as np
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra
from langchain.schema.embeddings import Embeddings
from langchain.schema.language_model import BaseLanguageModel
class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""Generate hypothetical document for query, and then embed that.
Based on https://arxiv.org/abs/2212.10496
"""
base_embeddings: Embeddings
llm_chain: LLMChain
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Input keys for Hyde's LLM chain."""
return self.llm_chain.input_keys
@property
def output_keys(self) -> List[str]:
"""Output keys for Hyde's LLM chain."""
return self.llm_chain.output_keys
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call the base embeddings."""
return self.base_embeddings.embed_documents(texts)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
"""Combine embeddings into final embeddings."""
return list(np.array(embeddings).mean(axis=0))
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Call the internal llm chain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
base_embeddings: Embeddings,
prompt_key: str,
**kwargs: Any,
) -> HypotheticalDocumentEmbedder:
"""Load and use LLMChain for a specific prompt key."""
prompt = PROMPT_MAP[prompt_key]
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
@property
def _chain_type(self) -> str:
return "hyde_chain"