In [36]:
import pickle

with open("./tree.pkl", "rb") as f:
    root = pickle.load(f)

In [37]:
from model import Node

In [None]:
from typing import Callable
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_community.callbacks import get_openai_callback
from typing import Callable, Optional, Tuple
from langchain_core.language_models.chat_models import BaseChatModel
from model.data import TokenCounts


create_llm = lambda: ChatOllama(
    model="qwen2.5:32b",
    # temperature=0,
)

class ImprovedCategory(BaseModel):
    category_name: str = Field(description="The improved category name which properly describes the elements within; in plain english.")

def prompt_improved_category_description(node: Node, create_llm: Callable[[], BaseChatModel]) -> Tuple[ImprovedCategory, TokenCounts]:
    template = """
    With the following old category name and the items within it, create a new category name which more closely describes its contents.
    Someone reading the new category name should immediately know if the item they are looking for is contained within the category.
    The name should be a reasonably short length so that it can be read quickly and entirely in english.
    The name should be specific to only the contents within so that it is not ambiguous with other categories.
    
    Old Category: {old_category}
    
    Items:
    {items}
    """.strip()    
    
    
    prompt = ChatPromptTemplate.from_template(template)
    items = "* " + "\n* ".join([n.condition for n in node.children])
    prompt = prompt.format(old_category=node.condition, items=items)
    
    with get_openai_callback() as cb:
        llm = create_llm().with_structured_output(ImprovedCategory)
        response = llm.invoke([HumanMessage(prompt)])
        
    return response, TokenCounts(prompt=cb.prompt_tokens, completion=cb.completion_tokens, total=cb.total_tokens)

In [39]:
from ipywidgets import IntProgress, VBox, Label
from IPython.display import display
import time

def seconds_to_iso_format(seconds: float) -> str:
    days, remainder = divmod(int(seconds), 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, secs = divmod(remainder, 60)
    return f"{days}:{hours}:{minutes}:{secs}"

def improve_category_descriptions(root: Node, create_llm: Callable[[], BaseChatModel]) -> TokenCounts:
    def count_non_leaf_nodes(node: Node) -> int:
        if node.is_leaf():
            return 0
        return 1 + sum(count_non_leaf_nodes(child) for child in node.children)

    total_nodes = count_non_leaf_nodes(root)

    progress_bar = IntProgress(min=0, max=total_nodes, description='Progress:', bar_style='info')
    label = Label(value=f"0/{total_nodes} nodes processed")
    time_label = Label(value="Estimating time...")
    display(VBox([progress_bar, label, time_label]))

    processed_nodes = 0
    start_time = time.time()

    def recur(node: Node) -> TokenCounts:
        nonlocal processed_nodes
        if node.is_leaf():
            return TokenCounts()
        all_tokens = TokenCounts()
        for n in node.children:
            all_tokens += recur(n)
        
        improved_desc, tokens = prompt_improved_category_description(node=node, create_llm=create_llm)
        print(f"Created new category name: {node.condition} -> {improved_desc.category_name}")
        node.condition = improved_desc.category_name

        processed_nodes += 1
        elapsed_time = time.time() - start_time
        progress_bar.value = processed_nodes
        label.value = f"{processed_nodes}/{total_nodes} nodes processed"
        estimated_total_time = elapsed_time / processed_nodes * total_nodes
        remaining_time = estimated_total_time - elapsed_time
        time_label.value = f"Estimated remaining time: {seconds_to_iso_format(remaining_time)}"
        
        return all_tokens + tokens

    return recur(root)

In [40]:
improve_category_descriptions(root, create_llm=create_llm)

VBox(children=(IntProgress(value=0, bar_style='info', description='Progress:', max=4359), Label(value='0/4359 …

Created new category name: Angles -> Metallic & Non-Metallic Angles
Created new category name: Beams -> Metallic & Composite Structural Beams
Created new category name: Channels -> Metallic & Composite Profile Channels
Created new category name: Foil -> Metallic & Plastic Foil Sheets
Created new category name: Metallic Plates -> Metal Sheets and Plates
Created new category name: Non-Metallic Plates -> Composite & Synthetic Plates
Created new category name: Plate -> Metallic & Synthetic Plates
Created new category name: Profiles -> Metallic & Non-Metallic Profiles
Created new category name: Rod -> Metallic & Non-Metallic Rods
Created new category name: Piling -> Construction Pilings
Created new category name: Post -> Structural Posts
Created new category name: Rails -> Track Rails
Created new category name: Grating -> Industrial Gratings
Created new category name: Honeycomb core -> Metallic & Composite Honeycomb Cores
Created new category name: Structural products -> Wood and Structural

TokenCounts(prompt=855752, completion=75469, total=931221)

In [41]:
with open("./tree_improved.pkl", "wb") as f:
    pickle.dump(root, f)

In [42]:
from model import display_lazy_tree

display_lazy_tree(root, max_initial_depth=3)

VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [61518b9d-27a8-4858-…