# Tree Creation From Classification Labels

## Method

Trees are created from the classification labels only. The data to be classified is not used and therefore labeled datasets are not required for this step. Labeled datasets will be useful for evaluation instead.

'Categories', 'Conditions', and 'Labels' are used interchangeably here. Categories are created from the labels' texts and new categories will be made in the same style as the original dataset.

### Algorithms

#### Tree Formation

* Create the tree using any heirarchical labels in the output labels if they exist. Otherwise create a tree with a single root node and one leaf for every classification label.
* Walk down the tree starting at the root and stop when the current node has more children than desired. Then:
    * Place the category texts of the child nodes in a vector store and generate embeddings for each of them.
    * Sample from the vector store to choose relatively spread out categories. This is done by performing KMeans clustering and choosing cluster representatives.
    * Provide the LLM with representative categories and ask it to create new categories that divide those further.
    * Classify the old child categories into the newly created categories.
    * Place old child categories that cannot be classified into the new categories as children of the current node.
    * Repeat this process until less than or equal to the desired number of categories is reached.

#### Classification

* Start at the root node and work down the tree.
* Prompt the LLM to choose the correct next category from among the children of the current node if a valid one exists.
* If a the LLM responds saying that node of the child categories are suitable for the item to be classified, mask/hide the current node and treat the parent of the current node as the next node.
* If the failure above occurs while the current node is the root, the item will be skipped and receive no classification.

## Code

### Reading The Dataset

The dataset used here is [UNSPSC Codes](https://data.ok.gov/dataset/unspsc-codes)

Detect the correct charset to use when reading the categories dataset. Then read the dataset as a DataFrame.

In [1]:
import chardet

with open("data-unspsc-codes.csv", 'rb') as f:
    result = chardet.detect(f.read())
    print(result)

{'encoding': 'Windows-1252', 'confidence': 0.73, 'language': ''}


In [2]:
import pandas as pd

df = pd.read_csv("data-unspsc-codes.csv", encoding=result['encoding'])

df.head()

Unnamed: 0,Segment,Segment Name,Family,Family Name,Class,Class Name,Commodity,Commodity Name
0,10000000,Live Plant and Animal Material and Accessories...,10100000,Live animals,10101500,Livestock,10101501,Cats
1,10000000,Live Plant and Animal Material and Accessories...,10100000,Live animals,10101500,Livestock,10101502,Dogs
2,10000000,Live Plant and Animal Material and Accessories...,10100000,Live animals,10101500,Livestock,10101504,Mink
3,10000000,Live Plant and Animal Material and Accessories...,10100000,Live animals,10101500,Livestock,10101505,Rats
4,10000000,Live Plant and Animal Material and Accessories...,10100000,Live animals,10101500,Livestock,10101506,Horses


### Dataset Exploration And Evaluation

Test the existing heirarchical structure to determine if it can be used immediately for classification.

In [3]:
cat_cols = ["Segment Name", "Family Name", "Class Name", "Commodity Name"]

In [4]:
from itertools import combinations

for i, col in enumerate(cat_cols):
    unique = pd.unique(df[col])
    n_unqiue = len(unique)
        
    total = len(df[col])
    print(f"col: {col}\nunqiue: {n_unqiue}\ntotal: {total}\n")
    
    if i < len(cat_cols) - 1:
        next_branch_counts = []
        next_branches: dict[str, set] = {}
        
        for uc in unique:
            df_next = df[df[col] == uc]
            next_branch_counts.append(len(pd.unique(df_next[cat_cols[i+1]])))
            next_branches[uc] = set(pd.unique(df_next[cat_cols[i+1]]))
    
        print(f"next level:\nmax:{max(next_branch_counts)}\navg:{sum(next_branch_counts)/len(next_branch_counts)}\n")
        
        ambiguous = False
        for a, b in combinations(next_branches.keys(), 2):
            intersection = next_branches[a] & next_branches[b]
            if len(intersection) > 0:
                ambiguous = True
                print(f"The following nodes appear in both {a} and {b} for {cat_cols[i+1]}: {intersection}")
            
        if ambiguous:    
            print(f"{col} has ambiguous branches\n")
        else:
            print(f"{col} does NOT have ambiguous branches\n")
            

col: Segment Name
unqiue: 57
total: 71502

next level:
max:43
avg:8.157894736842104

Segment Name does NOT have ambiguous branches

col: Family Name
unqiue: 465
total: 71502

next level:
max:67
avg:11.425806451612903

Family Name does NOT have ambiguous branches

col: Class Name
unqiue: 5313
total: 71502

next level:
max:99
avg:13.45793337097685

Class Name does NOT have ambiguous branches

col: Commodity Name
unqiue: 71502
total: 71502



From the results above, it can be seen that with the 4 existing heirarchical levels the number of children that would be created for each node would be up to ~100. This is more than our objective of 25 children max.

We will need to divide up these children to make classification using an LLM more reliable.

In [None]:
from model import create_tree_from_breadcrumbs, check_tree, Node, display_tree, create_vector_store, ask_model_category, format_node, optimize_tree, ProgressBars, display_lazy_tree, clean_tree

Create the tree using the breadcrumbs present in the dataset so we can retain the existing heirarchy as a head-start.

In [6]:
root = create_tree_from_breadcrumbs(df, breadcrumb_cols=["Segment Name", "Family Name", "Class Name", "Commodity Name"], extra_cols_map={"Segment Name": ["Segment"], "Family Name": ["Family"], "Class Name": ["Class"], "Commodity Name": ["Commodity"]})

Validate the new tree representation by performing a similar check to above and comparing. Things look consistent here.

In [7]:
check_tree(root)

sub_branches: 465, avg: 8.157894736842104, max: 43

leaves at this level: 0
sub_branches: 5313, avg: 11.425806451612903, max: 67

leaves at this level: 0
sub_branches: 71502, avg: 13.45793337097685, max: 99

leaves at this level: 0
sub_branches: 0, avg: 0.0, max: 0

leaves at this level: 71502
total leaves: 71502


Explore the new tree visually if desired. High numbers of child nodes can be seen on many of the nodes.

In [8]:
node = root

print(format_node(node))

display_lazy_tree(node, max_initial_depth=2)




VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [bcc267e5-021c-478a-…

Create Ollama model instances to process LLM requests locally (Change this to your desired models from LangChain)

In [9]:
cats = [n.condition for n in node.children]
from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(
    model="mxbai-embed-large",
)

vectorstore = create_vector_store(texts=cats, embeddings=embeddings)

from langchain_ollama import ChatOllama
create_llm = lambda: ChatOllama(
    model="qwen2.5:14b",
    # temperature=0,
)

Test out a single iteration of creating new categories at the root level

In [10]:
cats, tokens = ask_model_category(node=root, embeddings=embeddings, create_llm=create_llm)
cats

CategoryAnswer(categories=['Agriculture and Forestry', 'Construction and Infrastructure Materials', 'Education and Professional Training', 'Healthcare and Medical Equipment', 'Industrial Machinery and Supplies', 'Consumer Goods and Services', 'Natural Resources Extraction and Services', 'Chemical Products and Biochemicals', 'Financial and Legal Services', 'Environmental and Sustainability Solutions', 'Electrical Engineering and Lighting Technology', 'Media and Communication Equipment'])

Token counts are also captured by any of the functions that interact with LLM calls

In [11]:
tokens

TokenCounts(prompt=401, completion=101, total=502)

The original dataset includes more categories than needed. We will be classifying consumer products from Amazon later so let's remove the extra segments and families that are irrelevant to save time optimizing the tree.

We will remove all service segments and 'Food Beverage and Tabacco Products' which is very large and complex and not particularly relevant. We will also remove the 'live' and 'fresh' families.

The following segments will be removed:
* Farming and Fishing and Forestry and Wildlife Contracting Services
* Mining and oil and gas services
* Building and Facility Construction and Maintenance Services
* Industrial Production and Manufacturing Services
* Industrial Cleaning Services
* Environmental Services
* Transportation and Storage and Mail Services
* Management and Business Professionals and Administrative Services
* Engineering and Research and Technology Based Services
* Editorial and Design and Graphic and Fine Art Services
* Public Utilities and Public Sector Related Services
* Financial and Insurance Services
* Healthcare Services
* Education and Training Services
* Travel and Food and Lodging and Entertainment Services
* Personal and Domestic Services
* National Defense and Public Order and Security and Safety Services
* Politics and Civic Affairs Services
* Food Beverage and Tobacco Products

The following families will be removed:
* Live animals
* Live rose bushes
* Live plants of high species or variety count flowers
* Live plants of low species or variety count flowers
* Live chrysanthemums
* Live carnations
* Live orchids
* Fresh cut rose
* Fresh cut blooms of high species or variety count flowers
* Fresh cut blooms of low species or variety count flowers
* Fresh cut chrysanthemums
* Fresh cut floral bouquets
* Fresh cut carnations
* Fresh cut orchids
* Fresh cut greenery

In [12]:
segments_to_remove = [
    s for s in pd.unique(df["Segment Name"]) if "services" in s.lower()
]

segments_to_remove.append("Food Beverage and Tobacco Products")

families_to_remove = [
    f for f in pd.unique(df[df["Segment Name"] == "Live Plant and Animal Material and Accessories and Supplies"]["Family Name"]) if "live" in f.lower() or "fresh" in f.lower()
]

df_filt = df

for s in segments_to_remove:
    df_filt = df_filt[df_filt["Segment Name"] != s]

for f in families_to_remove:
    df_filt = df_filt[df_filt["Family Name"] != f]
    
len(df_filt)

31576

Create a fresh tree like shown above and run the optimizer. This will create new categories as necessary and ensure that a max of 25 children will exist at any node. (This takes a very long time)

In [None]:
root = create_tree_from_breadcrumbs(df_filt, breadcrumb_cols=["Segment Name", "Family Name", "Class Name", "Commodity Name"], extra_cols_map={"Segment Name": ["Segment"], "Family Name": ["Family"], "Class Name": ["Class"], "Commodity Name": ["Commodity"]})
progress_bars = ProgressBars(n_leaves=len(df_filt))
display(progress_bars.ui)
optimize_tree(root=root, max_children=25, progress_bars=progress_bars, embeddings=embeddings, create_llm=create_llm)

In [42]:
clean_tree(root=root)

Save the tree for later.

In [84]:
import pickle

with open("tree.pkl", "wb") as f:
    pickle.dump(root, f)

Test the saved tree by loading it again. (Using a copy of the file)

In [40]:
import pickle

with open("tree_raw.pkl", "rb") as f:
    root = pickle.load(f)

Explore the final tree.

In [43]:
check_tree(root)

sub_branches: 82, avg: 4.555555555555555, max: 12

leaves at this level: 0
sub_branches: 482, avg: 5.878048780487805, max: 21

leaves at this level: 1
sub_branches: 3422, avg: 7.0995850622406635, max: 25

leaves at this level: 18
sub_branches: 12084, avg: 3.5312682641729984, max: 25

leaves at this level: 2033
sub_branches: 11665, avg: 0.9653260509764978, max: 25

leaves at this level: 10713
sub_branches: 5059, avg: 0.433690527218174, max: 25

leaves at this level: 11038
sub_branches: 1925, avg: 0.38050998220992294, max: 25

leaves at this level: 4792
sub_branches: 796, avg: 0.4135064935064935, max: 24

leaves at this level: 1820
sub_branches: 312, avg: 0.39195979899497485, max: 25

leaves at this level: 768
sub_branches: 89, avg: 0.28525641025641024, max: 21

leaves at this level: 304
sub_branches: 0, avg: 0.0, max: 0

leaves at this level: 89
total leaves: 31576


In [16]:
root.children[0].condition

'Infrastructure Components'

In [44]:
display_lazy_tree(root, max_initial_depth=3)

VBox(children=(FigureWidget({
    'data': [{'branchvalues': 'total',
              'ids': [91411438-db5c-4de7-…