## 0. Imports

In [1]:
import pandas as pd
import numpy as np

import interpret
from interpret import show

import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
import pyAgrum.skbn as skbn

from sklearn.model_selection import train_test_split

## 1. Loading IRIS Dataset

We are going to use IRIS dataset to test TAN:

In [2]:
iris = pd.read_csv('data/iris.csv')
iris.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

Let's create a binary problem and discretize features.

In [3]:
iris['species'] = np.where(iris['species'] == 'Iris-setosa', 1, 0)

iris['sepal_length'] = pd.qcut(iris['sepal_length'], 4, labels=False)
iris['sepal_width'] = pd.qcut(iris['sepal_width'], 4, labels=False)
iris['petal_length'] = pd.qcut(iris['petal_length'], 4, labels=False)
iris['petal_width'] = pd.qcut(iris['petal_width'], 4, labels=False)

In [4]:
X = iris.drop('species', axis=1)
y = iris['species']

Finally, let's split the data.

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

## 2. Tree Augmented Naive Bayes Model

Let's load the TAN class that implements the model explanations.

In [6]:
from interpret.glassbox import TANClassifier

TAN_model = TANClassifier()

  from tqdm.autonotebook import tqdm


In [7]:
TAN_model.fit(X_train, y_train)

<interpret.glassbox._tan.TANClassifier at 0x1ddd5fccc40>

In [8]:
pred = TAN_model.predict(X_test)
pred

array([0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0.,
       0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1.])

In [9]:
print(TAN_model.score(X_test, y_test))

1.0


There exist lots of attributes native from pyAgrum...

In [10]:
print(dir(TAN_model.TAN_class))

['DirichletCsv', 'MarkovBlanket', 'XYfromCSV', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__sklearn_clone__', '__str__', '__subclasshook__', '__weakref__', '_binary_predict', '_build_request_for_signature', '_check_feature_names', '_check_n_features', '_estimator_type', '_get_default_requests', '_get_metadata_request', '_get_param_names', '_get_tags', '_more_tags', '_nary_predict', '_repr_html_', '_repr_html_inner', '_repr_mimebundle_', '_validate_data', '_validate_params', 'beta', 'bn', 'constraints', 'discretizationNbBins', 'discretizationStrategy', 'discretizationThreshold', 'discretizer', 'fit', 'fromModel', 'fromTrainedModel', 'get_metadata_routing', 'get_params', 'isBinaryClassifier', 'label', 'le

In [11]:
print(dir(TAN_model.TAN_class.bn))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__swig_destroy__', '__weakref__', '_repr_html_', 'add', 'addAMPLITUDE', 'addAND', 'addArc', 'addArcs', 'addCOUNT', 'addEXISTS', 'addFORALL', 'addLogit', 'addMAX', 'addMEDIAN', 'addMIN', 'addNoisyAND', 'addNoisyOR', 'addNoisyORCompound', 'addNoisyORNet', 'addOR', 'addSUM', 'addStructureListener', 'addVariables', 'addWeightedArc', 'adjacencyMatrix', 'ancestors', 'arcs', 'beginTopologyTransformation', 'changePotential', 'changeVariableLabel', 'changeVariableName', 'check', 'children', 'clear', 'completeInstantiation', 'connectedComponents', 'cpt', 'dag', 'descendants', 'dim', 'empty', 'endTopologyTransform

A tree has been constructed including some dependencies:

In [12]:
gnb.sideBySide(TAN_model._model().bn, gnb.getInference(TAN_model._model().bn, size='5!'))

0,1
G y y sepal_width sepal_width y->sepal_width petal_length petal_length y->petal_length sepal_length sepal_length y->sepal_length petal_width petal_width y->petal_width petal_length->sepal_length petal_length->petal_width sepal_length->sepal_width,"structs Inference in 9.00ms y  2025-03-26T23:59:26.017343  image/svg+xml  Matplotlib v3.9.2, https://matplotlib.org/  sepal_length  2025-03-26T23:59:26.082183  image/svg+xml  Matplotlib v3.9.2, https://matplotlib.org/  y->sepal_length sepal_width  2025-03-26T23:59:26.148246  image/svg+xml  Matplotlib v3.9.2, https://matplotlib.org/  y->sepal_width petal_length  2025-03-26T23:59:26.217233  image/svg+xml  Matplotlib v3.9.2, https://matplotlib.org/  y->petal_length petal_width  2025-03-26T23:59:26.281780  image/svg+xml  Matplotlib v3.9.2, https://matplotlib.org/  y->petal_width sepal_length->sepal_width petal_length->sepal_length petal_length->petal_width"


We can get specific info, for example:

In [13]:
TAN_model._model().bn.cpt("petal_length")

Unnamed: 0_level_0,petal_length,petal_length,petal_length,petal_length
y,0,1,2,3
0,0.0061,0.2744,0.3963,0.3232
1,0.8452,0.131,0.0119,0.0119


Depending on the parents of each variable, the cpt would have a different shape:

In [14]:
TAN_model._model().bn.cpt("petal_width")

Unnamed: 0_level_0,Unnamed: 1_level_0,petal_width,petal_width,petal_width,petal_width
y,petal_length,0,1,2,3
0,0,0.25,0.25,0.25,0.25
0,1,0.0208,0.8542,0.1042,0.0208
0,2,0.0147,0.1029,0.6912,0.1912
0,3,0.0179,0.0179,0.2679,0.6964
1,0,0.8514,0.1216,0.0135,0.0135
1,1,0.3571,0.5,0.0714,0.0714
1,2,0.25,0.25,0.25,0.25
1,3,0.25,0.25,0.25,0.25


Now let's see the explanations:

In [15]:
TAN_global = TAN_model.explain_global()
show(TAN_global)

As we have discretized data, we see two types of visualization:
- **petal_length**: In this case the visualization is exactly the same as CategoricalNB, as this variable only has one parent: the target variable.
- **petal_width**, **sepal_length** and **sepal_width**: These are the new cases. As those variables have two parents (**petal_length** and the target),  the visualization can be done on a grid as a heatmap, where certain patterns can be observed.

Let's go to the local:

In [16]:
TAN_local = TAN_model.explain_local(X_test, y_test)
show(TAN_local)

As in the other models, the visualization can be done the same way. In this case this values are obtained from the heatmaps (in the case of the variables with two parents) that we saw before, but the results are the same.