### Imports

In [1]:
import pandas as pd
import numpy as np

### Datasets

In [2]:
# Get metadata
metadata = pd.read_csv('data/metadata.csv', usecols = ['sample-id', 'group'])
metadata

Unnamed: 0,sample-id,group
0,ERR1072630,OTHER
1,ERR1072633,OTHER
2,ERR1072638,OTHER
3,ERR1072639,OTHER
4,ERR1072646,NORMAL
...,...,...
2822,ERR2092593,NORMAL
2823,ERR2092610,DEVIANT
2824,ERR2092680,NORMAL
2825,ERR2092729,NORMAL


In [3]:
metadata['group'].value_counts()

NORMAL     1414
OTHER      1271
DEVIANT     142
Name: group, dtype: int64

In [4]:
# Get NIM data
nim_aminoacids = pd.read_csv('data/nim-aminoacids_400.csv', index_col=0)
nim_aminoacids = nim_aminoacids.drop(nim_aminoacids.iloc[:, 7:], axis = 1)
nim_aminoacidsD = pd.read_csv('data/nim-aminoacidsD_400.csv', index_col=0)
nim_sugars = pd.read_csv('data/nim-sugars_400.csv', index_col=0)
nim_vitamins = pd.read_csv('data/nim-vitamins_400.csv', index_col=0)

print(f"NIM Amino Acids Shape: {nim_aminoacids.shape}")
print(f"NIM Amino Acids D Shape: {nim_aminoacidsD.shape}")
print(f"NIM Sugars Shape: {nim_sugars.shape}")
print(f"NIM Vitamins Shape: {nim_vitamins.shape}")

NIM Amino Acids Shape: (400, 7)
NIM Amino Acids D Shape: (400, 6)
NIM Sugars Shape: (400, 56)
NIM Vitamins Shape: (400, 11)


In [5]:
# Merge all NIMs into one dataframe
nim = nim_aminoacids.merge(nim_aminoacidsD, how = 'inner', on = 'taxonomy')
nim = nim.merge(nim_sugars, how = 'inner', on = 'taxonomy')
nim = nim.merge(nim_vitamins, how = 'inner', on = 'taxonomy')
nim += 1
nim 

Unnamed: 0_level_0,Trp,His,Pro,Leu,Arg,Ile_Val,Tyr,Thr_D,Trp_D,His_D,...,B2,B3,B5,B6,B7,B9,B12,Q,Lipoate,K
taxonomy,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Faecalibacterium prausnitzii,1.9377,1.4771,1.0000,1.0,1.0623,1.0,1.0,1.0000,1.0,1.0623,...,1.4771,2.0,2.0,2.0000,2.0000,1.9377,1.0623,2.0000,2.0000,2.0000
Phocaeicola vulgatus,1.0000,1.0000,1.1089,1.0,1.0000,1.0,1.0,1.8911,1.0,1.8911,...,1.0000,1.0,1.0,1.1089,1.1089,1.0000,1.1089,1.0000,1.1089,1.0000
Prevotella copri,1.0074,1.0000,1.0077,1.0,1.0000,1.0,1.0,1.0000,1.0,1.0000,...,1.0000,1.0,1.0,1.0077,2.0000,1.0077,2.0000,1.3099,2.0000,1.0074
Bacteroides uniformis,1.0000,1.0000,1.0000,1.0,1.0000,1.0,1.0,2.0000,2.0,2.0000,...,1.0000,1.0,1.0,1.0000,1.0000,1.0000,1.0000,1.0000,1.0000,2.0000
[Eubacterium] rectale,1.0000,1.0000,1.0000,1.0,1.0000,1.0,1.0,1.0000,1.0,1.0000,...,1.0000,1.0,1.0,1.0000,2.0000,2.0000,1.0000,2.0000,2.0000,2.0000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Bacteroides caccae/Bacteroides intestinalis/Alloprevotella rava,2.0000,2.0000,2.0000,2.0,2.0000,2.0,1.0,2.0000,1.0,2.0000,...,1.0000,2.0,1.0,1.0000,2.0000,1.0000,2.0000,1.0000,2.0000,1.0000
Ruthenibacterium lactatiformans/Fournierella massiliensis,2.0000,1.2000,1.0000,1.0,1.6000,1.0,1.0,1.2000,1.0,1.0000,...,2.0000,1.8,2.0,2.0000,2.0000,2.0000,1.0000,2.0000,2.0000,2.0000
Stenotrophomonas geniculata/maltophilia/pavanii,1.0000,1.0000,1.0000,1.0,1.0000,1.0,1.0,2.0000,2.0,2.0000,...,1.0000,1.0,1.0,1.0000,1.0000,1.0000,2.0000,1.0000,1.0000,2.0000
Tepidibaculum saccharolyticum/Ruminococcus albus,1.0000,1.0000,1.0000,1.0,1.0000,1.0,1.0,1.0000,1.0,1.0000,...,1.5000,1.0,1.5,1.5000,2.0000,2.0000,2.0000,1.5000,2.0000,2.0000


In [6]:
# Get taxonomy data
taxonomy = pd.read_csv('data/taxonomy_400.csv', index_col=0)
taxonomy = taxonomy.replace('%', '', regex = True).astype(np.float64)
taxonomy

Unnamed: 0_level_0,ERR2032802,ERR1845748,ERR2092355,ERR1845937,ERR1249738,ERR1090583,ERR1250049,ERR1459183,ERR1075840,ERR1458892,...,ERR2057080,ERR1678465,ERR1074540,ERR1077998,ERR2033465,ERR1389801,ERR1845840,ERR1842195,ERR1075183,ERR1077294
taxonomy,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Faecalibacterium prausnitzii,4.4,12.3,2.0,3.1,11.6,5.8,2.5,9.8,6.8,3.3,...,1.8,10.3,2.7,0.9,2.6,0.7,3.6,2.9,0.2,4.7
Phocaeicola vulgatus,23.4,11.0,19.4,10.9,4.4,5.7,1.7,16.8,5.2,16.1,...,0.5,0.3,0.1,0.0,8.5,5.3,0.5,0.1,0.1,0.0
Prevotella copri,0.1,0.0,0.1,32.4,48.0,71.8,63.9,0.0,0.0,0.0,...,0.0,0.9,7.6,0.1,0.0,0.0,0.7,0.0,0.2,0.5
Bacteroides uniformis,4.1,5.5,4.0,0.0,0.2,0.0,0.7,4.8,7.2,1.1,...,0.5,0.3,0.0,2.3,2.0,0.2,1.0,0.0,0.0,0.1
[Eubacterium] rectale,4.2,22.7,1.1,0.0,0.3,0.0,0.2,6.6,2.9,0.0,...,0.0,0.2,0.1,0.0,0.6,0.1,0.6,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Bacteroides caccae/Bacteroides intestinalis/Alloprevotella rava,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ruthenibacterium lactatiformans/Fournierella massiliensis,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Stenotrophomonas geniculata/maltophilia/pavanii,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Tepidibaculum saccharolyticum/Ruminococcus albus,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Data Cleaning

In [7]:
# Separating out normal & deviant sample-ids
normal = metadata.loc[metadata.group=='NORMAL']
normal = normal.drop(['group'], axis = 1)
normal_ids = normal.to_numpy()

deviant = metadata.loc[metadata.group=='DEVIANT']
deviant = deviant.drop(['group'], axis = 1)
deviant_ids = deviant.to_numpy()

In [8]:
# Taxonomies of normal and deviant samples
normal_taxonomy = taxonomy[normal_ids.reshape(-1)]
deviant_taxonomy = taxonomy[deviant_ids.reshape(-1)]

# Dataframes to numpy arrays
normal_taxonomy = normal_taxonomy.to_numpy()
deviant_taxonomy = deviant_taxonomy.to_numpy()

### Computation

In [9]:
# Compute percentile for specific microbe
# - testing vlow and vhigh for Faecalibacterium prausnitzii
vlow_test = np.percentile(normal_taxonomy[0], 10)
vhigh_test = np.percentile(normal_taxonomy[0], 90)

print("Faecalibacterium prausnitzii")
print(f"v_low : {vlow_test}\nv_high: {vhigh_test}")

Faecalibacterium prausnitzii
v_low : 5.4
v_high: 18.3


In [10]:
# Store vhigh / vlow in arrays
# - with an additional allowance of 5% on both ends
vlow = np.percentile(normal_taxonomy, 10, axis=1) - 5
vhigh = np.percentile(normal_taxonomy, 90, axis=1) + 5

print("Faecalibacterium prausnitzii")
print(f"v_low : {vlow[0]}\nv_high: {vhigh[0]}")

Faecalibacterium prausnitzii
v_low : 0.40000000000000036
v_high: 23.3


In [11]:
def violations(u):
    """ 
    Computes number of ASVs with abundance outside vlow and vhigh (i.e. violation)
    Returns an array of integers representing how many violations each sample has
    """
    
    if len(u.shape) == 1:
        u = u[np.newaxis, :]
        
    assert u.shape[1] == len(vlow) == len(vhigh)
    
    u = u / u.sum(axis=1, keepdims=True) * 100
    vio_low = u < vlow[np.newaxis, :]
    vio_high = u > vhigh[np.newaxis, :]
    vio = (vio_low | vio_high).astype(np.int32).sum(axis=1)
    
    return vio

In [12]:
normal_taxonomy_violoation = violations(normal_taxonomy.transpose((1, 0)))
deviant_taxonomy_violoation = violations(deviant_taxonomy.transpose((1, 0)))

print(f"Normal Taxonomy \n{normal_taxonomy_violoation}\nMean Violoations: {normal_taxonomy_violoation.mean()}\n")
print(f"Deviant Taxonomy \n{deviant_taxonomy_violoation}\nMean Violoations: {deviant_taxonomy_violoation.mean()}\n")

Normal Taxonomy 
[0 1 0 ... 3 0 1]
Mean Violoations: 1.1364922206506365

Deviant Taxonomy 
[ 5  8  2  2  5  1  5  3  4  6  3  5  4  4  2  4  4  4  6  3  2  4  4  3
  7  2  2  5  3  3  7  3  4  2  6  1  3  2  6  4  2  4  3  2  3  6  5  4
  3  4  4  5  6  5 11  6  7  8  6  7  5  7  6  8  8  7  5  8  8  6  7  5
  8  6  5  8  7  6  8  8  7  5 10  6  5  8  7  8  9  6  2  3  2  3  2  2
  3  5  3  1  1  7  1  1  2  3  3  4  1  1  2  4  3  4  1  7  3  4  5  6
  3  1  2  8  2  1  3  3  1  2  4  4  4  2  5  1  4  3  1  1  7  5]
Mean Violoations: 4.345070422535211



In [13]:
#Transpose Deviant Taxonomy to be compatiable with NIM
deviant_taxonomy_T = deviant_taxonomy.T

#sample usage
working = nim['His'] * deviant_taxonomy_T[0]
print(working)

taxonomy
Faecalibacterium prausnitzii                                       7.3855
Phocaeicola vulgatus                                               0.3000
Prevotella copri                                                   0.0000
Bacteroides uniformis                                              0.0000
[Eubacterium] rectale                                              0.4000
                                                                    ...  
Bacteroides caccae/Bacteroides intestinalis/Alloprevotella rava    0.0000
Ruthenibacterium lactatiformans/Fournierella massiliensis          0.0000
Stenotrophomonas geniculata/maltophilia/pavanii                    0.0000
Tepidibaculum saccharolyticum/Ruminococcus albus                   0.0000
Sellimonas intestinalis/Drancourtella massiliensis                 0.1000
Name: His, Length: 400, dtype: float64


In [14]:
arr = []
#Calculations for impact on first deviant sample.
for n in nim:
    working  = nim[n] * deviant_taxonomy_T[0]
    arr.append(working)
#Todo: Use V_high / low for filtering results. 

### Impact

#### Deviant Taxonomy

In [15]:
# Transpose Deviant Taxonomy to be compatiable with NIM
deviant_taxonomy_T = deviant_taxonomy.T

# NIM into numpy
nim_np = nim.to_numpy()

# Get a single deviant sample for testing
deviant_sample_0 = deviant_taxonomy_T[0]
deviant_sample_0 = deviant_sample_0[:, np.newaxis]

# Check shapes
print(f"Shape of NIM: {nim_np.shape}")
print(f"Shape of Test Deviant Sample: {deviant_sample_0.shape}")

# Nutrient impact on a single deviant sample
NIM_deviant_0 = np.multiply(nim_np, deviant_sample_0)

Shape of NIM: (400, 80)
Shape of Test Deviant Sample: (400, 1)


In [16]:
# Testing.
# Violations of this sample after i-th nutrient is supplied. 
np.seterr(invalid='ignore')
NIM_dev0_violation = violations(NIM_deviant_0.transpose(1,0))

# Change in number of violations after i-th nutrient intervention
NIM_dev0_violation - deviant_taxonomy_violoation[0]

array([ 1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,
        0,  1,  0,  1,  1,  0,  2,  2,  0,  3,  1,  1,  0, -1,  0,  0,  1,
       -2,  0,  1,  1,  0,  0, -1,  1,  1,  0,  1,  1,  1,  1,  0,  2, -1,
        0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  0,  0,  0,
        1,  1,  2,  0,  1,  1,  0, -1,  0,  1,  1,  1])

In [17]:
# Inefficient way to just quickly aggregate
# - using the snippets from above:

violation_change_deviant = []
for i,deviant_sample in enumerate(deviant_taxonomy_T):
    
    NIM_deviant = np.multiply(nim_np, deviant_sample[:, np.newaxis])
    NIM_dev_violation = violations(NIM_deviant.transpose(1,0))
    
    violation_change_deviant.append(NIM_dev_violation - deviant_taxonomy_violoation[i])
    
violation_change_deviant = np.array(violation_change_deviant)
#violation_change_deviant = violation_change_deviant.transpose(1,0)
violation_change_deviant

array([[ 1,  0,  1, ...,  1,  1,  1],
       [-1, -1, -1, ..., -2, -2, -2],
       [ 0,  0,  0, ...,  0,  0,  0],
       ...,
       [ 0,  0,  0, ...,  1,  0,  1],
       [ 0, -1,  0, ...,  0,  1,  0],
       [-1,  0,  0, ..., -1,  0, -1]])

#### Normal Taxonomy

In [18]:
# Transpose Normal Taxonomy to be compatiable with NIM
normal_taxonomy_T = normal_taxonomy.T

# NIM into numpy
nim_np = nim.to_numpy()

# Get a single normal sample for testing
normal_sample_0 = normal_taxonomy_T[0]
normal_sample_0 = normal_sample_0[:, np.newaxis]

# Check shapes
print(f"Shape of NIM: {nim_np.shape}")
print(f"Shape of Test Deviant Sample: {normal_sample_0.shape}")

# Nutrient impact on a single normal sample
NIM_normal_0 = np.multiply(nim_np, normal_sample_0)
NIM_normal_0

Shape of NIM: (400, 80)
Shape of Test Deviant Sample: (400, 1)


array([[25.77141, 19.64543, 13.3    , ..., 26.6    , 26.6    , 26.6    ],
       [12.6    , 12.6    , 13.97214, ..., 12.6    , 13.97214, 12.6    ],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ],
       ...,
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ]])

In [19]:
# Testing.
# Violations of this sample after i-th nutrient is supplied. 
np.seterr(invalid='ignore')
NIM_norm0_violation = violations(NIM_normal_0.transpose(1,0))

# Change in number of violations after i-th nutrient intervention
NIM_norm0_violation - normal_taxonomy_violoation[0]

array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
       0, 0, 2, 1, 0, 2, 1, 1, 2, 0, 0, 2, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0])

In [20]:
violation_change_normal = []

for i,normal_sample in enumerate(normal_taxonomy_T):
    
    NIM_normal = np.multiply(nim_np, normal_sample[:, np.newaxis])
    NIM_norm_violation = violations(NIM_normal.transpose(1,0))
    
    violation_change_normal.append(NIM_norm_violation - normal_taxonomy_violoation[i])
    
violation_change_normal = np.array(violation_change_normal)
violation_change_normal

array([[ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  1,  1,  0],
       [ 2,  1,  0, ...,  1,  2,  1],
       ...,
       [ 0,  0,  0, ...,  0, -2, -1],
       [ 0,  0,  0, ...,  1,  0,  1],
       [ 0,  0,  0, ...,  0,  0, -1]])

### Selecting Nutrients

In [21]:
### Find top nutrients that reduces the most number of violations 
### in violation_change_deviant

In [22]:
print(violation_change_deviant)
print(violation_change_deviant.shape)

[[ 1  0  1 ...  1  1  1]
 [-1 -1 -1 ... -2 -2 -2]
 [ 0  0  0 ...  0  0  0]
 ...
 [ 0  0  0 ...  1  0  1]
 [ 0 -1  0 ...  0  1  0]
 [-1  0  0 ... -1  0 -1]]
(142, 80)


In [23]:
# Change in total number of violations after i-th nutrient intervention across 
# all 142 deviant samples
# - each represents the number of violations i-th nutrient reduced
violation_change_d_sum = violation_change_deviant.sum(axis=0)
violation_change_d_sum

array([ 10,  -2,   4,   9,   2,   6,   3,   9,  12,  -1, -24,   0,   7,
         3,  16, -15,  12,  -5, -13,  17,   1,  -2, -21,  15,  15,   8,
        -8,  -2,   1, -21,   3, -13,  15,   1,   0, -26, -23,  22,  -4,
        13,   6,   1,  11,  18,  -2,  16,  -8,  21,  -2,   8,   9,   4,
        -1,   3,   4,  10,   9,   1,   8,  -1,   5,   5,   3,   3,   1,
         1,  -4,   2,  -1,  13,  25,   6,  15,  -5,   2,   5,   7,  -9,
        17,  15])

In [24]:
# indices of nutrients that reduced the most number of violations
smallest_indices = np.argsort(violation_change_d_sum)[:10]
smallest_indices

array([35, 10, 36, 29, 22, 15, 18, 31, 77, 46], dtype=int64)

In [25]:
### Find top nutrients that increases the least number of violations 
### in violation_change_normal

In [26]:
print(violation_change_normal)
print(violation_change_normal.shape)

[[ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  1  1  0]
 [ 2  1  0 ...  1  2  1]
 ...
 [ 0  0  0 ...  0 -2 -1]
 [ 0  0  0 ...  1  0  1]
 [ 0  0  0 ...  0  0 -1]]
(1414, 80)


In [27]:
# Change in total number of violations after i-th nutrient intervention across 
# all 1412 samples
# - each represents the number of violations i-th nutrient increased
violation_change_n_sum = violation_change_normal.sum(axis=0)
violation_change_n_sum

array([ 86,  91, -10, 162,  47, 146,  27, 309, 232, 267, 245,  92, -43,
       183, 131, 195, 216, 182, 157, 133, 237,  63, 273,  39, 312, 262,
        61, 277, 254, 319, 300, 225, -51, 207,  66, 293, 331, 225, 338,
         9,  59, 269, -33, -11, 295,  -7,  20, 235, 258, -40, 117, 296,
         4,  35, -36,  -6, -26, -19,  15,  38,  -9,  11,  -6,   0,   0,
        -4,  18,  25,   5, 108,  64, 160,  13,  52, -37,  67, 397,  60,
        13, -51])

In [28]:
# indices of nutrients that increased the least number of violations
smallest_indices = np.argsort(violation_change_n_sum)[:10]
smallest_indices

array([79, 32, 12, 49, 74, 54, 42, 56, 57, 43], dtype=int64)

In [29]:
from functools import partial
import scipy
import scipy.optimize
from tqdm import tqdm

def convert_problem_into_lp(u, v_low, v_high, nim, sparsity_penalty, constraint_violation_penalty):
    # min violations of v_low <= nim * x + u <= v_high, x should be sparse
    n, m = nim.shape
    assert u.shape[0] == v_low.shape[0] == v_high.shape[0] == n
    constraint_coef = []
    bias_v = []
    for i in range(n):
        # nim * x - 1 * y <= v_high - u
        constraint_coef.append(nim[i].tolist() + [-1 if j == i else 0 for j in range(2 * n)])
        bias_v.append(v_high[i] - u[i])
    for i in range(n):
        # -nim * x - 1 * y <= -v_low + u
        constraint_coef.append((-nim[i]).tolist() + [-1 if j - n == i else 0 for j in range(2 * n)])
        bias_v.append(-v_low[i] + u[i])
    c = []
    for i in range(m):
        c.append(sparsity_penalty)
    for i in range(2 * n):
        c.append(constraint_violation_penalty)
    return np.array(constraint_coef), np.array(bias_v), np.array(c)

# nim_np = nim.to_numpy()

# A, b, c = convert_problem_into_lp(deviant_taxonomy[:, 0], vlow, vhigh, nim_np, 0., 0.0)
# print(A.shape, b.shape, c.shape)
# import scipy.optimize
# r = scipy.optimize.linprog(c, A_ub=A, b_ub=b, A_eq=None, b_eq=None, bounds=None, method='highs', callback=None, options=None, x0=None)

# print(r.x[: 80].nonzero())
# print(r.success)
# print(violations(deviant_taxonomy[:, 0]))
# print(violations((nim_np @ r.x[: 80][:, np.newaxis]).squeeze(1) + deviant_taxonomy[:, 0]))


nim_np = nim.to_numpy()
nvars = nim_np.shape[1]

def grid_search(k_sparse, sparsity_low, sparsity_high, sparsity_delta, violation_low, violation_high, violation_delta, lp_fn, eval_fn):
    best_answer = 1e9
    sparsity_chunks = int((sparsity_high - sparsity_low) / sparsity_delta + 0.5)
    violation_chunks = int((violation_high - violation_low) / violation_delta + 0.5)
    for sparsity_i in range(sparsity_chunks):
        sparsity_p = sparsity_low + sparsity_i * sparsity_delta
        for violation_i in range(violation_chunks):
            violation_p = violation_low + violation_i * violation_delta
            r = lp_fn(sparsity_penalty=sparsity_p, constraint_violation_penalty=violation_p)
            if r.success:
                non_zero_elements = len(r.x[: nvars].nonzero()[0])
                if non_zero_elements <= k_sparse:
                    best_answer = min(best_answer, eval_fn(r.x[: nvars]))
    return best_answer

def binary_search(k_sparse, sparsity_low, sparsity_high, sparsity_delta, violation_low, violation_high, violation_delta, lp_fn, eval_fn):
    best_answer = 1e9
    # sparsity_p higher -> valid x
    # violation_p higher -> better sol
    # highest violation such that sparsity is satisfied
    while violation_low < violation_high - violation_delta:
        violation_p = (violation_low + violation_high) / 2
        s_low = sparsity_low
        s_high = sparsity_high
        found = False
        while s_low < s_high - sparsity_delta:
            sparsity_p = (s_low + s_high) / 2
            r = lp_fn(sparsity_penalty=sparsity_p, constraint_violation_penalty=violation_p)
            if r.success and len(r.x[: nvars].nonzero()[0]) <= k_sparse:
                best_answer = min(best_answer, eval_fn(r.x[: nvars]))
                s_high = sparsity_p
                found = True
            else:
                s_low = sparsity_p
        if found:
            violation_low = violation_p
        else:
            violation_high = violation_p
    return best_answer

def solve_problem_with_lp(u, v_low, v_high, nim, sparsity_penalty, constraint_violation_penalty):
    A, b, c = convert_problem_into_lp(u, v_low, v_high, nim, sparsity_penalty, constraint_violation_penalty)
    return scipy.optimize.linprog(c, A_ub=A, b_ub=b, method='highs')

k_sparse = 10
violations_before = []
violations_after = []
for deviant_sample in tqdm(deviant_taxonomy.T):
    lp_fn = partial(solve_problem_with_lp, u=deviant_sample, v_low=vlow, v_high=vhigh, nim=nim_np)
    def eval_fn(x):
        return violations((nim_np @ x[:, np.newaxis]).squeeze(1) + deviant_sample)[0]
    best_answer = grid_search(k_sparse, 0, 1, 0.2, 0, 1, 0.2, lp_fn, eval_fn)
#     print("violations", violations(deviant_sample)[0], "->", best_answer)
    violations_before.append(violations(deviant_sample)[0])
    violations_after.append(best_answer)
print(np.mean(violations_before), np.mean(violations_after))

100%|████████████████████████████████████████████████████████████████████████████████| 142/142 [11:33<00:00,  4.88s/it]

4.345070422535211 4.253521126760563



