Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
55064c2
ensure dict type
shahules786 Feb 9, 2024
1ae2319
Merge branch 'main' of github.com:explodinggradients/ragas
shahules786 Feb 9, 2024
8ec0b01
Merge branch 'main' of github.com:explodinggradients/ragas
shahules786 Feb 9, 2024
0c8e60c
Merge branch 'main' of github.com:explodinggradients/ragas
shahules786 Feb 12, 2024
6cca7bb
Merge branch 'main' of github.com:explodinggradients/ragas
shahules786 Feb 17, 2024
0a996de
add node to doc similarity
shahules786 Feb 19, 2024
4f40c8e
change chunk size
shahules786 Feb 19, 2024
b48af97
prev and next for node
shahules786 Feb 20, 2024
3204b82
noted todo
shahules786 Feb 20, 2024
9aa4c27
convert prompts
shahules786 Feb 20, 2024
d445c16
add retry question
shahules786 Feb 22, 2024
544dd39
add retry with adjacent ndoe
shahules786 Feb 22, 2024
5db0fb1
fixed filenames added a __eq__ for fast comparisions
jjmachan Feb 22, 2024
c5cc197
add question rewriting to all evolutions
shahules786 Feb 22, 2024
882792e
fix retry
shahules786 Feb 22, 2024
77262b4
updated a few tests
jjmachan Feb 23, 2024
6f0ceb3
added tests
jjmachan Feb 23, 2024
46cbb11
type fixes
shahules786 Feb 23, 2024
462dbb4
Merge branch 'main' of github.com:explodinggradients/ragas into dev#627
shahules786 Feb 23, 2024
72d51dd
use .get
shahules786 Feb 24, 2024
aa2f4f5
Merge branch 'tmp' into dev#627
jjmachan Feb 24, 2024
a8f4d1a
Merge branch 'dev#627' of github.com:shahules786/ragas into dev#627
shahules786 Feb 24, 2024
3e7dd3f
add next/prev
shahules786 Feb 24, 2024
377d2db
fix types
shahules786 Feb 24, 2024
43011cc
remove print
shahules786 Feb 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 89 additions & 54 deletions src/ragas/testset/docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,42 @@

class Document(LCDocument):
doc_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
filename: t.Optional[str] = None
embedding: t.Optional[t.List[float]] = Field(default=None, repr=False)

@classmethod
def from_langchain_document(cls, doc: LCDocument):
doc_id = str(uuid.uuid4())
if doc.metadata.get("filename"):
filename = doc.metadata["filename"]
@property
def filename(self):
filename = self.metadata.get("filename")
if filename is not None:
filename = self.metadata["filename"]
else:
logger.info(
"Document [ID: %s] has no filename. Using doc_id as filename.", doc_id
"Document [ID: %s] has no filename, using `doc_id` instead", self.doc_id
)
filename = doc_id
filename = self.doc_id

return filename

@classmethod
def from_langchain_document(cls, doc: LCDocument):
doc_id = str(uuid.uuid4())
return cls(
page_content=doc.page_content,
metadata=doc.metadata,
doc_id=doc_id,
filename=filename,
)

@classmethod
def from_llamaindex_document(cls, doc: LlamaindexDocument):
doc_id = str(uuid.uuid4())
if doc.metadata.get("filename"):
filename = doc.metadata["filename"]
else:
logger.info(
"Document [ID: %s] has no filename. Using doc_id as filename.", doc_id
)
filename = doc_id
return cls(
page_content=doc.text,
metadata=doc.metadata,
doc_id=doc_id,
filename=filename,
)


class Node(Document):
keyphrases: t.List[str] = Field(default_factory=list, repr=False)
def __eq__(self, other) -> bool:
# if the doc_id's are same then the Document objects are same
return self.doc_id == other.doc_id


class Direction(str, Enum):
Expand All @@ -84,6 +80,21 @@ class Direction(str, Enum):
DOWN = "down"


class Node(Document):
keyphrases: t.List[str] = Field(default_factory=list, repr=False)
relationships: t.Dict[Direction, t.Any] = Field(default_factory=dict, repr=False)
doc_similarity: t.Optional[float] = Field(default=None, repr=False)
wins: int = 0

@property
def next(self):
return self.relationships.get(Direction.NEXT)

@property
def prev(self):
return self.relationships.get(Direction.PREV)


class DocumentStore(ABC):
def __init__(self):
self.documents = {}
Expand All @@ -110,12 +121,6 @@ def get_similar(
) -> t.Union[t.List[Document], t.List[Node]]:
...

@abstractmethod
def get_adjacent(
self, node: Node, direction: Direction = Direction.NEXT
) -> t.Optional[Node]:
...

def set_run_config(self, run_config: RunConfig):
...

Expand Down Expand Up @@ -206,7 +211,6 @@ def add_documents(self, docs: t.Sequence[Document], show_progress=True):
Node.from_langchain_document(d)
for d in self.splitter.transform_documents(docs)
]

self.add_nodes(nodes, show_progress=show_progress)

def add_nodes(
Expand Down Expand Up @@ -264,14 +268,70 @@ def add_nodes(
), "Embedding must be list or np.ndarray"
self.node_embeddings_list.append(n.embedding)

self.calculate_nodes_docs_similarity()
self.set_node_relataionships()

def set_node_relataionships(self):
for i, node in enumerate(self.nodes):
if i > 0:
prev_node = self.nodes[i - 1]
if prev_node.filename == node.filename:
node.relationships[Direction.PREV] = prev_node
prev_node.relationships[Direction.NEXT] = node
else:
node.relationships[Direction.PREV] = None
prev_node.relationships[Direction.NEXT] = None
if i == len(self.nodes) - 1:
node.relationships[Direction.NEXT] = None

def calculate_nodes_docs_similarity(self):
doc_embeddings = {}
filename_ids = set(
[node.filename for node in self.nodes if node.filename is not None]
)
node_ids = set([node.doc_id for node in self.nodes])

if len(filename_ids) == len(node_ids):
logger.warning("Filename and doc_id are the same for all nodes.")
for node in self.nodes:
node.doc_similarity = 1.0

else:
for file_id in filename_ids:
nodes_embedding = np.array(
[node.embedding for node in self.nodes if node.filename == file_id]
)
nodes_embedding = nodes_embedding.reshape(len(nodes_embedding), -1)
doc_embeddings[file_id] = np.mean(nodes_embedding, axis=0)

for node in self.nodes:
assert node.embedding is not None, "Embedding cannot be None"
node.doc_similarity = similarity(
node.embedding, doc_embeddings[node.filename]
)

def get_node(self, node_id: str) -> Node:
return self.node_map[node_id]

def get_document(self, doc_id: str) -> Node:
raise NotImplementedError

def get_random_nodes(self, k=1) -> t.List[Node]:
return rng.choice(np.array(self.nodes), size=k).tolist()
def get_random_nodes(self, k=1, alpha=0.1) -> t.List[Node]:
def adjustment_factor(wins, alpha):
return np.exp(-alpha * wins)

scores = [adjustment_factor(node.wins, alpha) for node in self.nodes]
similarity_scores = [node.doc_similarity for node in self.nodes]
prob = np.array(scores) * np.array(similarity_scores)
prob = prob / np.sum(prob)

nodes = rng.choice(np.array(self.nodes), size=k, p=prob).tolist()

for node in nodes:
idx = self.nodes.index(node)
self.nodes[idx].wins += 1

return nodes

def get_similar(
self, node: Node, threshold: float = 0.7, top_k: int = 3
Expand All @@ -293,31 +353,6 @@ def get_similar(
items = [self.nodes[doc_id] for doc_id in doc_ids]
return items

def get_adjacent(
self, node: Node, direction: Direction = Direction.NEXT
) -> t.Optional[Node]:
# linear search for doc_id of doc in documents_list
index = self.nodes.index(node)

if direction == Direction.NEXT:
if len(self.nodes) > index + 1:
next_doc = self.nodes[index + 1]
if next_doc.filename == node.filename:
return next_doc
else:
return None
else:
return None
if direction == Direction.PREV:
if index > 0:
prev_doc = self.nodes[index - 1]
if prev_doc.filename == node.filename:
return prev_doc
else:
return None
else:
return None

def set_run_config(self, run_config: RunConfig):
if self.embeddings:
self.embeddings.set_run_config(run_config)
Loading