In [1]:
import numpy as np
import gurobipy as gb
from itertools import combinations

In [2]:
# samples from the control samples are indexed from 0 to n_prime
def extract_n_prime(L_prime, k):
    return int(max(max(max(L_prime[p][i]) for i in range(k[p])) for p in range(P))) + 1

# number of treatment samples (and also cardinality of S)
def extract_n(l):
    return int(sum(l[0]))

def extract_k(l):
    return list(map(len, l))

In [14]:
def compute_imbalance(map_L_prime_to_value, S, l):
    s = 0
    for p in range(P):
        S = list(S)
        values, counts = np.unique(map_L_prime_to_value[p][S], return_counts=True)
        # we might be missing some values in values
        full_counts = np.zeros(len(l[p]), dtype=int)
        full_counts[values] = counts
        s += np.sum(np.abs(full_counts - l[p]))
    return s

def brute_force(l, L_prime):
    k = extract_k(l)
    n_prime = extract_n_prime(L_prime, k)
    n = extract_n(l)
    
    z = np.arange(n_prime)
    
    P = len(k)
    
    map_L_prime_to_value = []
    for p in range(P):
        ls = np.empty(n_prime, dtype=int)
        for i in range(k[p]):
            ls[L_prime[p][i]] = i
        map_L_prime_to_value.append(ls)
    
    m = 10000000
    S_m = None
    for S in combinations(z, n):
        mi = compute_imbalance(map_L_prime_to_value, S, l)
        if mi < m:
            m = mi
            S_m = [S]
        elif mi == m:
            S_m.append(S)
    return S_m, m

In [4]:
# each k[i] consecutive rows of A contain 1 in the j-th
# column if for the (i+1)-th covariate we have that
# z[j] belongs to the k-th bucket of the covariate i,
# where k is the offset from k[i]+sum(k[:i])
def compute_A(L_prime, k, n_prime):
    A = np.zeros((sum(k), n_prime), dtype=int)
    current_row = 0
    for p in range(P):
        Lp_prime = L_prime[p]
        for i in range(k[p]):
            A[current_row, Lp_prime[i]] = 1
            current_row += 1
    return A

def min_imbalance_solver(l, L_prime, verbose=False):
    min_imbalance = gb.Model()
    min_imbalance.modelSense = gb.GRB.MINIMIZE
    min_imbalance.setParam('outputFlag', 0)
    
    n = extract_n(l)
    k = extract_k(l)
    P = len(k)
    l = np.array(np.concatenate(l))
    n_prime = extract_n_prime(L_prime, k)
    
    # 1e
    z = min_imbalance.addMVar(n_prime, vtype=gb.GRB.BINARY)
    y = min_imbalance.addMVar(sum(k))
    
    A = compute_A(L_prime, k, n_prime)
            
    # 1b
    min_imbalance.addConstr(A @ z - l <= y)
    # 1c
    min_imbalance.addConstr(l - A @ z <= y)
    # 1d
    min_imbalance.addConstr(sum(z) == n)
    
    # 1a
    min_imbalance.setObjective(sum(y))
    
    min_imbalance.optimize()
    
    if verbose:
        if min_imbalance.status == 2:
            print('OK')
        else:
            print('Bad things happened')
    return z.x, sum(y.x)

In [5]:
def min_imbalance_solver_sec3(l, L_prime, verbose=False):
    min_imbalance = gb.Model()
    min_imbalance.modelSense = gb.GRB.MINIMIZE
    min_imbalance.setParam('outputFlag', 0)
    
    n = extract_n(l)
    k = extract_k(l)
    P = len(k)
    l = np.array(np.concatenate(l))
    n_prime = extract_n_prime(L_prime, k)
    
    assert P == 2
    
    # 2f
    z = min_imbalance.addMVar(n_prime, vtype=gb.GRB.BINARY)
    # 2e
    e = min_imbalance.addMVar(sum(k), lb=0.0)
    d = min_imbalance.addMVar(sum(k), lb=0.0)
    
    A = compute_A(L_prime, k, n_prime)
    
    for p in range(2):
        # smallest index for the covariate p
        bottom_index = sum(k[:p])
        # biggest index for the covariate p
        top_index = k[p] + bottom_index
        
        sl = slice(bottom_index,top_index)
        # 2b/2c
        min_imbalance.addConstr(A[sl] @ z + d[sl] - e[sl] == l[sl])
        
    # 2d
    min_imbalance.addConstr(sum(e[:k[0]]) - sum(d[:k[0]]) == 0)
    
    # 2a
    min_imbalance.setObjective(sum(e) + sum(d))
    
    min_imbalance.optimize()
    
    if verbose:
        if min_imbalance.status == 2:
            print('OK')
        else:
            print('Bad things happened')
    return z.x, sum(e.x) + sum(d.x)

In [16]:
# ----- VARIABLES -------
n_prime = 15
verbose = False

# n is equal to the size of the treatmeant sample
l = (np.array([5,3,5]), np.array([4,4,5,0,0]))
for li in l:
    assert sum(li) == sum(l[0])

L_prime_sizes = [(10,2,3), (2,3,3,2,5)]
for Lp_prime_sizes in L_prime_sizes:
    assert n_prime == sum(Lp_prime_sizes)
    
for lp, Lp_prime_size in zip(l, L_prime_sizes):
    assert len(lp) == len(Lp_prime_size)
# -----------------------

P = len(l)
L_prime = [None for p in range(P)]
for p in range(P):
    choice = np.random.choice(np.arange(n_prime), size=n_prime, replace=False)
    indexes = np.cumsum(np.array(L_prime_sizes[p]))[:-1]
    splitted_choice = np.split(choice, indexes)
    L_prime[p] = splitted_choice

if verbose:
    for p in range(P):
        print(p, L_prime[p])

print(min_imbalance_solver(l, L_prime))
print(min_imbalance_solver_sec3(l, L_prime))

S_m, m = brute_force(l, L_prime)
print('Min is {}'.format(m))
if verbose:
    for S in S_m:
        print(S)

(array([1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.]), 16.0)
(array([1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.]), 16.0)
Min is 16
