In [None]:
!pip install graphviz
import graphviz

## Node Class
The `Node` class represents a single node in a decision tree. Each node may store information necessary for making a decision or classification and contains pointers to its child nodes.

###Attributes:
* `feature_idx` (`int`): Index of the feature used to split the data at this node. Default is `-1`, which indicates the node is a leaf or uninitialized.
* `threshold` (`float`): The value at which the feature is split. Default is `0.0`.
* `label`(`int`): Class lael assigned to the node. Used when the node is leaf. Default is `-1`.
* `left` (`Node` or `None`): Pointer to the left hild node (i.e., the subtree where feature value is <= threshold). Default is `None`.


The class is typically used when building decision trees manually.

In [None]:
class Node:
  def __init__(self):
    self.feature_idx=-1
    self.threshold=0.0
    self.label=-1
    self.left=None
    self.right=None

## Calculate gini
Calculates the Gini impurity of a dataset, which is a measure of how often a randomly chosen element from the dataset would be incorrectly labelled if it were randomly chosen according to the label distribution in the dataset

### Parameters:
* `samples` (`list` of `list` or `tuple`): Each sample is expected to be a list or tuple where the label is stored at index `2`. The function assumes binary classification, which class labels `0` and `1`.

### Returns:
* `float`: The Gini impurity value ranging from `0.0` (pure) to `0.5` (maximum impurity for binary classes).

### How it works:
 1. If the input `samples` list is empty, the function returns `0.0` (no impurity).
 2. It counts how many samples belong to class `0` and class `1`.
 3. It computes the proportion of each class.
 4. It applies the Gini formula:
  $
  \text{Gini}=1 - \sum_{i=1}^{n} p_i^2
  $
  Where $p_i$ is the probability of class $i$




In [None]:
def calculate_gini(samples):
  if not samples:
    return 0.0
  count=[0, 0]
  for sample in samples:
    count[int(sample[2])]+=1
  total=len(samples)
  gini=1.0
  for c in count:
    p=c/total
    gini-=p*p
  return gini

## Find best split

Finds the best feature and threshold to split a dataset in order to minimize the Gini impurity. This function is used in building decision trees for classification.

### Parameters
* `samples` (`list` of `list` or `tuple`): Each sample is expected to have feature values at index `0` and `1`, and the class label at index `2`.

### Returns:
* `tuple`:
  * `best_feature` (`int`): Index of the feature (0 or 1) that gives the best split.
  * `best_threshold` (`float`): Threshold values for the best split.
  * `best_gini` (`float`): The Gini impurity score of the best split (lower is better).

### How it works:
1. Initializes the best Gini score as infinity and placeholders for the best feature and threshold.
2. Iterates over each feature (in this case, feature 0 and 1).
3. For each feature, gathers unique values (used as possible thresholds) and sorts them.
4. For each threshold:
  * Splits the dataset into two groups: `left` (values <= threshold) and `right` (values > threshold).
  * Skips the threshold if either split is empty.
  * Calculates Gini impurity for both groups and omputes the weighted Gini of the Gini Split.

5. Keeps track of the feature and threshold that produces the lowest weighted Gini impurity.
6. Returns the best feature index, threshold, and more corresponding Gini score.



In [None]:
def find_best_split(samples):
  best_gini=float('inf')
  best_feature=-1
  best_threhsold=0.0

  for feature in range(2):
    thresholds=sorted(set(sample[feature] for sample in samples))

    for thresh in thresholds:
      left=[s for s in samples if s[feature]<=thresh]
      right=[s for s in samples if s[feature]> thresh]
      if not left or not right:
        continue

      gini_left=calculate_gini(left)
      gini_right=calculate_gini(right)
      weighted_gini=(len(left)*gini_left+len(right)*gini_right)

      if weighted_gini< best_gini:
        best_gini=weighted_gini
        best_feature=feature
        best_threshold=thresh

  return best_feature, best_threshold, best_gini

In [None]:
def build_tree(samples, depth, max_depth):
  node=Node()

  current_gini=calculate_gini(samples)
  if depth>= max_depth or len(samples)<2 or current_gini==0.0:
    count=[0,0]
    for sample in samples:
      count[int(sample[2])]+=1
    node.label=1 if count[1]>count[0] else 0
    return node

  best_feature, best_threshold, _=find_best_split(samples)

  left_samples=[s for s in samples if s[best_feature]<=best_threshold]
  right_samples=[s for s in samples if s[best_feature]>best_threshold]

  node.feature_idx=best_feature
  node.threshold=best_threshold
  node.left=build_tree(left_samples, depth+1, max_depth)
  node.right=build_tree(right_samples, depth+1, max_depth)
  return node

In [None]:
def predict(node, sample):
  if node.label !=-1:
    return node.label
  if sample[node.feature_idx]<=node.threshold:
    return predict(node.left, sample)
  return predict(node.right, sample)

In [None]:
def tree_to_dot(node, dot=None):
  if dot is None:
    from graphviz import Digraph
    dot=Digraph()
    dot.attr(rankdir='TB')

  node_id=str(id(node))

  if node.label != -1:
    label=f"Class: {'Igneous' if node.label==1 else 'Sedimentary'}"
    dot.node(node_id, label, shape='box')
  else:
    feature_name='Silica' if node.feature_idx==0 else 'Grain Size'
    label=f"{feature_name}\n<= {node.threshold:2f}"
    dot.node(node_id, label, shape='oval')

  if node.left:
    dot.edge(node_id, str(id(node.left)))
    tree_to_dot(node.left, dot)
  if node.right:
    dot.edge(node_id, str(id(node.right)))
    tree_to_dot(node.right, dot)

  return dot

In [None]:
if __name__=="__main__":
  dataset=[
      [70.0, 1.0, 1],
      [55.0, 0.5, 1],
      [65.0, 2.0, 1],
      [30.0, 0.1, 0],
      [40.0, 1.5, 0],
      [50.0, 0.05, 0]
  ]

  root=build_tree(dataset, 0, 2)

  print("Testing rock sample")
  test_samples=[
      [60.0, 1.2],
      [35.0, 0.2],
  ]

  for sample in test_samples:
    pred=predict(root, sample)
    print(f"Silica: {sample[0]}%, Grain Size: {sample[1]}mm -> {'Igneous' if pred else 'Sedimentary'}")

    try:
      dot=tree_to_dot(root)
      dot.render("rock_decision_tree", format="png", cleanup=True)
      print("\nDecision tree visualization saved as 'rock_decision_tree.png'")
    except Exception as e:
      print("Couldn't create visualization. Make Sure graphviz is installed")
      print("!pip install graphviz")
      print("Also ensure Grphviz is installed on your system")