In [1]:
import pickle
from model import Node, categorize_next, display_lazy_tree, format_node, format_time

In [2]:
import chardet

with open("amazon_products.csv", 'rb') as f:
    result = chardet.detect(f.read(8*1024*1024))
    print(result)

{'encoding': 'utf-8', 'confidence': 0.99, 'language': ''}


In [3]:
import pandas as pd
df = pd.read_csv("amazon_products.csv", encoding=result["encoding"])

In [4]:
df = df.sample(20000)
df.head()

Unnamed: 0,asin,title,imgUrl,productURL,stars,reviews,price,listPrice,category_id,isBestSeller,boughtInLastMonth
494595,B0BCLYGL58,Tree House Pad & Paper | 27” x 17” Packing Pap...,https://m.media-amazon.com/images/I/7150xDXI0K...,https://www.amazon.com/dp/B0BCLYGL58,4.9,0,21.97,0.0,160,False,0
379172,B07DNB11ZL,Long Wallets for Men Leather RFID Blocking Bif...,https://m.media-amazon.com/images/I/71mzFekYsj...,https://www.amazon.com/dp/B07DNB11ZL,4.4,0,32.99,0.0,112,False,50
659430,B07PFF5GZD,Mens Analogue-Digital Quartz Watch with Resin ...,https://m.media-amazon.com/images/I/71e6DwbvZ6...,https://www.amazon.com/dp/B07PFF5GZD,4.6,0,139.91,148.89,113,False,0
582438,B0BPMMDLDV,Godox XProII-S Wireless Flash Trigger for Sony...,https://m.media-amazon.com/images/I/71OaPVoSnX...,https://www.amazon.com/dp/B0BPMMDLDV,4.7,0,89.0,0.0,79,False,0
561386,B09TS1BK51,Fashion Angels Laptop Beauty Pallete - Make-up...,https://m.media-amazon.com/images/I/81BBOQj6BS...,https://www.amazon.com/dp/B09TS1BK51,4.6,0,29.99,0.0,270,False,0


In [5]:
with open("./tree.pkl", "rb") as f:
    root = pickle.load(f)

In [6]:
from typing import Callable, Optional, Tuple
from langchain_core.language_models.chat_models import BaseChatModel

from model.data import TokenCounts

def classify(item_description: str, root: Node, create_llm: Callable[[], BaseChatModel]) -> Tuple[Optional[Node], TokenCounts]:
    def classify_recur(item_description: str, node: Node) -> Tuple[Optional[Node], TokenCounts]:
        if node.is_leaf():
            return node, TokenCounts()
        children = [*node.children]
        all_tokens = TokenCounts()
        print(f"Entering {node.condition}")
        while len(children) > 0:
            print(f"Trying {node.condition} with {len(children)} unexplored child nodes.")
            choice, tokens = categorize_next(item_description=item_description, nodes=children, create_llm=create_llm)
            print(f"Chose {choice.condition if choice else "None"}")
            all_tokens += tokens
            if choice is None:
                print(f"Retrying previous node with mask.")
                return (None, all_tokens)
            
            print(f"Trying chosen subtree.")
            result, tokens = classify_recur(item_description=item_description, node=choice)
            print(f"Subtree yielded choice {result.condition if result else "None"}.")
            all_tokens += tokens
            if result is not None:
                print(f"Yielding subtree result {result.condition}")
                return (result, all_tokens)
            
            print(f"Retrying {node.condition} with mask.")
            children.remove(choice)
            
        return None, all_tokens
            
    return classify_recur(item_description=item_description, node=root)

In [10]:
from langchain_ollama import ChatOllama
create_llm = lambda: ChatOllama(
    model="qwen2.5:32b",
    # temperature=0,
)

In [11]:
display_lazy_tree(root, max_initial_depth=3)

VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [2b7b01f4-e27f-4f6a-…

In [12]:
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_community.callbacks import get_openai_callback


class Description(BaseModel):
    description: str = Field(description="The shortened description")

def prompt_tentative_description(item_description: str, create_llm: Callable[[], BaseChatModel]) -> Tuple[Description, TokenCounts]:
    template = """
    From the given e-commerce item description give a simplified short, generic description that would be useful in categorizing an item into product standardizations.
    From specific product names and branding provide a description that includes the true nature of the product.
    Use the given tool to provide your answer.
    
    {item}
    """.strip()
    
    prompt = ChatPromptTemplate.from_template(template)
    prompt = prompt.format(item=item_description)
    
    with get_openai_callback() as cb:
        llm = create_llm().with_structured_output(Description)
        response = llm.invoke([HumanMessage(prompt)])
        
    return response, TokenCounts(prompt=cb.prompt_tokens, completion=cb.completion_tokens, total=cb.total_tokens)
    

In [13]:
from ipywidgets import IntProgress, VBox, Label
from IPython.display import display
import time

def create_classification_sample(df: pd.DataFrame, root: Node) -> Tuple[pd.DataFrame, TokenCounts]:
    items: list[dict[str, any]] = []
    all_tokens = TokenCounts()

    total_rows = len(df)
    progress_bar = IntProgress(min=0, max=total_rows, description='Progress:', bar_style='info')
    label = Label(value=f"0/{total_rows} items processed")
    time_label = Label(value="Estimating time...")
    display(VBox([progress_bar, label, time_label]))

    processed_rows = 0
    start_time = time.time()

    for i, row in df.iterrows():
        short_desc, desc_tokens = prompt_tentative_description(item_description=row["title"], create_llm=create_llm)
        node, tokens = classify(item_description=short_desc.description, root=root, create_llm=create_llm)
        record = {
            "item": row["title"],
            "imputed_desc": short_desc.description,
            "classification": format_node(node) if node else "None"
        }
        items.append(record)
        all_tokens += tokens + desc_tokens

        processed_rows += 1
        elapsed_time = time.time() - start_time
        progress_bar.value = processed_rows
        label.value = f"{processed_rows}/{total_rows} items processed"
        estimated_total_time = elapsed_time / processed_rows * total_rows
        remaining_time = estimated_total_time - elapsed_time
        time_label.value = f"Estimated remaining time: {format_time(remaining_time)}"

    result = pd.DataFrame(items)
    return result, all_tokens


In [14]:
df_sample = df.sample(100)

args: list[dict[str,any]] = [
    {
        "model_file": "tree.pkl",
        "output_file": "classifications_cleaned_tree.csv"
    },
    {
        "model_file": "tree_improved.pkl",
        "output_file": "classifications_improved_tree.csv"
    }
]

output_dir = "classification_results"
tokens_output_file = "tokens.csv"

import os
import time

os.makedirs(output_dir, exist_ok=True)


tokens_dicts = []
for argset in args:
    start_time = time.time()
    with open(argset["model_file"], "rb") as f:
        root = pickle.load(f)
    
    out_df, tokens = create_classification_sample(df_sample, root)
    elapsed_time = time.time() - start_time
    tokens_dicts.append({
        **argset,
        "prompt": tokens.prompt,
        "completion": tokens.completion,
        "total": tokens.total,
        "elapsed_seconds": elapsed_time
    })
    out_df.to_csv(argset["output_file"])
    print(f"Saved classifications from {argset['model_file']} to {output_dir}/{argset['output_file']}")
    
tokens_df = pd.DataFrame(tokens_dicts)
tokens_df.to_csv(tokens_output_file)

VBox(children=(IntProgress(value=0, bar_style='info', description='Progress:'), Label(value='0/100 items proce…

Entering None
Trying None with 18 unexplored child nodes.
Chose Industrial Machinery and Tools
Trying chosen subtree.
Entering Industrial Machinery and Tools
Trying Industrial Machinery and Tools with 3 unexplored child nodes.
Chose Tools and General Machinery
Trying chosen subtree.
Entering Tools and General Machinery
Trying Tools and General Machinery with 4 unexplored child nodes.
Chose Hand tools
Trying chosen subtree.
Entering Hand tools
Trying Hand tools with 19 unexplored child nodes.
Chose Measuring and layout tools
Trying chosen subtree.
Entering Measuring and layout tools
Trying Measuring and layout tools with 3 unexplored child nodes.
Chose None
Retrying previous node with mask.
Subtree yielded choice None.
Retrying Hand tools with mask.
Trying Hand tools with 18 unexplored child nodes.
Chose Tool attachments and accessories
Trying chosen subtree.
Entering Tool attachments and accessories
Trying Tool attachments and accessories with 16 unexplored child nodes.
Chose Crimping 

VBox(children=(IntProgress(value=0, bar_style='info', description='Progress:'), Label(value='0/100 items proce…

Entering BusinessAndIndustrialSolutions
Trying BusinessAndIndustrialSolutions with 18 unexplored child nodes.
Chose Oilfield Mining & Industrial Equipment
Trying chosen subtree.
Entering Oilfield Mining & Industrial Equipment
Trying Oilfield Mining & Industrial Equipment with 3 unexplored child nodes.
Chose Industrial Equipment and Tooling
Trying chosen subtree.
Entering Industrial Equipment and Tooling
Trying Industrial Equipment and Tooling with 4 unexplored child nodes.
Chose Metalwork_and_Construction_Tools
Trying chosen subtree.
Entering Metalwork_and_Construction_Tools
Trying Metalwork_and_Construction_Tools with 19 unexplored child nodes.
Chose None
Retrying previous node with mask.
Subtree yielded choice None.
Retrying Industrial Equipment and Tooling with mask.
Trying Industrial Equipment and Tooling with 3 unexplored child nodes.
Chose None
Retrying previous node with mask.
Subtree yielded choice None.
Retrying Oilfield Mining & Industrial Equipment with mask.
Trying Oilfield