<a href="https://colab.research.google.com/github/cedamusk/AI-N-ML/blob/main/scratchdecision_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install graphviz
import graphviz

## Node Class
The `Node` class represents a single node in a decision tree. Each node may store information necessary for making a decision or classification and contains pointers to its child nodes.

###Attributes:
* `feature_idx` (`int`): Index of the feature used to split the data at this node. Default is `-1`, which indicates the node is a leaf or uninitialized.
* `threshold` (`float`): The value at which the feature is split. Default is `0.0`.
* `label`(`int`): Class lael assigned to the node. Used when the node is leaf. Default is `-1`.
* `left` (`Node` or `None`): Pointer to the left hild node (i.e., the subtree where feature value is <= threshold). Default is `None`.


The class is typically used when building decision trees manually.

In [None]:
class Node:
  def __init__(self):
    self.feature_idx=-1
    self.threshold=0.0
    self.label=-1
    self.left=None
    self.right=None

## Calculate gini
Calculates the Gini impurity of a dataset, which is a measure of how often a randomly chosen element from the dataset would be incorrectly labelled if it were randomly chosen according to the label distribution in the dataset

### Parameters:
* `samples` (`list` of `list` or `tuple`): Each sample is expected to be a list or tuple where the label is stored at index `2`. The function assumes binary classification, which class labels `0` and `1`.

### Returns:
* `float`: The Gini impurity value ranging from `0.0` (pure) to `0.5` (maximum impurity for binary classes).

### How it works:
 1. If the input `samples` list is empty, the function returns `0.0` (no impurity).
 2. It counts how many samples belong to class `0` and class `1`.
 3. It computes the proportion of each class.
 4. It applies the Gini formula:
  $
  \text{Gini}=1 - \sum_{i=1}^{n} p_i^2
  $
  Where $p_i$ is the probability of class $i$




In [None]:
def calculate_gini(samples):
  if not samples:
    return 0.0
  count=[0, 0]
  for sample in samples:
    count[int(sample[2])]+=1
  total=len(samples)
  gini=1.0
  for c in count:
    p=c/total
    gini-=p*p
  return gini

## Find best split

Finds the best feature and threshold to split a dataset in order to minimize the Gini impurity. This function is used in building decision trees for classification.

### Parameters
* `samples` (`list` of `list` or `tuple`): Each sample is expected to have feature values at index `0` and `1`, and the class label at index `2`.

### Returns:
* `tuple`:
  * `best_feature` (`int`): Index of the feature (0 or 1) that gives the best split.
  * `best_threshold` (`float`): Threshold values for the best split.
  * `best_gini` (`float`): The Gini impurity score of the best split (lower is better).

### How it works:
1. Initializes the best Gini score as infinity and placeholders for the best feature and threshold.
2. Iterates over each feature (in this case, feature 0 and 1).
3. For each feature, gathers unique values (used as possible thresholds) and sorts them.
4. For each threshold:
  * Splits the dataset into two groups: `left` (values <= threshold) and `right` (values > threshold).
  * Skips the threshold if either split is empty.
  * Calculates Gini impurity for both groups and omputes the weighted Gini of the Gini Split.

5. Keeps track of the feature and threshold that produces the lowest weighted Gini impurity.
6. Returns the best feature index, threshold, and more corresponding Gini sco



In [None]:
def find_best_split(samples):
  best_gini=float('inf')
  best_feature=-1
  best_threhsold=0.0

  for feature in range(2):
    thresholds=sorted(set(sample[feature] for sample in samples))

    for thresh in thresholds:
      left=[s for s in samples if s[feature]<=thresh]
      right=[s for s in samples if s[feature]> thresh]
      if not left or not right:
        continue

      gini_left=calculate_gini(left)
      gini_right=calculate_gini(right)
      weighted_gini=(len(left)*gini_left+len(right)*gini_right)

      if weighted_gini< best_gini:
        best_gini=weighted_gini
        best_feature=feature
        best_threshold=thresh

  return best_feature, best_threshold, best_gini

## Build Tree
Recursively builds a binary decision tree using the Gini impurity criterion. The tree is constructed by choosing the best feature and threshold at each node to split the data, until a maximum depth is reached or further splitting is unnecessary.

### Parameters
* `samples` (`list` of `list` of `tuple`): Each sample must have at least three elements: feature 0, feature 1 and a label at index 2.
* `depth` (`int`): the current depth of the tree (starts at 0 when first caled).
* `max_depth` (`int`): The maximum depth allowed for the tree.

### Returns:
* `Node`: The rrot node of the (sub)tree, containing split information or a predicted label.

### How it works
1. Creates a new `Node` instance.
2. Checks stopping conditions:
  * If the maximum depth is reached.
  * If there are fewer than 2 samples.
  * If the Gini Impurity is 0 (pure node). In these cases, the node becomes a leaf, and the majority class label is assigned.

3. Otherwise, the function:
  * Finds the best feature and threshold to split the samples using `find_best_split`.
  * Splits the dataset into `left_samples` and `right_samples`.
  * Recursively builds the left and right subtrees.
  * Sets the node's feature index and threshold, and attaches the child nodes.

In [None]:
def build_tree(samples, depth, max_depth):
  node=Node()

  current_gini=calculate_gini(samples)
  if depth>= max_depth or len(samples)<2 or current_gini==0.0:
    count=[0,0]
    for sample in samples:
      count[int(sample[2])]+=1
    node.label=1 if count[1]>count[0] else 0
    return node

  best_feature, best_threshold, _=find_best_split(samples)

  left_samples=[s for s in samples if s[best_feature]<=best_threshold]
  right_samples=[s for s in samples if s[best_feature]>best_threshold]

  node.feature_idx=best_feature
  node.threshold=best_threshold
  node.left=build_tree(left_samples, depth+1, max_depth)
  node.right=build_tree(right_samples, depth+1, max_depth)
  return node

## Predict
Predicts the class label for a single input sample using a decision tree built with `Node` objects.

### Parameters
* `node` (`Node`): The current node in the decision tree (typically the root node when called initially).
* `sample` (`list` or `tuple`): A single input sample, where the number of features matches what was used to train the tree.

### Returns
* `int`: The predicted class label (e.g., `0` or `1`)

### How it works
1. Base case: If the current node is a leaf (i.e., `node.label !=-1`), it returns the label stored in that node.
2. Recursive cas: It checks the feature at `node.feature_idx`:
  * If the sample's value at that index is less than or equal to the node's threshold, it continues down the left subtree.
  * Otherwise, it goes down the right subtree.

3. The function continues traversing unitl it reaches a leaf node and returns the final label.


In [None]:
def predict(node, sample):
  if node.label !=-1:
    return node.label
  if sample[node.feature_idx]<=node.threshold:
    return predict(node.left, sample)
  return predict(node.right, sample)

## Tree dot
Generates a Graphviz DOT representation of a decision tree for visualization purposes. Each node in the tree is represented either as a box (leaf node with class label) or an oval (internal decision node).

###Parameters:
* `node1 (`Node`): The root node of the decision tree.
* `dot` (`graphvix.Diagraph`, optional): A `Diagraph` object used for recursive calls. On te first call, this should be left as `None`.

### Returns:
* `graphviz.Diagraph`: A `Diagraph` object representing the struture of the tree, ready for rendering or exporting.

### How it works:
1. Initial setup: If `dot` is `None`, creates a new `Diagraph` instance and sets the direction from top to bottom.
2. Node identification: Uses Python's `id()` to generate a unique identifier for each node.
3. Node labels:
  * Leaf nodes: Displayed as boxes with the predicted class (`Igneous` for label 1, `Sedimentary` for label 0).
  * Internal nodes: Displayed as ovals with the decision rule (e.g., `Silica <= threshold` or `Grain Size<=threshold`).
4. Edges: Adds edges from the current node to its left and right children, if they exist.
5. Recursively calls itself to continue building the tree structure.

In [None]:
def tree_to_dot(node, dot=None):
  if dot is None:
    from graphviz import Digraph
    dot=Digraph()
    dot.attr(rankdir='TB')

  node_id=str(id(node))

  if node.label != -1:
    label=f"Class: {'Igneous' if node.label==1 else 'Sedimentary'}"
    dot.node(node_id, label, shape='box')
  else:
    feature_name='Silica' if node.feature_idx==0 else 'Grain Size'
    label=f"{feature_name}\n<= {node.threshold:2f}"
    dot.node(node_id, label, shape='oval')

  if node.left:
    dot.edge(node_id, str(id(node.left)))
    tree_to_dot(node.left, dot)
  if node.right:
    dot.edge(node_id, str(id(node.right)))
    tree_to_dot(node.right, dot)

  return dot

## Script execution block
This section runs when the script is excuted directly. It demonstrates how to build a decision tree for classifying rock types (Igneous vs. Sedimentary) based on silica content an grain size, mae predictions, and visualize the tree.

In [None]:
if __name__=="__main__":
  dataset=[
      [70.0, 1.0, 1], #Igneous
      [55.0, 0.5, 1], #Igneous
      [65.0, 2.0, 1], #Igneous
      [30.0, 0.1, 0], #Sedimentary
      [40.0, 1.5, 0], #Sedimentary
      [50.0, 0.05, 0] #Sedimentary
  ]

  """Each sample is structures as [Silica %, Grain size(mm), Label], where Label
  is 1 for Igneous and 0 for sedimentary"""

  root=build_tree(dataset, 0, 2) #Builds the decision tree from the dataset, starting at depth 0, with a max depth of 2

  print("Testing rock sample")
  test_samples=[
      [60.0, 1.2], #Expected: Igneous
      [35.0, 0.2], #Expected: Sedimentary
  ]
  """Predicts the class label for each sample and prints the result"""

  for sample in test_samples:
    pred=predict(root, sample)
    print(f"Silica: {sample[0]}%, Grain Size: {sample[1]}mm -> {'Igneous' if pred else 'Sedimentary'}")

    try:
      dot=tree_to_dot(root)
      dot.render("rock_decision_tree", format="png", cleanup=True)
      print("\nDecision tree visualization saved as 'rock_decision_tree.png'")
    except Exception as e:
      print("Couldn't create visualization. Make Sure graphviz is installed")
      print("!pip install graphviz")
      print("Also ensure Grphviz is installed on your system")