In [175]:
import pickle

with open("./tree_v4.pkl", "rb") as f:
    root = pickle.load(f)

In [176]:
from model import Node, optimize_tree

In [177]:
from typing import Callable
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_community.callbacks import get_openai_callback
from typing import Callable, Optional, Tuple
from langchain_core.language_models.chat_models import BaseChatModel
from model.data import TokenCounts
from langchain_core.embeddings import Embeddings


create_llm = lambda: ChatOllama(
    model="qwen2.5:32b",
    # temperature=0,
)

from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(
    model="mxbai-embed-large",
)

class ImprovedCategory(BaseModel):
    reasoning: str = Field(description="Reasoning on your choice.")
    category_name: str = Field(description="The improved category name which properly describes the elements within; in plain english.")

def prompt_improved_category_description(node: Node, create_llm: Callable[[], BaseChatModel]) -> Tuple[ImprovedCategory, TokenCounts]:
    template = """
    With the following items within a category, create a new category name which closely describes its contents.
    Someone reading the new category name should immediately know if the item they are looking for is contained within the category.
    The name should be a reasonably short length so that it can be read quickly and entirely in english.
    The name should be specific to only the contents within so that it is not ambiguous with other categories.
    The name should describe ALL of the contents within.
    
    Items:
    {items}
    """.strip()    
    
    
    prompt = ChatPromptTemplate.from_template(template)
    items = "\n".join([n.condition for n in node.children])
    prompt = prompt.format(items=items)
    
    with get_openai_callback() as cb:
        llm = create_llm().with_structured_output(ImprovedCategory)
        response = llm.invoke([HumanMessage(prompt)])
        
    return response, TokenCounts(prompt=cb.prompt_tokens, completion=cb.completion_tokens, total=cb.total_tokens)


class DuplicateCategories(BaseModel):
    reasoning: str = Field(description='Reasoning on your choice of categories.')
    duplicate_categories: list[int] = Field(description='A single subset of categories that are duplicates-of or near-duplicates-of each other. Give the numbers of the categories chosen according to the list. Example: [1, 2] where 1 and 2 are "Beauty and Personal Care" and "Personal Care Products"')
    
    
def prompt_duplicate_categories(node: Node, create_llm: Callable[[], BaseChatModel]) -> Tuple[list[Node], TokenCounts]:
    template = """
    You will be given a list of categories that may contain duplicates or near-duplicates.
    Two or more categories should be considered duplicates or near-duplicates if and only if there would be significant overlap in the items that fit within them or if they are ambiguous with each other.
    Give a set of duplicates or near-duplicates if they are found. Only include a single subset of duplicates that overlap with each other.
    
    Categories:
    {cats}
    """.strip()    
    
    
    prompt = ChatPromptTemplate.from_template(template)
    items = "\n ".join([f"{i+1}. {n.condition}" for i, n in enumerate(node.children)])
    prompt = prompt.format(cats=items)
    
    with get_openai_callback() as cb:
        llm = create_llm().with_structured_output(DuplicateCategories)
        response: DuplicateCategories = llm.invoke([HumanMessage(prompt)])
    
    res = []
    for idx in response.duplicate_categories:
        res.append(node.children[idx-1])
        
    return res, TokenCounts(prompt=cb.prompt_tokens, completion=cb.completion_tokens, total=cb.total_tokens)

In [178]:
from ipywidgets import IntProgress, VBox, Label
from IPython.display import display
import time

def seconds_to_iso_format(seconds: float) -> str:
    days, remainder = divmod(int(seconds), 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, secs = divmod(remainder, 60)
    return f"{days}:{hours}:{minutes}:{secs}"

def improve_category_descriptions(root: Node, create_llm: Callable[[], BaseChatModel]) -> TokenCounts:
    def count_non_leaf_nodes(node: Node) -> int:
        if node.is_leaf():
            return 0
        count = 0 if node.is_root() else 1
        return count + sum(count_non_leaf_nodes(child) for child in node.children)

    total_nodes = count_non_leaf_nodes(root)

    progress_bar = IntProgress(min=0, max=total_nodes, description='Progress:', bar_style='info')
    label = Label(value=f"0/{total_nodes} nodes processed")
    time_label = Label(value="Estimating time...")
    display(VBox([progress_bar, label, time_label]))

    processed_nodes = 0
    start_time = time.time()

    def recur(node: Node) -> TokenCounts:
        nonlocal processed_nodes
        if node.is_leaf():
            return TokenCounts()
        all_tokens = TokenCounts()
        for n in node.children:
            all_tokens += recur(n)
        
        improved_desc, tokens = prompt_improved_category_description(node=node, create_llm=create_llm)
        print(f"Created new category name: {node.condition} -> {improved_desc.category_name}")
        node.condition = improved_desc.category_name

        processed_nodes += 1
        elapsed_time = time.time() - start_time
        progress_bar.value = processed_nodes
        label.value = f"{processed_nodes}/{total_nodes} nodes processed"
        estimated_total_time = elapsed_time / processed_nodes * total_nodes
        remaining_time = estimated_total_time - elapsed_time
        time_label.value = f"Estimated remaining time: {seconds_to_iso_format(remaining_time)}"
        
        return all_tokens + tokens

    tokens = TokenCounts()
    for c in root.children:
        tokens += recur(c)
    
    return tokens

In [None]:
from typing import List

MAX_CHILDREN = 15

def fill_to_max_children(node: 'Node', max_children: int):
    # Process bottom-up
    for child in node.children:
        fill_to_max_children(child, max_children)

    # Try to pull up children from child nodes, if possible
    i = 0
    while i < len(node.children):
        child = node.children[i]
        # Only consider if node can still add more children, and child is not a leaf
        if len(node.children) < max_children and not child.is_leaf():
            # How many new children would be added if we moved all of child's children up?
            num_to_add = len(child.children) - 1  # since the child itself would be removed
            possible_new_count = len(node.children) + num_to_add

            # Can we fit all of child's children under node? (after removing child itself)
            if possible_new_count <= max_children:
                node.children.pop(i)
                node.add_children(child.children)
                child.children = []  # clear, as good practice
                continue
        i += 1
        
def resolve_duplicates(node: Node, embeddings: Embeddings, create_llm: Callable[[], BaseChatModel], max_children: int) -> Tuple[bool, TokenCounts]:
    tokens = TokenCounts()
    did_resolve_duplicates = False
    if node.is_leaf():
        return did_resolve_duplicates, tokens
    
    if all([n.is_leaf() for n in node.children]):
        return did_resolve_duplicates, tokens
    
    for child in node.children:
        did_work, t = resolve_duplicates(child, embeddings=embeddings, create_llm=create_llm, max_children=max_children)
        did_resolve_duplicates = did_resolve_duplicates or did_work
        tokens += t
    
    duplicates, t = prompt_duplicate_categories(node, create_llm=create_llm)
    tokens += t
    
    if len(duplicates) > 0:
        if all([n.is_leaf() for n in duplicates]):
            return did_resolve_duplicates, tokens
        
        did_resolve_duplicates = True
        print(f"Found duplicates: {[d.condition for d in duplicates]}")
        new_parent = Node(condition="TBD", parent=node)
        for old_parent in duplicates:
            node.children.remove(old_parent)
            new_parent.add_children(old_parent.children)
        node.add_children([new_parent])
        tokens += optimize_tree(root=node, max_children=max_children, embeddings=embeddings, create_llm=create_llm)
        improved_category, t = prompt_improved_category_description(new_parent, create_llm=create_llm)
        tokens += t
        new_parent.condition = improved_category.category_name
        
            
    return did_resolve_duplicates, tokens



def improve_tree(root: Node, embeddings: Embeddings, create_llm: Callable[[], BaseChatModel], max_children: int) -> TokenCounts:
    # first fill out layers to best of ability
    # then, decide new category names based on content depth-first
    # then identify duplicates and reassign children
    #   if duplicates (2+ nodes), then combine into single node and begin subtree splitting
    #   then repeat process
    tokens = TokenCounts()
    fill_to_max_children(root, max_children=MAX_CHILDREN)
    t = improve_category_descriptions(root, create_llm=create_llm)
    tokens += t
    did_work, t = resolve_duplicates(root, embeddings=embeddings, create_llm=create_llm, max_children=max_children)
    tokens += t
    if did_work:
        tokens += improve_tree(root=root, embeddings=embeddings, create_llm=create_llm, max_children=max_children)
        
    return tokens

In [180]:
from model import display_lazy_tree

display_lazy_tree(root, max_initial_depth=8)

VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [b8e633af-c5bb-45d9-…

In [181]:
fill_to_max_children(root, max_children=MAX_CHILDREN)

In [182]:
display_lazy_tree(root, max_initial_depth=8)

VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [44ece323-c5b8-4ead-…

In [183]:
from model import check_tree

check_tree(root)

sub_branches: 74, avg: 5.6923076923076925, max: 15

leaves at this level: 6
sub_branches: 100, avg: 1.3513513513513513, max: 15

leaves at this level: 66
sub_branches: 77, avg: 0.77, max: 15

leaves at this level: 92
sub_branches: 10, avg: 0.12987012987012986, max: 6

leaves at this level: 74
sub_branches: 0, avg: 0.0, max: 0

leaves at this level: 10
total leaves: 248


In [184]:
improve_tree(root, embeddings=embeddings, create_llm=create_llm, max_children=MAX_CHILDREN)

VBox(children=(IntProgress(value=0, bar_style='info', description='Progress:', max=26), Label(value='0/26 node…

Created new category name: Baby & Kids Supplies -> Kids' Care & Toy Vehicles
Created new category name: Education & Research Supplies -> Children's Learning & Apparel
Created new category name: Home and Baby Essentials -> Baby & Small Pet Essentials
Created new category name: Consumer Goods -> Travel Gear & Accessories
Created new category name: Jewelry & Accessories -> Jewelry & Accessories
Created new category name: Fashions & Accessories -> Fashion & Apparel
Created new category name: Home Entertainment & Office Electronics -> Display & Audio Entertainment Systems
Created new category name: Electronics -> Consumer Tech & Entertainment
Created new category name: Car Electronics & Accessories -> Auto Parts & Accessories
Created new category name: Smart Home: New Smart Devices -> Smart Home Tech Solutions
Created new category name: Electronics & Gadgets -> Tech & Smart Solutions
Created new category name: DIY Projects & Craft Supplies -> Craft & Hobby Organization Supplies
Created new 

VBox(children=(IntProgress(value=0, bar_style='info', description='Progress:', max=87), Label(value='0/87 node…

Created new category name: Kids' Care & Toy Vehicles -> Baby & Kids Play & Care Essentials
Created new category name: Children's Learning & Apparel -> Children's Education & Play Essentials
Created new category name: Vehicle Parts & Accessories -> Vehicle Parts & Equipment
Created new category name: Industrial Supplies & Equipment -> Industrial Maintenance & Safety Supplies
Created new category name: Personal Care & Health Essentials -> Personal Care & Health Supplies
Created new category name: Personal Care & Health Supplies -> Personal Health & Grooming
Created new category name: Personal Health & Grooming -> Personal Care & Wellness Products
Created new category name: Health & Wellness Supplies -> Health & Personal Wellness
Created new category name: Pet & Animal Care Supplies -> Pet Care Supplies
Created new category name: Console Gaming Systems & Media -> Multi-Platform Game Consoles & Accessories
Created new category name: Multi-Platform Game Consoles & Accessories -> Console Gam

KeyboardInterrupt: 

In [None]:
with open("./tree_v4_improved.pkl", "wb") as f:
    pickle.dump(root, f)

In [None]:
from model import display_lazy_tree

display_lazy_tree(root, max_initial_depth=5)

VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [69b2e114-cf2d-43f5-…

: 