In [None]:
import pandas as pd
import numpy as np
import gurobipy as gp
from sklearn import tree
from collections import namedtuple

class maxFlowOptimalDecisionTreeClassifier:
    def __init__(self, max_depth=3, alpha=0, warmstart=True, timelimit=600, output=True):
        self.max_depth = max_depth
        self.alpha = alpha
        self.warmstart = warmstart
        self.timelimit = timelimit
        self.output = output
        self.trained = False
        self.optgap = None
        self.leaf_counts = {}

        # Tree structure definitions
        self.B = list(range(2**self.max_depth - 1))  # Branch nodes
        self.T = list(range(2**self.max_depth - 1, 2**(self.max_depth + 1) - 1))  # Leaf nodes

    def fit(self, x, y):
        self.n, self.m = x.shape
        if self.output:
            print(f'Training data include {self.n} instances, {self.m} features.')

        self.K = list(range(np.max(y) + 1))
        self.I = list(range(self.n))
        self.F = list(range(self.m))

        m, b, w = self._buildMIP(x, y)
        if self.warmstart:
            self._setStart(x, y, b, w)
        m.optimize()
        self.optgap = m.MIPGap

        # Retrieve solution values
         # Retrieve solution values
        z_val = m.getAttr('x', m._z)
        self._zval = z_val  # Gem løsning i objektet til senere brug
        
        self._tree_construction(m, b, w)
        # Beregn korrekt klassificerede (match mellem bladets label og y[i])
        correct = sum(
            z_val[i, n] >= 0.5 and self.labels.get(n) == y[i]
            for i in self.I for n in self.T
        )
        accuracy = correct / len(self.I)

        # Antal data i hvert blad
        self.leaf_counts = {
            n: sum(z_val[i, n] >= 0.5 for i in self.I)
            for n in self.T
        }

        # Fordeling af klasser i hvert blad
        from collections import Counter
        self.leaf_class_distribution = {
            n: Counter(y[i] for i in self.I if z_val[i, n] >= 0.5)
            for n in self.T
        }

        # Udskrivning
        if self.output:
            print(f"Correctly classified: {correct}/{len(self.I)} ({accuracy:.2%})")
        
            print("\nAntal data i hvert blad:")
            for n, count in self.leaf_counts.items():
                print(f"  Leaf {n}: {count} data points")
        
            print("\nKlassedistribution i hvert blad:")
            for n, dist in self.leaf_class_distribution.items():
                print(f"  Leaf {n}: {dict(dist)}")


        self.trained = True

    def _buildMIP(self, x, y):
        m = gp.Model("StrongOCT")
        m.Params.outputFlag = self.output
        m.Params.LogToConsole = self.output
        m.Params.timelimit = self.timelimit
        m.Params.threads = 0
        #m.Params.MIPFocus = 1  # Focus on feasibility
        #m.Params.Presolve = 2   # Aggressive presolve

        # Variables
        b = m.addVars(self.B, self.F, vtype=gp.GRB.BINARY, name="b")
        w = m.addVars(self.T, self.K, vtype=gp.GRB.BINARY, name="w")
        z = m.addVars(self.I, self.B + self.T, vtype=gp.GRB.BINARY, name="z")

        # Constraint (1b): Each branch node selects exactly one feature
        m.addConstrs((b.sum(n, '*') == 1 for n in self.B))

        # Flow constraints
        for i in self.I:
            # Root node must be visited by all instances
            m.addConstr(z[i, 0] == 1)  # Fixed: must be == 1
            
            # Branch node flow conservation
            for n in self.B:
                l, r = self._tree_children(n)
                m.addConstr(z[i, n] == z[i, l] + z[i, r])
            for i in self.I:
                for n in self.B:
                    l, r = self._tree_children(n)
                    # Features where x_i[f] == 0
                    features_zero = [f for f in self.F if x[i][f] == 0]
                    m.addConstr(z[i, l] <= gp.quicksum(b[n, f] for f in features_zero))
                    # Features where x_i[f] == 1
                    features_one = [f for f in self.F if x[i][f] == 1]
                    m.addConstr(z[i, r] <= gp.quicksum(b[n, f] for f in features_one))

            # Leaf node constraints
            for n in self.T:
                parent = (n - 1) // 2
                if parent in self.B:
                    m.addConstr(z[i, parent] >= z[i, n], name=f"leaf_flow_lower_{i}_{n}")  # z_p >= z_n
                    #m.addConstr(z[i, parent] <= z[i, n] + (1 - w[n, y[i]]),
            #name=f"leaf_flow_upper_{i}_{n}")  # z_p <= z_n if w = 1

                m.addConstr(z[i, n] <= w[n, y[i]], name=f"label_match_{i}_{n}")

        # Constraint: Each leaf assigns exactly one class
        m.addConstrs((w.sum(n, '*') == 1 for n in self.T))

        # Objective: Maximize correct classifications - regularization
        obj = gp.quicksum(z[i, n] for i in self.I for n in self.T) - self.alpha * b.sum()
        m.setObjective(obj, gp.GRB.MAXIMIZE)

        # Store variables for warm start
        m._b = b
        m._w = w
        m._z = z
        m._X = x
        m._Y = y

        return m, b, w

    def predict(self, x):
        if not self.trained:
            raise AssertionError('Model not trained yet.')

        pred = []
        for val in x:
            n = 0
            while n in self.branches:  # Traverse until reaching a leaf
                f = self.branches.get(n, 0)  # Safely handle missing branches
                n = 2 * n + 1 if val[f] == 0 else 2 * n + 2
            pred.append(self.labels.get(n, 0))  # Safely handle missing labels
        return np.array(pred)

    @staticmethod
    def _tree_children(node):
        return 2 * node + 1, 2 * node + 2

    def _tree_construction(self, m, b, w):
        b_val = m.getAttr('x', b)
        w_val = m.getAttr('x', w)
        
        # Extract branch decisions with safety checks
        self.branches = {}
        for n in self.B:
            for f in self.F:
                if b_val[n, f] >= 0.999:  # Considered as binary
                    self.branches[n] = f
                    break
        
        # Extract leaf class assignments with safety checks
        self.labels = {}
        for n in self.T:
            for k in self.K:
                if w_val[n, k] >= 0.999:
                    self.labels[n] = k
                    break

    def _setStart(self, x, y, b, w):
        # Warm start from standard decision tree
        clf = tree.DecisionTreeClassifier(max_depth=self.max_depth)
        clf.fit(x, y)
        rules = self._getRules(clf)

        # Set feature splits with safety checks
        for n in self.B:
            if rules[n].feat is not None and rules[n].feat >= 0:
                try:
                    feat = int(rules[n].feat)
                    b[n, feat].start = 1
                    for f in self.F:
                        if f != feat:
                            b[n, f].start = 0
                except (ValueError, TypeError, IndexError):
                    continue  # Skip invalid conversions

        # Set leaf class assignments with safety checks
        for n in self.T:
            if rules[n].value is not None:
                try:
                    # scikit-learn's value is shaped (1, n_classes)
                    leaf_class = np.argmax(rules[n].value[0])  
                    for k in self.K:
                        w[n, k].start = 1 if k == leaf_class else 0
                except (ValueError, TypeError, IndexError):
                    continue  # Skip invalid assignments

    def _getRules(self, clf):
        node_map = {0: 0}
        for n in self.B:
            mapped = node_map.get(n, -1)
            if mapped == -1 or mapped >= clf.tree_.node_count:
                continue
            left = clf.tree_.children_left[mapped]
            right = clf.tree_.children_right[mapped]
            node_map[2*n + 1] = left
            node_map[2*n + 2] = right

        rule = namedtuple('Rules', ('feat', 'threshold', 'value'))
        rules = {}
        for n in self.B + self.T:
            mapped = node_map.get(n, -1)
            if mapped == -1 or mapped >= clf.tree_.node_count:
                rules[n] = rule(None, None, None)
            else:
                feat = clf.tree_.feature[mapped]
                rules[n] = rule(
                    feat if feat >= 0 else None,
                    clf.tree_.threshold[mapped] if feat >= 0 else None,
                    clf.tree_.value[mapped, 0] if feat >= 0 else None
                )
        return rules

    def get_leaf_counts(self):
        """Return dictionary of leaf node counts"""
        if not self.trained:
            raise AssertionError("Model not trained yet.")
        return self.leaf_counts