Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu committed Aug 8, 2023
1 parent 9e1bce2 commit 859bf33
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions docs/examples/retrievers/ensemble_retrieval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@
"source": [
"# try loading great gatsby\n",
"\n",
"documents = SimpleDirectoryReader(input_files=[\"../../../examples/gatsby/gatsby_full.txt\"]).load_data()"
"documents = SimpleDirectoryReader(\n",
" input_files=[\"../../../examples/gatsby/gatsby_full.txt\"]\n",
").load_data()"
]
},
{
Expand Down Expand Up @@ -133,23 +135,23 @@
"vector_indices = []\n",
"query_engines = []\n",
"for chunk_size in chunk_sizes:\n",
" print(f'Chunk Size: {chunk_size}')\n",
" print(f\"Chunk Size: {chunk_size}\")\n",
" service_context = ServiceContext.from_defaults(chunk_size=chunk_size, llm=llm)\n",
" service_contexts.append(service_context)\n",
" nodes = service_context.node_parser.get_nodes_from_documents(documents)\n",
" \n",
"\n",
" # add chunk size to nodes to track later\n",
" for node in nodes:\n",
" node.metadata[\"chunk_size\"] = chunk_size\n",
" node.excluded_embed_metadata_keys = [\"chunk_size\"]\n",
" node.excluded_llm_metadata_keys = [\"chunk_size\"]\n",
" \n",
"\n",
" nodes_list.append(nodes)\n",
" \n",
"\n",
" # build vector index\n",
" vector_index = VectorStoreIndex(nodes)\n",
" vector_indices.append(vector_index)\n",
" \n",
"\n",
" # query engines\n",
" query_engines.append(vector_index.as_query_engine())"
]
Expand All @@ -173,7 +175,7 @@
" retriever=vector_index.as_retriever(),\n",
" description=f\"Retrieves relevant context from the Great Gatsby (chunk size {chunk_size})\",\n",
" )\n",
" retriever_tools.append(retriever_tool)\n"
" retriever_tools.append(retriever_tool)"
]
},
{
Expand All @@ -185,15 +187,13 @@
},
"outputs": [],
"source": [
"from llama_index.selectors.pydantic_selectors import (\n",
" PydanticMultiSelector\n",
")\n",
"from llama_index.selectors.pydantic_selectors import PydanticMultiSelector\n",
"from llama_index.retrievers import RouterRetriever\n",
"\n",
"\n",
"retriever = RouterRetriever(\n",
" selector=PydanticMultiSelector.from_defaults(llm=llm, max_outputs=4),\n",
" retriever_tools=retriever_tools\n",
" retriever_tools=retriever_tools,\n",
")"
]
},
Expand Down Expand Up @@ -221,7 +221,9 @@
}
],
"source": [
"nodes = await retriever.aretrieve(\"Describe and summarize the interactions between Gatsby and Daisy\")"
"nodes = await retriever.aretrieve(\n",
" \"Describe and summarize the interactions between Gatsby and Daisy\"\n",
")"
]
},
{
Expand Down Expand Up @@ -560,7 +562,12 @@
"outputs": [],
"source": [
"# define reranker\n",
"from llama_index.indices.postprocessor import LLMRerank, SentenceTransformerRerank, CohereRerank\n",
"from llama_index.indices.postprocessor import (\n",
" LLMRerank,\n",
" SentenceTransformerRerank,\n",
" CohereRerank,\n",
")\n",
"\n",
"# reranker = LLMRerank()\n",
"# reranker = SentenceTransformerRerank(top_n=10)\n",
"reranker = CohereRerank(top_n=10)"
Expand All @@ -578,10 +585,7 @@
"# define RetrieverQueryEngine\n",
"from llama_index.query_engine import RetrieverQueryEngine\n",
"\n",
"query_engine = RetrieverQueryEngine(\n",
" retriever,\n",
" node_postprocessors=[reranker]\n",
")"
"query_engine = RetrieverQueryEngine(retriever, node_postprocessors=[reranker])"
]
},
{
Expand All @@ -604,7 +608,9 @@
}
],
"source": [
"response = query_engine.query(\"Describe and summarize the interactions between Gatsby and Daisy\")"
"response = query_engine.query(\n",
" \"Describe and summarize the interactions between Gatsby and Daisy\"\n",
")"
]
},
{
Expand Down Expand Up @@ -1003,7 +1009,9 @@
}
],
"source": [
"display_response(response, show_source=True, source_length=500, show_source_metadata=True)"
"display_response(\n",
" response, show_source=True, source_length=500, show_source_metadata=True\n",
")"
]
},
{
Expand All @@ -1019,6 +1027,7 @@
"from collections import defaultdict\n",
"import pandas as pd\n",
"\n",
"\n",
"def mrr_all(metadata_values, metadata_key, source_nodes):\n",
" # source nodes is a ranked list\n",
" # go through each value, find out positioning in source_nodes\n",
Expand All @@ -1027,18 +1036,17 @@
" mrr = 0\n",
" for idx, source_node in enumerate(source_nodes):\n",
" if source_node.node.metadata[metadata_key] == metadata_value:\n",
" mrr = 1 / (idx+1)\n",
" mrr = 1 / (idx + 1)\n",
" break\n",
" else:\n",
" continue\n",
" \n",
"\n",
" # normalize AP, set in dict\n",
" value_to_mrr_dict[metadata_value] = mrr\n",
" \n",
"\n",
" df = pd.DataFrame(value_to_mrr_dict, index=[\"MRR\"])\n",
" df.style.set_caption(\"Mean Reciprocal Rank\")\n",
" return df\n",
" "
" return df"
]
},
{
Expand Down Expand Up @@ -1108,7 +1116,7 @@
"source": [
"# Compute the Mean Reciprocal Rank for each chunk size (higher is better)\n",
"# we can see that chunk size of 256 has the highest ranked results.\n",
"print('Mean Reciprocal Rank for each Chunk Size')\n",
"print(\"Mean Reciprocal Rank for each Chunk Size\")\n",
"mrr_all(chunk_sizes, \"chunk_size\", response.source_nodes)"
]
},
Expand Down Expand Up @@ -1143,7 +1151,9 @@
},
"outputs": [],
"source": [
"response_1024 = query_engine_1024.query(\"Describe and summarize the interactions between Gatsby and Daisy\")"
"response_1024 = query_engine_1024.query(\n",
" \"Describe and summarize the interactions between Gatsby and Daisy\"\n",
")"
]
},
{
Expand Down

0 comments on commit 859bf33

Please sign in to comment.