# Tutorial: General Usage
In this notebook, we use `aughs` to fit a decision tree and a random forest with augmented hierarchical shrinkage to a dataset, and plot the resulting feature importances.

In [6]:
import sys
sys.path.append('../')  # Necessary to import aughs from parent directory

from aughs import ShrinkageClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from imodels.util.data_util import get_clean_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score

## 1. Fitting the models

In [7]:
# Load and split the data
X, y, feature_names = get_clean_dataset("heart", data_source="imodels")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [9]:
# Train a decision tree with entropy-based augmented HS
clf = ShrinkageClassifier(DecisionTreeClassifier(), # Use a decision tree as the base classifier
                          shrink_mode="hs_entropy", # Use entropy-based HS
                          lmb=10) # Use a lambda of 10
clf.fit(X_train, y_train)
print("Accuracy:", clf.score(X_test, y_test))
print("Balanced accuracy:", balanced_accuracy_score(y_test, clf.predict(X_test)))

Accuracy: 0.7777777777777778
Balanced accuracy: 0.7835497835497836


In [None]:
# Train a random forest with log-cardinality-based augmented HS
clf = ShrinkageClassifier(RandomForestClassifier(), # Use a random forest as the base classifier
                            shrink_mode="hs_log_cardinality", # Use log-cardinality-based HS

## 2. Plotting feature importances