diff --git a/3.Reranker - Q.Transformation - Res.Synthesis/main.py b/3.Reranker - Q.Transformation - Res.Synthesis/main.py index 19c9220..5fa4a48 100644 --- a/3.Reranker - Q.Transformation - Res.Synthesis/main.py +++ b/3.Reranker - Q.Transformation - Res.Synthesis/main.py @@ -66,6 +66,26 @@ async def start(): ).send() +async def set_sources(response, response_message): + label_list = [] + count = 1 + for sr in response.source_nodes: + elements = [ + cl.Text( + name="S" + str(count), + content=f"{sr.node.text}", + display="side", + size="small", + ) + ] + response_message.elements = elements + label_list.append("S" + str(count)) + await response_message.update() + count += 1 + response_message.content += "\n\nSources: " + ", ".join(label_list) + await response_message.update() + + @cl.on_message async def main(message: cl.Message): query_engine = cl.user_session.get("query_engine") @@ -93,21 +113,5 @@ async def main(message: cl.Message): message_history = message_history[-6:] cl.user_session.set("message_history", message_history) - label_list = [] - count = 1 - - for sr in response.source_nodes: - elements = [ - cl.Text( - name="S" + str(count), - content=f"{sr.node.text}", - display="side", - size="small", - ) - ] - response_message.elements = elements - label_list.append("S" + str(count)) - await response_message.update() - count += 1 - response_message.content += "\n\nSources: " + ", ".join(label_list) - await response_message.update() + if response.source_nodes: + await set_sources(response, response_message)