In [None]:
import pickle
from xgboost import to_graphviz

# 1. Load the trained XGBoost model and the LabelEncoder
with open('xgboost_model.pkl', 'rb') as f:
    model = pickle.load(f)

with open('label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

# Retrieve class labels from the encoder
classes = label_encoder.classes_ if hasattr(label_encoder, 'classes_') else label_encoder
num_classes = len(classes)

# 2. Access the underlying booster and compute total number of trees
booster = model.get_booster()

# multiply number of boosting rounds by number of classes
num_rounds = booster.num_boosted_rounds()
total_trees = num_rounds * num_classes
1
print(f"Boosting rounds: {num_rounds}")
print(f"Number of classes: {num_classes}")
print(f"Total trees to render: {total_trees}\n")

# 3. Iterate and render each tree, embedding the class name in the filename
for i in range(total_trees):
    cls = classes[i % num_classes]
    print(f"Rendering tree {i} for class '{cls}'")

    # Generate the graphviz object for tree i
    dot = to_graphviz(booster, num_trees=i)

    # Filename includes tree index and class label
    output_filename = f"tree_{i}_{cls}"

    # Save as PDF (e.g., tree_0_ClassA.pdf, tree_1_ClassB.pdf, ...)
    dot.render(output_filename)

print("\nAll trees have been rendered with class names in their filenames.")


Boosting rounds: 20
Number of classes: 6
Total trees to render: 120

Rendering tree 0 for class 'TCP500Mbps'
Rendering tree 1 for class 'TCP50Mbps'
Rendering tree 2 for class 'TCP5Mbps'
Rendering tree 3 for class 'UDP500Mbps'
Rendering tree 4 for class 'UDP50Mbps'
Rendering tree 5 for class 'UDP5Mbps'
Rendering tree 6 for class 'TCP500Mbps'
Rendering tree 7 for class 'TCP50Mbps'
Rendering tree 8 for class 'TCP5Mbps'
Rendering tree 9 for class 'UDP500Mbps'
Rendering tree 10 for class 'UDP50Mbps'
Rendering tree 11 for class 'UDP5Mbps'
Rendering tree 12 for class 'TCP500Mbps'
Rendering tree 13 for class 'TCP50Mbps'
Rendering tree 14 for class 'TCP5Mbps'
Rendering tree 15 for class 'UDP500Mbps'
Rendering tree 16 for class 'UDP50Mbps'
Rendering tree 17 for class 'UDP5Mbps'
Rendering tree 18 for class 'TCP500Mbps'
Rendering tree 19 for class 'TCP50Mbps'
Rendering tree 20 for class 'TCP5Mbps'
Rendering tree 21 for class 'UDP500Mbps'
Rendering tree 22 for class 'UDP50Mbps'
Rendering tree 23 fo