In [3]:
import math
from collections import Counter
import pprint
import csv

# ---- Entropy ----
def entropy(data_subset):
    labels = [record["class"] for record in data_subset]
    total = len(labels)
    counts = Counter(labels)
    ent = 0.0
    for count in counts.values():
        p = count / total
        ent -= p * math.log2(p)
    return ent

# ---- Information Gain ----
def info_gain(data_subset, attribute):
    total_entropy = entropy(data_subset)
    values = set(record[attribute] for record in data_subset)
    weighted_entropy = 0.0
    total = len(data_subset)
    for val in values:
        subset = [record for record in data_subset if record[attribute] == val]
        weighted_entropy += (len(subset) / total) * entropy(subset)
    return total_entropy - weighted_entropy

# ---- Majority class ----
def majority_class(data_subset):
    return Counter([record["class"] for record in data_subset]).most_common(1)[0][0]

# ---- ID3 Algorithm ----
def id3(data_subset, attributes):
    labels = [record["class"] for record in data_subset]
    if len(set(labels)) == 1:
        return labels[0]
    if not attributes:
        return majority_class(data_subset)

    gains = [(attr, info_gain(data_subset, attr)) for attr in attributes]
    best_attr, best_gain = max(gains, key=lambda x: x[1])
    if best_gain == 0:
        return majority_class(data_subset)

    tree = {best_attr: {}}
    values = set(record[best_attr] for record in data_subset)
    for val in values:
        subset = [record for record in data_subset if record[best_attr] == val]
        if not subset:
            tree[best_attr][val] = majority_class(data_subset)
        else:
            remaining_attrs = [a for a in attributes if a != best_attr]
            tree[best_attr][val] = id3(subset, remaining_attrs)
    return tree

# ---- Prediction ----
def predict(tree, sample, default_class=None):
    if not isinstance(tree, dict):
        return tree
    attribute = next(iter(tree))
    value = sample.get(attribute)
    if value in tree[attribute]:
        return predict(tree[attribute][value], sample, default_class)
    else:
        return default_class

if __name__ == "__main__":
    # Load the mushrooms dataset
    data = []
    with open("mushrooms.csv", newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            data.append(row)

    # All attributes except target "class"
    attributes = list(data[0].keys())
    attributes.remove("class")

    # Build decision tree
    decision_tree = id3(data, attributes)

    print("Decision Tree:")
    pprint.pprint(decision_tree)

    # Example prediction
    new_sample = {attr: data[0][attr] for attr in attributes}  # taking first row as a test
    prediction = predict(decision_tree, new_sample, default_class=majority_class(data))
    print(f"\nPredicted class for sample: {prediction}")


Decision Tree:
{'odor': {'a': 'e',
          'c': 'p',
          'f': 'p',
          'l': 'e',
          'm': 'p',
          'n': {'spore-print-color': {'b': 'e',
                                      'h': 'e',
                                      'k': 'e',
                                      'n': 'e',
                                      'o': 'e',
                                      'r': 'p',
                                      'w': {'habitat': {'d': {'gill-size': {'b': 'e',
                                                                            'n': 'p'}},
                                                        'g': 'e',
                                                        'l': {'cap-color': {'c': 'e',
                                                                            'n': 'e',
                                                                            'w': 'p',
                                                                            'y': 'p'}},
           