In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

In [None]:
class TreeNode:
    def __init__(self, is_leaf=False, class_label=None, svm_model=None):
        self.is_leaf = is_leaf          # True/False
        self.class_label = class_label   # Only for leaves
        self.svm_model = svm_model       # Trained SVC for non-leaves
        self.left = None                 # "Below" boundary child
        self.right = None                # "Above/on" boundary child

In [None]:
def generate_combinations(classes):
    # Generate 1-class and 2-class combinations
    combos = []
    for c in classes:
        combos.append([c])  # Single class
    for i in range(len(classes)):
        for j in range(i+1, len(classes)):
            combos.append([classes[i], classes[j]])  # Class pairs
    return combos

In [None]:
def calculate_gini(y_left, y_right):
    # Calculate weighted Gini impurity
    n_left, n_right = len(y_left), len(y_right)
    n_total = n_left + n_right
    
    def _gini(y):
        if len(y) == 0: return 0
        counts = np.bincount(y)
        probs = counts / len(y)
        return 1 - np.sum(probs**2)
    
    return (n_left/n_total)*_gini(y_left) + (n_right/n_total)*_gini(y_right)

In [None]:
def predict(node, x):
    while not node.is_leaf:
        side = node.svm_model.predict([x])[0]  # 0=left, 1=right
        node = node.right if side == 1 else node.left
    return node.class_label

In [None]:
def build_tree(X, y, depth=0, max_depth=5, kernel='linear'):
    # Stopping conditions
    if (depth >= max_depth) or (len(X) < 10) or (len(np.unique(y)) == 1):
        majority_class = np.argmax(np.bincount(y))
        return TreeNode(is_leaf=True, class_label=majority_class)
    
    best_gini = float('inf')
    best_svm = None
    best_mask = None
    
    # Try all 1-class and 2-class combinations
    classes = np.unique(y)
    for class_combo in generate_combinations(classes):  # Implement this helper
        # Create binary labels: 1=selected class(es), 0=others
        y_binary = np.isin(y, class_combo).astype(int)
        
        # Train SVM
        svm = SVC(kernel=kernel).fit(X, y_binary)
        
        # Evaluate split using Gini index
        mask = svm.predict(X) == 1
        gini = calculate_gini(y[mask], y[~mask])  # Implement this
        
        if gini < best_gini:
            best_gini, best_svm, best_mask = gini, svm, mask
    
    # Recursively build subtrees
    node = TreeNode(svm_model=best_svm)
    node.left = build_tree(X[~best_mask], y[~best_mask], depth+1, max_depth, kernel)
    node.right = build_tree(X[best_mask], y[best_mask], depth+1, max_depth, kernel)
    return node