In [2]:
from functools import partial
import numpy as np
import scipy
import scipy.optimize
from tqdm import tqdm
import data
violations, vlow, vhigh, deviant_taxonomy_np, nim_np, deviant_taxonomy, nim = data.preprocess()

NORMAL     1414
OTHER      1271
DEVIANT     142
Name: group, dtype: int64
NIM Amino Acids Shape: (400, 7)
NIM Amino Acids D Shape: (400, 6)
NIM Sugars Shape: (400, 56)
NIM Vitamins Shape: (400, 11)
nim                                                        Trp     His     Pro  \
taxonomy                                                                     
Faecalibacterium prausnitzii                        0.9377  0.4771  0.0000   
Phocaeicola vulgatus                                0.0000  0.0000  0.1089   
Prevotella copri                                    0.0074  0.0000  0.0077   
Bacteroides uniformis                               0.0000  0.0000  0.0000   
[Eubacterium] rectale                               0.0000  0.0000  0.0000   
...                                                    ...     ...     ...   
Bacteroides caccae/Bacteroides intestinalis/All...  1.0000  1.0000  1.0000   
Ruthenibacterium lactatiformans/Fournierella ma...  1.0000  0.2000  0.0000   
Stenotrophomonas g

In [3]:
# extra defs
nvars = nim_np.shape[1]

In [4]:
def convert_problem_into_lp(u, v_low, v_high, nim, sparsity_penalty, constraint_violation_penalty):
    # min violations of v_low <= nim * x + u <= v_high, x should be sparse
    n, m = nim.shape
    assert u.shape[0] == v_low.shape[0] == v_high.shape[0] == n
    constraint_coef = []
    bias_v = []
    for i in range(n):
        # nim * x - 1 * y <= v_high - u
        constraint_coef.append(nim[i].tolist() + [-1 if j == i else 0 for j in range(2 * n)])
        bias_v.append(v_high[i] - u[i])
    for i in range(n):
        # -nim * x - 1 * y <= -v_low + u
        constraint_coef.append((-nim[i]).tolist() + [-1 if j - n == i else 0 for j in range(2 * n)])
        bias_v.append(-v_low[i] + u[i])
    c = []
    for i in range(m):
        c.append(sparsity_penalty)
    for i in range(2 * n):
        c.append(constraint_violation_penalty)
    return np.array(constraint_coef), np.array(bias_v), np.array(c)

nvars = nim_np.shape[1]

def grid_search(k_sparse, sparsity_low, sparsity_high, sparsity_delta, violation_low, violation_high, violation_delta, lp_fn, eval_fn):
    best_answer = (1e9, )
    sparsity_chunks = int((sparsity_high - sparsity_low) / sparsity_delta + 0.5)
    violation_chunks = int((violation_high - violation_low) / violation_delta + 0.5)
    for sparsity_i in range(sparsity_chunks):
        sparsity_p = sparsity_low + sparsity_i * sparsity_delta
        for violation_i in range(violation_chunks):
            violation_p = violation_low + violation_i * violation_delta
            r = lp_fn(sparsity_penalty=sparsity_p, constraint_violation_penalty=violation_p)
            if r.success:
                non_zero_elements = len(r.x[: nvars].nonzero()[0])
                if non_zero_elements <= k_sparse:
                    best_answer = min(best_answer, (eval_fn(r.x[: nvars]), sparsity_p, violation_p, r.x[: nvars]))
    return best_answer

def binary_search(k_sparse, sparsity_low, sparsity_high, sparsity_delta, violation_low, violation_high, violation_delta, lp_fn, eval_fn):
    best_answer = (1e9, )
    # sparsity_p higher -> valid x
    # violation_p higher -> better sol
    # highest violation such that sparsity is satisfied
    while violation_low < violation_high - violation_delta:
        violation_p = (violation_low + violation_high) / 2
        s_low = sparsity_low
        s_high = sparsity_high
        found = False
        while s_low < s_high - sparsity_delta:
            sparsity_p = (s_low + s_high) / 2
            r = lp_fn(sparsity_penalty=sparsity_p, constraint_violation_penalty=violation_p)
            if r.success and len(r.x[: nvars].nonzero()[0]) <= k_sparse:
                best_answer = min(best_answer, (eval_fn(r.x[: nvars]), sparsity_p, violation_p, r.x[: nvars]))
                s_high = sparsity_p
                found = True
            else:
                s_low = sparsity_p
        if found:
            violation_low = violation_p
        else:
            violation_high = violation_p
    return best_answer

def solve_problem_with_lp(u, v_low, v_high, nim, sparsity_penalty, constraint_violation_penalty):
    A, b, c = convert_problem_into_lp(u, v_low, v_high, nim, sparsity_penalty, constraint_violation_penalty)
    return scipy.optimize.linprog(c, A_ub=A, b_ub=b, method='highs')

In [5]:
def solve_problem_with_random(nvars, k_sparse):
    indices = np.random.choice(nvars, k_sparse)
    values = np.random.uniform(0, 1, k_sparse)
    x = np.zeros(nvars)
    x[indices] = values
    return x

def random_trial(k_sparse, max_trials, sample_fn, eval_fn):
    best_answer = (1e9, )
    for trial in range(max_trials):
        x = sample_fn(k_sparse=k_sparse)
        for alpha in range(20):
            x_p = x * alpha / 10
            best_answer = min(best_answer, (eval_fn(x_p), alpha, trial, x_p))
    return best_answer

In [6]:
k_sparse = 10
violations_before = []
violations_after = []
def eval_fn(x):
    return violations((nim_np @ x[:, np.newaxis]).squeeze(1) + deviant_sample)[0]
for deviant_sample in tqdm(deviant_taxonomy_np.T):
    # linear_programming, can use either binary_search or grid search
    lp_fn = partial(solve_problem_with_lp, u=deviant_sample, v_low=vlow, v_high=vhigh, nim=nim_np)
    best_answer = binary_search(k_sparse, 0, 1, 0.2, 0, 1, 0.2, lp_fn, eval_fn)

    # random search
#     sample_fn = partial(solve_problem_with_random, nvars=nvars)
#     best_answer = random_trial(k_sparse, 1000, sample_fn, eval_fn)

    violations_before.append(violations(deviant_sample)[0])
    violations_after.append(best_answer[0])
print(np.mean(violations_before), np.mean(violations_after))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 142/142 [11:40<00:00,  4.94s/it]

11.76056338028169 3.8028169014084505



