In [None]:
import packages
from configs import settings, const

import asyncio, os, time, yaml, json, datetime, copy
from typing import Any, AsyncGenerator, Generator, Callable, Literal, Optional, TypeAlias, Union
from tqdm import tqdm
from loguru import logger
from pprint import pprint
from rich import print as rprint

from toolkit.llm.llama_index import (
	agents, cores, deploys as dpls, evaluation, messages, models, 
	observability, types, utils as utils_llama_index, workflows as wfs
)
from toolkit.llm.llama_index.data import loading, querying, storing

from toolkit.utils import utils, typer as t
from toolkit.utils.llm import measure_performance, main as utils_llm


# LlamaIndex

In [None]:
# response = query_engine.query(query)
# pprint(response)

# start_time = time.time()
# if not query_engine._response_synthesizer._streaming:
# 		response = query_engine.query(query)
# 		total_time = time.time() - start_time

# else:
# 		response = query_engine.query(query)
		
# 		for i, token in enumerate(response.response_gen):
# 				if i == 0:
# 						first_token_time = time.time() - start_time
# 						print(f"First token time: {first_token_time:.2f} seconds", f"\n{'-'*80}")
				
# 				print(token, end="", flush=True)
		
# 		total_time = time.time() - start_time

# res = model.chat(
# 	[
# 		# ChatMessage(role="user", content="Hi there!"),
# 		# ChatMessage(role="assistant", content="Yes?"),
# 		# ChatMessage(role="user", content="What is the meaning of life?"),
# 		prompts.ChatMessage(role="user", content=prompt.format(user_question=query, retrieved_data=result_texts)),
# 	]
# )

# pprint(res.message.content)

## Structured predict

In [None]:
prompts_agent_car = settings.prompts_agent_car

prompt_categorize_query = prompts_agent_car["CategorizeQuestion"]["dev"]

examples = prompts_agent_car["CategorizeQuestion"]["examples"]
examples = await utils_llm.convert_examples_to_string(examples)

class UserQueryCategory(str, t.Enum):
	CAR_CONTROL = "car_control"
	CAR_MANUAL = "car_manual"

class TypeUserQueryCategory(t.BaseModel):
	user_query_category: UserQueryCategory

user_query = "It's too cold"

user_query_category = cores.Settings.llm.structured_predict(
	output_cls=TypeUserQueryCategory,
	prompt=messages.PromptTemplate(prompt_categorize_query), 
 	user_query=user_query, examples=examples,
)

pprint(user_query_category)

# Dev

In [None]:
prompts_agent_car = settings.prompts_agent_car

class EvUserQueryCategorized(wfs.Event):
	user_query_category: t.UserQueryCategory

class EvFlowStartedRag(wfs.Event):
	pass

class EvFlowStartedControl(wfs.Event):
	pass

class EvFlowCtrlTasksSeparated(wfs.Event):
	pass

class EvFlowCtrlTasksStarted(wfs.Event):
	task: str

class EvFlowCtrlTasksCompleted(wfs.Event):
	task_result: str

class EvFlowCompletedRag(wfs.Event):
	pass

class EvFlowCompletedControl(wfs.Event):
	pass

class EvFlowCompleted(wfs.Event):
	pass

class EvHumanFeedbackCompleted(wfs.Event):
	human_feedback: dict[str, Any]

class EvHumanSatisfied(wfs.Event):
	pass

class MyWorkflow(wfs.Workflow):
	@wfs.step()
	async def categorize_user_query(
		self, ctx: wfs.Context, ev: wfs.StartEvent,
	) -> EvUserQueryCategorized:
		user_query = ev.get("user_query", "")

		user_query_category = await apis_car.categorize_user_query(user_query)
		print(user_query_category)
  
		await ctx.set("user_query", user_query)
		await ctx.set("user_query_category", user_query_category)

		return EvUserQueryCategorized(user_query_category=user_query_category)
	
	@wfs.step()
	async def start_flow(
		self, ctx: wfs.Context, ev: EvUserQueryCategorized
	) -> EvFlowStartedRag | EvFlowStartedControl:
		user_query_category: t.UserQueryCategory = await ctx.get("user_query_category")
		
		flow_mapping = {
			"car_manual": ("rag", EvFlowStartedRag),
			"car_control": ("control", EvFlowStartedControl)
		}

		if user_query_category in flow_mapping:
			task, event_class = flow_mapping[user_query_category]
			
			flow_info = {
				"activated": True,
				"task": {
					"name": task,
				}
			}
			await ctx.set("flow_info", flow_info)
			return event_class()
	
	@wfs.step()
	async def run_flow_rag(
		self, ctx: wfs.Context, ev: EvFlowStartedRag,
	) -> EvFlowCompletedRag:
		flow_info = await ctx.get("flow_info")
		user_query = await ctx.get("user_query")

		result = await apis_rag.do_querying(user_query=user_query, mode="achat")

		if flow_info:
			flow_info["completed"] = True
			flow_info["result"] = result
	
		await ctx.set("flow_info", flow_info)

		return EvFlowCompletedRag()

	@wfs.step()
	async def run_flow_control(
		self, ctx: wfs.Context, ev: EvFlowStartedControl,
	) -> EvFlowCtrlTasksSeparated:
		flow_info = await ctx.get("flow_info")
		user_query = await ctx.get("user_query")

		tasks = await apis_car.separate_tasks(user_query=user_query)
		n_tasks = len(tasks)

		if flow_info:
			flow_info["task"]["tasks"] = {}
			for task in tasks:
				flow_info["task"]["tasks"][task] = None
	
			flow_info["task"]["n_tasks"] = n_tasks

		await ctx.set("flow_info", flow_info)
	
		return EvFlowCtrlTasksSeparated()
	
	@wfs.step(num_workers=5)
	async def run_control_tasks(
		self, ctx: wfs.Context, ev: EvFlowCtrlTasksSeparated,
	) -> EvFlowCtrlTasksStarted | None:
		flow_info = await ctx.get("flow_info")

		tasks = flow_info["task"]["tasks"].keys()
		
		for task in tasks:
			ctx.send_event(EvFlowCtrlTasksStarted(task=task))
		
		return None
	
	@wfs.step(num_workers=5)
	async def process_control_task(
		self, ctx: wfs.Context, ev: EvFlowCtrlTasksStarted,
	) -> EvFlowCtrlTasksCompleted:
		flow_info = await ctx.get("flow_info")

		task = ev.task
		user_query = task
		
		result = await apis_car.do_controlling(
			user_query=user_query,
			mode="achat",
		)

		if flow_info:
			flow_info["task"]["tasks"][task] = result
	
		return EvFlowCtrlTasksCompleted(task_result=result)

	@wfs.step()
	async def process_control_tasks(
		self, ctx: wfs.Context, ev: EvFlowCtrlTasksCompleted,
	) -> EvFlowCompletedControl:
		flow_info = await ctx.get("flow_info")
		tasks = flow_info["task"]["tasks"]
		n_tasks = flow_info["task"]["n_tasks"]

		print(tasks)
	
		results = ctx.collect_events(ev, [EvFlowCtrlTasksCompleted] * n_tasks)

		print("❌")
		print(results)
		print("❌")
	
		valid_results = [res.task_result for res in results if res and res.task_result]
		results_text = "\n".join(valid_results) if valid_results else "Some tasks failed to complete"
		
		print(results)
		if flow_info:
			flow_info["completed"] = True
			flow_info["result"] = results_text

		await ctx.set("flow_info", flow_info)
		
		return EvFlowCompletedControl()

	@wfs.step()
	async def complete_flow(
		self, ctx: wfs.Context, ev: EvFlowCompletedRag | EvFlowCompletedControl
	) -> EvFlowCompleted:
		flow_info = await ctx.get("flow_info")

		if flow_info:
			flow_info["confirmed"] = True
		pprint(flow_info)
		await ctx.set("flow_info", flow_info)
	
		return EvFlowCompleted()

	@wfs.step()
	async def human_feedback(
		self, ctx: wfs.Context, ev: EvFlowCompleted,
	) -> EvHumanFeedbackCompleted:
		human_feedback = {
			"feedback": "OK!",
			"retry": False,
		}

		await ctx.set("human_feedback", human_feedback)
		return EvHumanFeedbackCompleted(human_feedback=human_feedback)
	
	@wfs.step()
	async def retry(
		self, ctx: wfs.Context, ev: EvHumanFeedbackCompleted
	) -> EvHumanSatisfied | EvUserQueryCategorized:
		human_feedback = await ctx.get("human_feedback")
	
		if human_feedback["retry"] == True:
			return EvUserQueryCategorized(
				user_query_category=await ctx.get("user_query_category")
			)
		else:
			return EvHumanSatisfied()

	@wfs.step()
	async def stop(
		self, ctx: wfs.Context, ev: EvHumanSatisfied,
	) -> wfs.StopEvent:
		
		utils.print_dict_as_table(ctx.data)
	
		flow_info: dict = await ctx.get("flow_info")
		result = flow_info["result"]
		
		return wfs.StopEvent(result=result)

async def main():
	workflow = MyWorkflow(timeout=60, verbose=True)

	# user_query = queries[2]
	# user_query = "It's too hot"
	# user_query = "Yes"
	user_query = "Activate the AC mode. Increase front wiper speed"
	# user_query = "Is the car locked? Is the car trunk opened?"
	result = await workflow.run(user_query=user_query)

wfs.draw_all_possible_flows(MyWorkflow)

asyncio.run(main())