# MILP Formulation from Nathan Kallus' Paper (Problem 4)

In [1]:
import gurobipy as gp
from gurobipy import GRB
import sys
import math
import numpy as np
import pandas as pd
import random

## Model specifications
#### Since we chose to modify the formulation to a certain extent, these variables simply allow us to revert back to the original model
- delta_include: If true, then constraint 4c from original formulation holds. If false, only the added constraint that gamma[p] need to add to 1 holds
- different_Cp = If true, then we have different sets of branching choices for every non-leaf node (like original formulation). If false, then we have a static set C for all nodes

In [2]:
# -- This class is an alternative to solving the right/left ancestor problem --
"""
INPUT: d = depth of tree (which includes root node), so d = 2 would make a tree with {1, 2, 3}
RELEVANT FUNCTIONS:
- get_right_left: For all leaf nodes, returns its right and left ancestors in a dictionary 
                  of {(p, q): 1 or -1 if q is right or left ancestor respectively}
"""
class Tree:
    def __init__(self, d):
        self.depth = d
        self.nodes = list(range(1, 2**(d-1)))
        self.leaves = list(range(2**(d-1), 2**d))
        self.ancestor_rl = {}
    
    def get_left_children(self, n):
        if n in self.nodes:
            return int(2*n)
        else:
            raise Exception ('Invalid node n')
    
    def get_right_children(self, n):
        if n in self.nodes:
            return int(2*n+1)
        else:
            raise Exception ('Invalid node n')
    
    def get_parent(self, n):
        if (n in self.nodes) | (n in self.leaves):
            return int(math.floor(n/2))
        else:
            raise Exception ('Invalide node n')
    
    def get_ancestors(self, direction, n):
        current = n
        ancestors = []
        while current != 1:
            current_buffer = self.get_parent(current)
            if direction == 'r':
                if self.get_right_children(current_buffer) == current:
                    ancestors.append(current_buffer)
            else:
                if self.get_left_children(current_buffer) == current:
                    ancestors.append(current_buffer)
            current = current_buffer
        return ancestors
    
    def get_right_left(self):
        for i in self.leaves:
            right = self.get_ancestors('r', i)
            for j in right:
                self.ancestor_rl[(i, j)] = 1
            left = self.get_ancestors('l', i)
            for j in left:
                self.ancestor_rl[(i, j)] = -1
        return self.ancestor_rl

# Evaluating Synthetic Data

In [3]:
def open_file(file_path):
    df = pd.read_csv(file_path)
    train_X = df.iloc[:, :20]
    real = df.iloc[:, 20:22]
    train_t = df.iloc[:, 22]
    train_y = df.iloc[:, 24]
    return train_X, train_t, train_y, real

In [4]:
def datapoint_tree(node, i, test_X, test_real, test_t):
    if node in L: #if datapoint has reached leaf node, calculate error
        index = treatments[node]
        ideal_outcome = max(test_real.iloc[i, :])
        difference = ideal_outcome - test_real.iloc[i, index]
        if difference == 0:
            count_optimal = 1
        else:
            count_optimal = 0
        
        if index == test_t[i]:
            same_treatment = 1
        else:
            same_treatment = 0
        return difference, count_optimal, same_treatment
    if test_X.iloc[i, branching[node]] <= 0: # go left (node 2)
        return datapoint_tree(tree.get_left_children(node), i, test_X, test_real, test_t)
    else:
        return datapoint_tree(tree.get_right_children(node), i, test_X, test_real, test_t)

In [5]:
def get_metrics(test_X, test_real, test_t):
    difference = 0
    count_optimal = 0
    count_same = 0
    for i in range(len(test_X)):
        diff, optimal, treat = datapoint_tree(1, i, test_X, test_real, test_t)
        difference += diff
        count_optimal += optimal
        count_same += treat
    return difference, float(count_optimal)/len(test_X), float(count_same)/len(test_X)

## Declaring variables determined a-priori

### a) Specfiying Input to Model
- m treatments indexed by t {1,..., m}
- n datapoints indexed by i {(X1, T1, Y1), ..., (Xn, Tn, Yn)} 
- d: depth of decision tree
- n_min: minimum number of datapoints of each treatment in node p
- num_features
- num_cuts

In [6]:
def run_tree(depth, bert):
    delta_include = False
    different_Cp = False
    bertsimas = bert

    m = {0, 1}
    n = len(train_X)
    d = int(depth)
    n_min = 0
    num_features = 2
    num_cuts = 2

    """# ---- CONSTRUCTING COMPLETE BINARY TREE ----
    # - P = number of nodes in the tree
    # - L_c = set of non-leaf nodes
    # - L = set of leaf ndoes"""

    P = 2**d
    L_c = set(range(1, 2**(d-1)))
    L = set(range(2**(d-1), P))

    # - ancestors: dictionary {leaf nodes: {set of ancestors}}
    ancestors = {}
    for p in L:
        ancestors[p] = [math.floor(p/(2**j)) for j in range(1, d)]

        # Alternative way of retrieving right/left ancestors
    tree = Tree(d)
    right_left = tree.get_right_left()


    """
    - C[p]: finite set of cuts on node p--determined apriori {(l, theta)}
        Representation: dictionary {non-leaf node: list of (l, theta)}
        Require a list instead of set so it's ordered (indexed easily)


    From Kallus' paper, he wants us to:
    1. For each l in [d], sort data along x_l
    2a. For each non-leaf node, pick #features from [d] randomly
    2b. Set J = {1, (n-1/#cuts), ..., n-1} in decreasing order of cuts
    2c. Set Cp = {(l, midpoint between the two buckets in J) for all dimensions chosen and for all j in J}

    Make version of ALG3 to take all features"""


    if different_Cp:
        pass
    else:
        # -- BINARY COVARIATES --> Create a finite set of cuts for C for all features --
        C = []
        for i in range(len(train_X.columns)):
            C.append((i, 0))


    """ --- OTHER DATA --- 
    BIG M Constraints:
    - Ybar
    - Ymax
    - M

    BINARY ENCODING FOR CUTS
    - k_p: dictionary {non-leaf node p: k_p value}
    - Z_p: dictionary {non-leaf node p: k_p x |C_p| 2d matrix}
    """

    # ---- Big M Constraints ----
    # Ybar is merged into data, but it could not be. ybar is a numpy array
    minimum = min(train_y)

    ybar = train_y - minimum
    #data['ybar'] = ybar

    # Ymax
    ymax = max(ybar)

    # M
    # Find all sums for treatments 1, ..., m
    #treatment_counts = treatment.value_counts().to_list()
    unique, counts = np.unique(train_t, return_counts=True)
    #frequencies = numpy.asarray((unique, counts)).T
    M = np.array(counts)
    M -= len(L) * n_min
    M = max(M)
    
    model = gp.Model("Kallus")
    model.params.TimeLimit = 3600

    # -- VARIABLE DECLARATION --

    # -- Variables to determine: gamma and lambda --
    # 1. gamma_p = choice of cut at node p ([0, 1]^C_p) (only applies to non-leaf node)
    #       - represent with a matrix gamma (|L_c| x |C_p|)

    gamma = model.addVars(L_c, len(C), vtype=GRB.BINARY, name='gamma') 
    # This assumes gamma is binary
    #gamma = model.addVars(L_c, len(C), vtype=GRB.BINARY, name='gamma') 

    # 2. lambda_pt = choice of treatment t at node p (only applies to leaf nodes L)
    #       - represent with a matrix lamb (|L| x m)
    lamb = model.addVars(L, m, vtype=GRB.BINARY, name='lamb')


    # -- Other Variables in Formulation --
    # 1. w_ip = membership of datapoint i in node p (only applies to leaf nodes L)
    #       - represent with a matrix w (n x |L|)

    w = model.addVars(n, L, lb=0, ub=1, name='w')
    # This assumes w is binary, when in reality it is continuous from 0-1
    #w = model.addVars(n, L, vtype=GRB.BINARY, name='w') # Original paper has this be a continuous variable

    # 2. mu_p = mean outcome of prescribed treatment in node p
    #       - represent with a matrix mu (|L|)
    mu = model.addVars(L, lb=0, name='mu') # define in constraint

    # 3. nu_ip = "effect" of treatment in node p by multiplying mu and w
    #       - represent with a matrix nu (n x |L|)
    nu = model.addVars(n, L, lb=0, name='nu')

    # 4. delta_p = forces only 1 choice of cut at node p
    #       - represent with a dictionary {non-leaf node p: 1d matrix of size k_p}
    #delta = model.addVars(L, k, vtype=GRB.BINARY, name='delta')

    # 5. Chi_i(gamma) = 1 if choice of cut induces datapoint i to go left on the cut gamma, 0 otherwise
    chi = model.addVars(L_c, n, vtype=GRB.BINARY, name='chi')

    if bertsimas:
        # 6. f_i
        f = model.addVars(n, lb=0, name='f')

        # 7. Beta_lt
        beta = model.addVars(L, m, lb=0, name='beta')

    theta = 0.5
    
        # --- OBJECTIVE FUNCTION ---
    if bertsimas:
        model.setObjective(theta * gp.quicksum(nu[i, p] for i in range(n) for p in L) 
                       - (1-theta) * gp.quicksum((train_y[i] - f[i]) * (train_y[i] - f[i]) for i in range(n)), GRB.MAXIMIZE)
    else:
        model.setObjective(gp.quicksum(nu[i, p] for i in range(n) for p in L), GRB.MAXIMIZE)


    # --- CONSTRAINTS ---
    # Constraint 4c (4b is done by definition of variables)
    if delta_include:
        for p in L_c:
            # need to do matrix multiplication somehow, but this might work?
            for j in range(k):
                model.addConstr(delta[p, j] == gp.quicksum(gamma[p, i] * z[j, i] for i in range(len(C))))

    # Additional constraint that gamma[p] adds up to 1
    # CHECKED
    for p in L_c:
        model.addConstr(gp.quicksum(gamma[p, i] for i in range(len(C))) == 1)

    # Add constraint Chi
    for i in range(n):
        for p in L_c:
            model.addConstr(chi[p, i] == gp.quicksum(gamma[p, j] for j in range(len(C)) if C[j][1] >= train_X.iloc[i, C[j][0]]))


    # Constraint 4d&e (Membership restriction from its ancestors) CHECKED
    for p in L:
        A_p = ancestors[p] #index ancestors of p
        for q in A_p:
            R_pq = right_left[(p, q)]
            for i in range(n):
                model.addConstr(w[i, p] <= (1+R_pq)/2 - R_pq * chi[q, i])


    #4e CHECKED
    for p in L:
        A_p = ancestors[p] #index ancestors of p
        for i in range(n):
            model.addConstr(w[i, p] >= 1 + gp.quicksum(-chi[q, i] for q in A_p if right_left[(p, q)] == 1)
                        + gp.quicksum(-1+chi[q, i] for q in A_p if right_left[(p, q)] == -1))


    # Constraint 4f
    # CHECKED
    for t in m:
        for p in L:
            model.addConstr(gp.quicksum(w[i, p] for i in range(n) if train_t[i] == t) >= n_min) #assuming the input comes in vector (Xi, Ti, Yi)
            # only add the datapoints that have been given treatment t

    # Constraints 4g&h (Linearization of nu)
    # CHECKED
    for p in L:
        for i in range(n):
            model.addConstr(nu[i, p] <= ymax * w[i, p])
            model.addConstr(nu[i, p] <= mu[p])
            model.addConstr(nu[i, p] >= mu[p] - ymax * (1-w[i, p]))

    # Constraint 4i (Choice of treatment applied to p)
    # CHECKED
    for p in L:
        model.addConstr(gp.quicksum(lamb[p, t] for t in m) == 1)

    # Constraint 4j&k (Consistency between lambda and mu)
    # CHECKED. There are some inconsistencies where some w don't appear, but this is because ybar is 0 (i.e. the minimum)
    for p in L:
        for t in m:
            model.addConstr(gp.quicksum(nu[i, p] - w[i, p] * ybar[i] for i in range(n) if train_t[i] == t) <= M*(1-lamb[p, t]))
            model.addConstr(gp.quicksum(nu[i, p] - w[i, p] * ybar[i] for i in range(n) if train_t[i] == t) >= M*(lamb[p, t]-1))
    #model.addConstr(lamb[2, 0] == 1)
    if bertsimas:
        for i in range(n):
            for p in L:
                for t in m:
                    if train_t[i] == t:
                        model.addConstr(f[i] - beta[p, t] <= M * (1-w[i, p]))
                        model.addConstr(f[i] - beta[p, t] >= M * (w[i, p]-1))

    model.optimize()
    model.printAttr('X')


In [8]:
"""kallus = pd.DataFrame(columns=['Dataset','Depth','P(Correct Treatment)','Test Error','Test % Optimal',
                               'Test % Same Treatment', 'Train Error', 'Train % Optimal', 
                               'Train % Same Treatment', 'Tree'])
bertsimas = pd.DataFrame(columns=['Dataset','Depth','P(Correct Treatment)','Test Error','Test % Optimal',
                               'Test % Same Treatment', 'Train Error', 'Train % Optimal', 
                                  'Train % Same Treatment', 'Tree'])"""


bertsimas = pd.read_csv('bertsimas_athey_500.csv')

   Dataset  Depth  P(Correct Treatment)  Test Error  Test % Optimal  \
0        1      2                   0.5   53.227602           76.83   
1        1      3                   0.5  123.833930           66.00   
2        1      3                   0.5  123.833930           66.00   
3        2      2                   0.5   24.915258           84.31   
4        2      3                   0.5   24.915258           84.31   
5        3      2                   0.5  166.280038           58.61   
6        3      3                   0.5   57.544135           77.50   

   Test % Same Treatment  Train Error  Train % Optimal  \
0                  49.33     1.833834             83.0   
1                  49.34     2.356756             85.2   
2                  49.34     2.356756             85.2   
3                  50.46     1.220292             87.8   
4                  50.46     1.220292             87.8   
5                  49.19     5.115352             75.0   
6                  49.48 

In [8]:
d = 2
prob = 0.5
dataset = 3
bert = False

train_filepath = 'data/data_train_enc_' + str(prob) + '_' + str(dataset) + '.csv'
test_filepath = 'data/data_test_enc_' + str(prob) + '_' + str(dataset) + '.csv'
train_X, train_t, train_y, train_real = open_file(train_filepath)
run_tree(d, bert)

Changed value of parameter TimeLimit to 3600.0
   Prev: inf  Min: 0.0  Max: inf  Default: inf
Gurobi Optimizer version 9.0.3 build v9.0.3rc0 (mac64)
Optimize a model with 5515 rows, 2526 columns and 21004 nonzeros
Model fingerprint: 0xb7b8bcd1
Variable types: 2002 continuous, 524 integer (524 binary)
Coefficient statistics:
  Matrix range     [3e-03, 3e+02]
  Objective range  [1e+00, 1e+00]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 3e+02]
Presolve removed 4282 rows and 2031 columns
Presolve time: 0.06s
Presolved: 1233 rows, 495 columns, 5135 nonzeros
Variable types: 378 continuous, 117 integer (117 binary)

Root relaxation: objective 4.951259e+02, 662 iterations, 0.01 seconds

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0  495.12588    0   63          -  495.12588      -     -    0s
H    0     0                     214.3258956  495.12588   131%     - 

    w[221,3]            1 
    w[222,2]            1 
    w[223,3]            1 
    w[224,2]            1 
    w[225,3]            1 
    w[226,2]            1 
    w[227,3]            1 
    w[228,2]            1 
    w[229,3]            1 
    w[230,2]            1 
    w[231,3]            1 
    w[232,2]            1 
    w[233,3]            1 
    w[234,3]            1 
    w[235,3]            1 
    w[236,3]            1 
    w[237,3]            1 
    w[238,3]            1 
    w[239,2]            1 
    w[240,2]            1 
    w[241,2]            1 
    w[242,3]            1 
    w[243,2]            1 
    w[244,3]            1 
    w[245,2]            1 
    w[246,3]            1 
    w[247,3]            1 
    w[248,3]            1 
    w[249,3]            1 
    w[250,3]            1 
    w[251,3]            1 
    w[252,3]            1 
    w[253,2]            1 
    w[254,2]            1 
    w[255,3]            1 
    w[256,2]            1 
    w[257,3]            1 
 

    nu[23,2]     0.551399 
    nu[24,2]     0.551399 
    nu[25,3]     0.479498 
    nu[26,3]     0.479498 
    nu[27,2]     0.551399 
    nu[28,3]     0.479498 
    nu[29,3]     0.479498 
    nu[30,2]     0.551399 
    nu[31,3]     0.479498 
    nu[32,3]     0.479498 
    nu[33,3]     0.479498 
    nu[34,3]     0.479498 
    nu[35,2]     0.551399 
    nu[36,2]     0.551399 
    nu[37,2]     0.551399 
    nu[38,2]     0.551399 
    nu[39,2]     0.551399 
    nu[40,3]     0.479498 
    nu[41,3]     0.479498 
    nu[42,3]     0.479498 
    nu[43,3]     0.479498 
    nu[44,3]     0.479498 
    nu[45,3]     0.479498 
    nu[46,3]     0.479498 
    nu[47,3]     0.479498 
    nu[48,3]     0.479498 
    nu[49,3]     0.479498 
    nu[50,2]     0.551399 
    nu[51,3]     0.479498 
    nu[52,2]     0.551399 
    nu[53,3]     0.479498 
    nu[54,2]     0.551399 
    nu[55,2]     0.551399 
    nu[56,3]     0.479498 
    nu[57,3]     0.479498 
    nu[58,3]     0.479498 
    nu[59,3]     0.479498 
 

   nu[327,2]     0.551399 
   nu[328,3]     0.479498 
   nu[329,2]     0.551399 
   nu[330,2]     0.551399 
   nu[331,3]     0.479498 
   nu[332,2]     0.551399 
   nu[333,3]     0.479498 
   nu[334,2]     0.551399 
   nu[335,2]     0.551399 
   nu[336,3]     0.479498 
   nu[337,2]     0.551399 
   nu[338,2]     0.551399 
   nu[339,2]     0.551399 
   nu[340,3]     0.479498 
   nu[341,2]     0.551399 
   nu[342,2]     0.551399 
   nu[343,3]     0.479498 
   nu[344,2]     0.551399 
   nu[345,3]     0.479498 
   nu[346,3]     0.479498 
   nu[347,3]     0.479498 
   nu[348,2]     0.551399 
   nu[349,3]     0.479498 
   nu[350,3]     0.479498 
   nu[351,2]     0.551399 
   nu[352,2]     0.551399 
   nu[353,3]     0.479498 
   nu[354,2]     0.551399 
   nu[355,3]     0.479498 
   nu[356,3]     0.479498 
   nu[357,3]     0.479498 
   nu[358,2]     0.551399 
   nu[359,2]     0.551399 
   nu[360,3]     0.479498 
   nu[361,2]     0.551399 
   nu[362,2]     0.551399 
   nu[363,3]     0.479498 
 

  chi[1,314]            1 
  chi[1,315]            1 
  chi[1,316]            1 
  chi[1,319]            1 
  chi[1,321]            1 
  chi[1,322]            1 
  chi[1,325]            1 
  chi[1,326]            1 
  chi[1,327]            1 
  chi[1,329]            1 
  chi[1,330]            1 
  chi[1,332]            1 
  chi[1,334]            1 
  chi[1,335]            1 
  chi[1,337]            1 
  chi[1,338]            1 
  chi[1,339]            1 
  chi[1,341]            1 
  chi[1,342]            1 
  chi[1,344]            1 
  chi[1,348]            1 
  chi[1,351]            1 
  chi[1,352]            1 
  chi[1,354]            1 
  chi[1,358]            1 
  chi[1,359]            1 
  chi[1,361]            1 
  chi[1,362]            1 
  chi[1,367]            1 
  chi[1,368]            1 
  chi[1,374]            1 
  chi[1,375]            1 
  chi[1,379]            1 
  chi[1,380]            1 
  chi[1,382]            1 
  chi[1,387]            1 
  chi[1,389]            1 
 

In [75]:
if d == 2:
    branching = {1: 16}
    treatments = {2: 1, 3: 1}
elif d == 3:
    branching = {1: 10, 2: 16, 3: 0}
    treatments = {4: 1, 5: 1, 6: 1, 7: 1}

In [76]:
tree = Tree(d)
L = set(range(2**(d-1), 2**d))
test_X, test_t, test_y, test_real = open_file(test_filepath)
error, pct, acc = get_metrics(test_X, test_real, test_t)
error1, pct1, acc1 = get_metrics(train_X, train_real, train_t)
print(error, pct, acc)
print(error1, pct1, acc1)

267.9863312540136 0.4446 0.5066
21.252969854711775 0.354 0.5


In [77]:
tree_stats = 'branching = ' + str(branching) + ', treatments = ' + str(treatments)
row = [dataset, d, prob, error, pct*100, acc*100, error1, pct1*100, acc1*100, tree_stats]
print(row)
if bert:
    bertsimas.loc[len(bertsimas)] = row
else:
    kallus.loc[len(kallus)] = row

[5, 3, 0.9, 267.9863312540136, 44.46, 50.660000000000004, 21.252969854711775, 35.4, 50.0, 'branching = {1: 10, 2: 16, 3: 0}, treatments = {4: 1, 5: 1, 6: 1, 7: 1}']


In [78]:
print(bertsimas)
bertsimas.to_csv('bertsimas_athey_500.csv', index=False)

    Dataset  Depth  P(Correct Treatment)  Test Error  Test % Optimal  \
0         1      2                   0.5   53.227602           76.83   
1         1      3                   0.5  123.833930           66.00   
2         1      3                   0.5  123.833930           66.00   
3         2      2                   0.5   24.915258           84.31   
4         2      3                   0.5   24.915258           84.31   
5         3      2                   0.5  166.280038           58.61   
6         3      3                   0.5   57.544135           77.50   
7         4      2                   0.5  432.810829           24.05   
8         4      3                   0.5  231.607156           44.69   
9         5      2                   0.5  189.645268           55.55   
10        5      3                   0.5   36.072978           82.59   
11        1      2                   0.9  224.473595           49.78   
12        1      3                   0.9  224.473595           4

### Summary of Performance for Athey's Synthetic Data

In [247]:
# SUMMARY: KALLUS on Athey
summary = {'Method': ['Dataset 1', 'Dataset 1', 'Dataset 2', 'Dataset 2', 'Dataset 3',
                     'Dataset 3', 'Dataset 4', 'Dataset 4', 'Dataset 5', 'Dataset 5', 'Dataset 1', 
                      'Dataset 1', 'Dataset 2', 'Dataset 2', 'Dataset 3',
                     'Dataset 3', 'Dataset 4', 'Dataset 4', 'Dataset 5', 'Dataset 5'],
           'Depth': [2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3],
           'P(Correct Treatment)': ['0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.9', '0.9',
                                   '0.9', '0.9', '0.9', '0.9', '0.9', '0.9', '0.9', '0.9'],
          'Error': [74.67, 74.67, 46.40, 79.86, 145, 109.21, 18.15, 18.15, 152.55, 90.47, 199.70, 238.67, 219.02, 
                    239.10, 201.21, 201.21, 99.23, 88.12, 196.07, 156.84],
          '% Classified': [71.97, 71.97, 78.49, 71.11, 61.22, 68.03, 88.33, 88.33, 60.41, 70.03, 53.35,
                          47.49, 49.91, 47.49, 53.42, 53.42, 71.19, 74.11, 54.12, 58.02]}
summary = pd.DataFrame(summary)
summary.to_csv('kallus_tree_athey.csv', index=False)

In [248]:
# SUMMARY: Bertsimas on Athey
summary = {'Method': ['Dataset 1', 'Dataset 1', 'Dataset 2', 'Dataset 2', 'Dataset 3',
                     'Dataset 3', 'Dataset 4', 'Dataset 4', 'Dataset 5', 'Dataset 5', 'Dataset 1', 
                      'Dataset 1', 'Dataset 2', 'Dataset 2', 'Dataset 3',
                     'Dataset 3', 'Dataset 4', 'Dataset 4', 'Dataset 5', 'Dataset 5'],
           'Depth': [2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3],
           'P(Correct Treatment)': ['0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.9', '0.9',
                                   '0.9', '0.9', '0.9', '0.9', '0.9', '0.9', '0.9', '0.9'],
          'Error': [74.67, 74.67, 46.40, 55.17, 145, 109.21, 18.15, 18.15, 152.55, 90.47, 199.70, 154.07, 219.02, 
                    180.83, 201.21, 116.75, 99.23, 88.12, 196.07, 156.84],
          '% Classified': [71.97, 71.97, 78.49, 76.49, 61.22, 68.03, 88.33, 88.33, 60.41, 70.03, 53.35,
                          62.17, 49.91, 53.91, 53.42, 71.59, 71.19, 74.11, 54.12, 58.02]}
summary = pd.DataFrame(summary)
summary.to_csv('bertsimas_tree_athey.csv', index=False)


In [37]:
def datapoint_tree_avg(node, i):
    if node in L: #if datapoint has reached leaf node, calculate error
        return node
    if train_X.iloc[i, branching[node]] <= 0: # go left (node 2)
        return datapoint_tree_avg(tree.get_left_children(node), i)
    else:
        return datapoint_tree_avg(tree.get_right_children(node), i)

In [None]:
summation = {20: 0, 21: 0, 30: 0, 31: 0}
count = {20: 0, 21: 0, 30: 0, 31: 0}

for i in range(n):
    leaf_node = datapoint_tree_avg(1, i)
    index = str(leaf_node) + str(train_t[i])
    summation[int(index)] += train_y[i]
    count[int(index)] += 1
    
avg = {}
for i in summation:
    avg[i] = float(summation[i]) / count[i]

print(avg)

In [49]:
summation = {2: 0, 3: 0}
count = {2: 0, 3: 0}

for i in range(n):
    leaf_node = datapoint_tree_avg(1, i)
    if train_t[i] == treatments[leaf_node]:
        summation[leaf_node] += train_y[i]
        count[leaf_node] += 1
    
avg = {}
for i in summation:
    avg[i] = float(summation[i]) / count[i]

print(avg)

{2: 0.39851651102745433, 3: 0.34277517675515534}


In [59]:
N = L.union(L_c)
print(N)

model = gp.Model("Nathan")

# -- VARIABLE DECLARATION --

# -- Variables to determine: gamma and lambda --
# 1. gamma_p = choice of cut at node p ([0, 1]^C_p) (only applies to non-leaf node)
#       - represent with a matrix gamma (|L_c| x |C_p|)

gamma = model.addVars(L_c, len(C), vtype=GRB.BINARY, name='gamma') 
# This assumes gamma is binary
#gamma = model.addVars(L_c, len(C), vtype=GRB.BINARY, name='gamma') 

# 2. lambda_pt = choice of treatment t at node p (only applies to leaf nodes L)
#       - represent with a matrix lamb (|L| x m)
lamb = model.addVars(L, m, vtype=GRB.BINARY, name='lamb')


# -- Other Variables in Formulation --
# 1. w_ip = membership of datapoint i in node p (only applies to leaf nodes L)
#       - represent with a matrix w (n x |L|)

c = model.addVars(n, N, vtype=GRB.BINARY, name='c')
# This assumes w is binary, when in reality it is continuous from 0-1
#w = model.addVars(n, L, vtype=GRB.BINARY, name='w') # Original paper has this be a continuous variable

# 2. mu_p = mean outcome of prescribed treatment in node p
#       - represent with a matrix mu (|L|)
v = model.addVars(n, L_c, vtype=GRB.BINARY, name='v') # define in constraint

# 3. nu_ip = "effect" of treatment in node p by multiplying mu and w
#       - represent with a matrix nu (n x |L|)
w = model.addVars(n, vtype=GRB.BINARY, name='w')

{1, 2, 3}


In [None]:


"""if different_Cp:
# ---- K and Z if we followed the definition of C_p from the paper -----
    for p in L_c:
        k[p] = math.ceil(math.log2(len(C[p])))

    z = {}
    for p in L_c:
        matrix = np.zeros((k[p], len(C[p])))
        print(matrix.shape)
        for i in range(1, k[p]+1):
            for j in range(1, len(C[p])+1):
                if math.floor(j/(2**i)) % 2 == 1: # odd number
                    z[i-1, j-1] = 1
                else:
                    z[i-1, j-1] = 0
        z[p] = matrix

else:
    # ---- K and Z if we had constant C for all nodes ----
    k = math.ceil(math.log2(len(C)))
    z = np.zeros((k, len(C)))
    for i in range(1, k+1):
        for j in range(1, len(C)+1):
            if math.floor(j/(2**i)) % 2 == 1: # odd number
                z[i-1, j-1] = 1
            else:
                z[i-1, j-1] = 0
print(z)"""