In [1]:
import pandas as pd
import graphviz
import pydotplus
import collections
from sklearn import tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

In [2]:
def gini_to_alpha(gini):
    if gini == '':
        alpha = 0
    else:
        alpha = 1 - 2 * float(gini)
    # Scale the float to an integer (0-255)
    if alpha < 0.0:
        alpha = 0.0
    elif alpha > 1.0:
        alpha = 1.0

    # Convert to hexadecimal and ensure it's two digits
    hex_value = format(int(alpha * 255), '02x')
    
    return hex_value

In [3]:
df = pd.read_csv("../fire_data_train.csv")

In [4]:
features = ['dom_vel', 'dom_dir', 'max_temp', 'min_temp',
            'mean_temp', 'rain_7_days', 'ndvi', 'lst_day', 'slope', 'dem',
            'corine_gr1', 'corine_gr4', 'corine_gr5', 'corine_gr21', 'corine_gr22',
            'corine_gr23', 'corine_gr24', 'corine_gr31', 'corine_gr32',
            'corine_gr33']
X = df[features]
y = df['fire']

clf = tree.DecisionTreeClassifier(max_depth = 2, random_state = 42)
clf = clf.fit(X, y)

In [5]:
dot_data = tree.export_graphviz(clf,
                                class_names = ['no-fire', 'fire'],
                                feature_names = features,
                                out_file = None,
                                node_ids = True,
                                proportion = True,
                                filled = True,
                                special_characters = True,
                                rounded = True)

graph = pydotplus.graph_from_dot_data(dot_data)

colors = ('#1792EA', '#E86E17')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        gini = dest.get_attributes()['label'].partition('gini = ')[2].partition('<br/>samples')[0]
        dest.set_fillcolor(colors[i]+gini_to_alpha(gini))

graph.write_png('tree.png')

True