In [1]:
import os
import glob
from pathlib import Path
import json
import numpy as np
import pandas as pd
from multiprocessing import Pool, cpu_count
import time
import argparse
from typing import List, Dict
import pickle
import gurobipy as gp
from gurobipy import GRB
from multiprocessing import Process

from problem import setcover

In [2]:
from joblib import Parallel, delayed
import time
import numpy as np

In [3]:
from typing import Callable, List, Tuple
import numpy as np

In [4]:
from problem import parallel_generate_problem, parallel_generate_solutions, setcover

In [36]:
# parallel_generate_problem(setcover, "temp_pretrain", n_insts=100, n_jobs=10)

In [35]:
# parallel_generate_solutions("temp_pretrain/", 16)

In [7]:
import random
import torch
from info import ModelInfo, ConInfo, VarInfo


def get_lhs_matrix(n_var: int, con_info: ConInfo) -> torch.Tensor:
    n_con = con_info.n
    shape = (n_con, n_var)
    
    idxs = [[], []]
    vals = []

    for con_idx in range(n_con):
        var_idxs = con_info.lhs_p[con_idx]
        var_cefs = con_info.lhs_c[con_idx]
        for var_idx, var_cef in zip(var_idxs, var_cefs):
            idxs[0].append(con_idx)
            idxs[1].append(var_idx)
            vals.append(var_cef)

    lhs = torch.sparse_coo_tensor(idxs, vals, shape)
    return lhs


def random_shift_binary_var_val(vals, var_info: VarInfo, prob: float=0.2):
    shifted = vals.copy()
    for i, val in enumerate(vals):
        if var_info.types[i] != gp.GRB.BINARY:
            continue
        if random.random() > prob:
            continue
        shifted[i] = 1 - vals[i]
    return np.array(shifted)
            

def get_con_shift(lhs, dv):
    dv = dv[:np.newaxis] if len(dv.shape) == 1 else dv
    shift = lhs @ torch.as_tensor(dv).float()
    return shift.numpy().squeeze()


def get_obj_shift(ks, dv):
    dv = dv.squeeze() if len(dv.shape) == 2 else dv
    shift = sum(k * dv[i] for i, k in ks.items())
    return shift


def shift_model(model, var_shift, rhs_shift):
    # ONLY USED FOR VALIDATION
    var_shift = var_shift.squeeze() if len(var_shift.shape) == 2 else var_shift
    
    shifted = model.copy()
    vs = shifted.getVars()
    # TODO: allow C and I variable bound change
    for v, v_shift in zip(vs, var_shift):
        if v_shift == 0:
            continue
        if v_shift > 0:
            v.setAttr("lb", 1)
            continue
        if v_shift < 0:
            v.setAttr("ub", 0)
            continue

    cs = shifted.getConstrs()
    for c, c_shift in zip(cs, rhs_shift):
        c.setAttr("rhs", c.rhs + c_shift)
        
    shifted.update()
    return shifted


def shift_model_info(info: ModelInfo, var_shift, con_shift, obj_shift):
    info = info.copy()
    var_shift = var_shift.squeeze() if len(var_shift.shape) == 2 else var_shift

    info.var_info.sols[:, 1:] += var_shift
    info.var_info.sols[:, 0]  += obj_shift
    
    for i, v_shift in enumerate(var_shift):
        if v_shift == 0:
            continue
            
        info.var_info.lbs[i] += v_shift
        info.var_info.ubs[i] += v_shift
        
        if info.var_info.types[i] != gp.GRB.BINARY:            
            continue

        info.var_info.lbs[i] = max(info.var_info.lbs[i], 0.0)
        info.var_info.ubs[i] = min(info.var_info.ubs[i], 1.0)
    
    for i, c_shift in enumerate(con_shift):
        if c_shift == 0:
            continue
        info.con_info.rhs[i] += c_shift
    
    return info


def augment_model_info(info: ModelInfo, prob=0.2, n=10):
    assert info.var_info.sols is not None, "info must contain solution at var_info.sols"
    aug = []
    for i in range(n):
        vals = info.var_info.sols[0, 1:]
        shifted_vals = random_shift_binary_var_val(vals, info.var_info, prob=prob)
        lhs = get_lhs_matrix(info.var_info.n, info.con_info)
        var_shfit = shifted_vals - vals
        con_shift = get_con_shift(lhs, var_shfit)
        obj_shift = get_obj_shift(info.obj_info.ks, var_shfit)
        shifted_info = shift_model_info(info, var_shfit, con_shift, obj_shift)
        aug.append(shifted_info)
    return aug
    

In [8]:
from functools import partial

In [9]:
m = setcover()
m.update()
info = ModelInfo.from_model(m)

Restricted license - for non-production use only - expires 2026-11-23


In [10]:
m.optimize()
vals = [v.x for v in m.getVars()]

Gurobi Optimizer version 12.0.1 build v12.0.1rc0 (mac64[x86] - Darwin 22.4.0 22E252)

CPU model: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
Thread count: 8 physical cores, 16 logical processors, using up to 16 threads

Optimize a model with 100 rows, 200 columns and 2057 nonzeros
Model fingerprint: 0xf16fb26f
Variable types: 0 continuous, 200 integer (200 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [5e+00, 2e+01]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+00]
Found heuristic solution: objective 216.0000000
Presolve time: 0.00s
Presolved: 100 rows, 200 columns, 2057 nonzeros
Variable types: 0 continuous, 200 integer (200 binary)

Root relaxation: objective 6.296904e+01, 230 iterations, 0.00 seconds (0.00 work units)

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0   62.96904    0   51  216.00000   62.96904  70.8%

In [11]:
info.var_info.sols = np.array([[m.objVal] + vals])

In [12]:
a = augment_model_info(info, n=2)

In [13]:
shifted_vals = random_shift_binary_var_val(vals, info.var_info)
lhs = get_lhs_matrix(info.var_info.n, info.con_info)

diff = shifted_vals - vals
con_shift = get_con_shift(lhs, diff)
obj_shift = get_obj_shift(info.obj_info.ks, diff)

shifted_m = shift_model(m, diff, con_shift)

In [14]:
obj_shift

361.0

In [15]:
shifted_m.optimize()

Gurobi Optimizer version 12.0.1 build v12.0.1rc0 (mac64[x86] - Darwin 22.4.0 22E252)

CPU model: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
Thread count: 8 physical cores, 16 logical processors, using up to 16 threads

Optimize a model with 100 rows, 200 columns and 2057 nonzeros
Model fingerprint: 0x2b99be54
Variable types: 0 continuous, 200 integer (200 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [5e+00, 2e+01]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+01]
Found heuristic solution: objective 553.0000000
Presolve removed 37 rows and 45 columns
Presolve time: 0.00s
Presolved: 63 rows, 155 columns, 1018 nonzeros
Variable types: 0 continuous, 155 integer (155 binary)

Root relaxation: objective 4.397278e+02, 103 iterations, 0.00 seconds (0.00 work units)

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0  439.72780

In [16]:
import networkx as nx
import torch
from sklearn.preprocessing import StandardScaler
from info import ModelInfo
from graph_preprocessing import get_bipartite_graph, add_label

m = gp.read("temp/0_0.lp")
s = np.load("temp/0_0.npz")['solutions']

info = ModelInfo.from_model(m)

g, con_names = get_bipartite_graph(info)
g = add_label(g, info, s)

Read LP format model from file temp/0_0.lp
Reading time = 0.00 seconds
: 100 rows, 200 columns, 2033 nonzeros


In [17]:
from graph_preprocessing import BipartiteData, constraint_valuation, create_data_object

In [18]:
import os
import random
from typing import Optional
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm


class ModelGraphDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, augment=None):
        self._inst_names = self._get_inst_names(root)
        self._augment = augment
        super().__init__(root, transform, pre_transform, pre_filter)
        self.load(self.processed_paths[0])

    @property
    def inst_names(self):
        return list(self._inst_names)
    
    @property
    def raw_file_names(self):
        mdl_paths = [os.path.join(self.root, f"{n}.lp") for n in self._inst_names]
        sol_paths = [os.path.join(self.root, f"{n}.npz") for n in self._inst_names]
        return mdl_paths + sol_paths

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        raw_info = []
        for n in self._inst_names:
            m = gp.read(os.path.join(self.root, f"{n}.lp"))
            s = np.load(os.path.join(self.root, f"{n}.npz"))['solutions']
            info = ModelInfo.from_model(m)
            info.var_info.sols = s
            raw_info.append((n, info))
    
        aug_info = []
        if self._augment is not None:
            for n, info in tqdm(raw_info, desc="model info augmentation"):
                aug_infos = self._augment(info)
                aug_names = [f"aug_{i}_{n}" for i in range(len(aug_infos))]
                aug_info.extend(zip(aug_names, aug_infos))

        processed = []
        for n, info in tqdm(raw_info + aug_info, desc="create data"):
            data = self.info_to_data(info)
            data.instance_name = n
            processed.append(data)

        random.shuffle(processed)
        torch.save(self.collate(processed), self.processed_paths[0])

    @staticmethod
    def info_to_data(info: ModelInfo):
        sol = info.var_info.sols
        g, _ = get_bipartite_graph(info)
        g = add_label(g, info, sol) if sol is not None else g
        data = create_data_object(g, sol is not None)
        return data
    
    def get(self, idx):
        data = super().get(idx)
        return idx, data
    
    @staticmethod
    def _get_inst_names(root):
        mdl_paths = sorted(p for p in os.listdir(root) if p.endswith(".lp"))
        sol_paths = sorted(p for p in os.listdir(root) if p.endswith(".npz"))
        assert len(mdl_paths) == len(sol_paths), (len(mdl_paths), len(sol_paths))
        assert set(mp[:-2] == sp[:-3] for mp, sp in zip(mdl_paths, sol_paths))
        lp_suffix_len = len(".lp")
        return [p[:-lp_suffix_len] for p in mdl_paths]    

In [20]:
root = "temp"
mdl_paths = sorted(p for p in os.listdir(root) if p.endswith(".lp"))
sol_paths = sorted(p for p in os.listdir(root) if p.endswith(".npz"))

for p in mdl_paths:
    if p.replace(".lp", ".npz") not in sol_paths:
        os.remove(os.path.join(root, p))

In [43]:
d = ModelGraphDataset("./temp", augment=augment_model_info)
data = d[0][1]
var_feature_size = data.var_node_features.size(-1)
con_feature_size = data.con_node_features.size(-1) 

In [44]:
root = "temp_pretrain"
mdl_paths = sorted(p for p in os.listdir(root) if p.endswith(".lp"))
sol_paths = sorted(p for p in os.listdir(root) if p.endswith(".npz"))

for p in mdl_paths:
    if p.replace(".lp", ".npz") not in sol_paths:
        os.remove(os.path.join(root, p))

In [45]:
d_pretrain = ModelGraphDataset("./temp_pretrain", augment=augment_model_info)
data = d[0][1]
var_feature_size = data.var_node_features.size(-1)
con_feature_size = data.con_node_features.size(-1) 

In [46]:
root = "temp_valid"
mdl_paths = sorted(p for p in os.listdir(root) if p.endswith(".lp"))
sol_paths = sorted(p for p in os.listdir(root) if p.endswith(".npz"))

for p in mdl_paths:
    if p.replace(".lp", ".npz") not in sol_paths:
        os.remove(os.path.join(root, p))

In [47]:
valid_d = ModelGraphDataset("./temp_valid")

In [48]:
import pandas as pd
cfgs = pd.read_excel("trained_models/setcover_model_configs.xlsx", index_col=0)

In [49]:
config = cfgs.loc[0].T.to_dict()
config["num_epochs"] = 1

In [50]:
from utils import get_model

In [51]:
model_name, model, criterion, optimizer, scheduler = get_model(".", var_feature_size, con_feature_size, n_batches=1, **config)

In [52]:
from trainer import train

In [53]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed()
    np.random.seed(worker_seed)
    random.seed(worker_seed)

from torch_geometric.loader import DataLoader
pretrain_loader = DataLoader(d_pretrain, batch_size=8, shuffle=True, worker_init_fn=seed_worker, generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(d, batch_size=8, shuffle=True, worker_init_fn=seed_worker, generator=torch.Generator().manual_seed(0))
val_loader = DataLoader(valid_d, batch_size=8, shuffle=True, worker_init_fn=seed_worker, generator=torch.Generator().manual_seed(0))

In [54]:
scheduler.total_steps = 10000

In [None]:
train(model_name, model, criterion, optimizer, scheduler, pretrain_loader, train_loader, val_loader, config, False, "./")

>> Training starts on the current device cpu
>> Pretraining for prenorm...


100%|████████████████████████| 132/132 [00:01<00:00, 131.08it/s]
100%|████████████████████████| 132/132 [00:00<00:00, 216.40it/s]
100%|████████████████████████| 132/132 [00:00<00:00, 161.25it/s]
100%|█████████████████████████| 132/132 [00:01<00:00, 68.57it/s]
100%|█████████████████████████| 132/132 [00:03<00:00, 41.31it/s]
100%|█████████████████████████| 132/132 [00:04<00:00, 31.47it/s]
100%|█████████████████████████| 132/132 [00:05<00:00, 23.68it/s]
100%|█████████████████████████| 132/132 [00:06<00:00, 20.45it/s]
100%|█████████████████████████| 132/132 [00:07<00:00, 16.60it/s]
100%|█████████████████████████| 132/132 [00:09<00:00, 13.71it/s]
100%|█████████████████████████| 132/132 [00:10<00:00, 12.76it/s]
100%|█████████████████████████| 132/132 [00:13<00:00,  9.70it/s]
100%|█████████████████████████| 132/132 [00:16<00:00,  8.10it/s]
100%|█████████████████████████| 132/132 [00:16<00:00,  8.04it/s]
100%|█████████████████████████| 132/132 [00:19<00:00,  6.65it/s]
100%|████████████████████

>> Epoch 1 ----------------------------------------------------------------------------------------------------
Training... 0


  4%|█                        | 58/1364 [00:43<24:14,  1.11s/it]

In [None]:
assert 1 == 2

In [None]:
large_m = setcover(n_rows=200, n_cols=200)
large_d = ModelGraphDataset.inst_to_data(large_m)

In [None]:
assert 1 == 2

In [None]:
logits = model(large_d)
logits

In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from tempfile import NamedTemporaryFile


EVIDENCE_FUNCS = {
    "softplus": (lambda y: F.softplus(y)),
    "relu"    : (lambda y: F.relu(y)),
    "exp"     : (lambda y: torch.exp(torch.clamp(y, -10, 10)))
}


def to_numpy(tensor_obj):
    return tensor_obj.cpu().detach().numpy()


def get_predictions(logits):
    
    binary_mask = to_numpy(data.is_binary).squeeze()
    binary_idx = np.arange(binary_mask.shape[0])[binary_mask]
    probs = torch.softmax(output, axis=1)
    preds = probs[:, 1]
    
    probs = to_numpy(probs)
    preds = to_numpy(preds).squeeze()
    preds[binary_mask] = preds[binary_mask].round()

    return probs, preds


def get_uncertainty(logits, evidence_func_name: str="softplus"):
    evidence = EVIDENCE_FUNCS[evidence_func_name](logits)
    alpha = evidence + 1
    uncertainty = logits.shape[1] / torch.sum(alpha, dim=1, keepdim=True)
    return uncertainty


def get_threshold(uncertainty: torch.Tensor, r_min: float=0.4, r_max: float=0.55):
    q = (r_min + r_max) / 2
    threshold = torch.quantile(uncertainty, q)
    r = (uncertainty <= threshold).float().mean()

    if r > r_max:
        threshold = torch.quantile(uncertainty, r_max)
        ratio = (uncertainty <= threshold).float().mean()
        return threashold

    if r < r_min:
        threshold = torch.quantile(uncertainty, r_min)
        ratio = (uncertainty <= threshold).float().mean()
        return threashold

    return threshold


def get_confident_idx(indices, uncertainty, threashold):
    confident_mask = uncertainty <= threashold
    confident_idx = list(indices[confident_mask])
    return sorted(confident_idx)


def fix_var(inst, idxs, vals):
    assert len(idxs) == len(vals)
    bounds = {}
    vs = inst.getVars()
    for idx, val in zip(idxs, vals):
        v = vs[idx]
        bounds[idx] = (v.lb, v.ub)
        v.setAttr("lb", val)
        v.setAttr("ub", val)
    inst.update()
    return bounds
        

def unfix_var(inst, idxs, bounds):
    assert len(idxs) == len(bounds)
    vs = inst.getVars()
    for i, (lb, ub) in zip(idxs, bounds):
        v.setAttr("lb", lb)
        v.setAttr("ub", ub)
        

def solve(inst):
    vs = inst.getVars()
    inst.optimize()
    return inst.getAttr("X", vs)


def get_iis_vars(inst):
    try:
        inst.computeIIS()
    except Exception as e:
        print(e)
        if "Cannot compute IIS on a feasible model" in str(e):
            return set()
        raise e
    
    with NamedTemporaryFile(suffix=".ilp", mode="w+") as f:
        m.write(f.name)
        f.seek(0)
        return set(f.read().split())


def repair(inst, fixed: set, bounds: dict):
    old_iis_method = getattr(inst, "IISMethod", -1)
    inst.setParam("IISMethod", 0)
    
    vs = inst.getVars()
    ns = inst.getAttr("varName", vs)
    name_to_idx = {n: i for i, n in enumerate(ns)}

    freed = set()
    while iis_var_names := get_iis_vars(inst):
        for n in iis_var_names:
            
            if n not in name_to_idx:
                continue

            var_idx = name_to_idx[n]
            if var_idx not in fixed:
                continue

            if var_idx in freed:
                continue
            
            lb, ub = bounds[var_idx]
            vs[var_idx].lb = lb
            vs[var_idx].ub = ub
            freed.add(var_idx)

    inst.setParam("IISMethod", old_iis_method)
    return freed
    

def set_warmstarts(inst, starts):
    vs = inst.getVars()
    for i, s in starts.items():
        vs[i].setAttr("lb", s)


def get_priorities(uncertainty, indices):
    ...


def set_priority(inst, priorities: dict):
    ...


def reduce_by_uncertainty(inst, prediction, uncertainty, indices, max_iter, timelimit):

    threshold = get_threshold(uncertainty)
    conf_idxs = get_confident_idx(indices, uncertainty, prediction, threshold)
    conf_vals = prediction[conf_idxs]
    bounds = fix_var(inst, conf_idxs, conf_vals)
    
    min_q = sum(uncertainty <= threshold) / len(uncertainty)
    max_q = 1.0
    dq = (max_q - min_q) / (max_iter - 1)
    
    fixed = set(conf_idxs)
    freed = set(repair(inst, fixed, bounds))
    for i in range(1, max_iter):
        sol = solve(inst)
        q = max_q - dq * i
        threshold = np.quantile(uncertainty, q)
        conf_idxs = get_confident_idx(indices, uncertainty, prediction, threshold)
        to_unfix = list(fixed - set(conf_idxs))
        to_unfix = [i for i in to_unfix if i not in freed]
        unfix_var(inst, to_unfix, bounds)
        starts = {i: sol[i] for i in to_unfix}
        starts.update({i: sol[i] for i in freed})
        set_warmstart(inst, starts)
    return sol

In [None]:
from learn.info import ModelInfo
from learn import solver

In [None]:
import json

def with_lic(m):
    with open("gb.lic") as f:
        env = gp.Env(params=json.load(f))
    return m.copy(env=env)

In [None]:
import gurobipy as gp
m = gp.read("model_11.lp")
info = ModelInfo.from_model(m)

In [None]:
edges = [[], []]
for con_i, lhs_p in enumerate(info.con_info.lhs_p):
    shifted_con_i = con_i + info.var_info.n
    for var_i in lhs_p: 
        edges[0].append(var_i)
        edges[1].append(shifted_con_i)

In [None]:
parts = solver.fennel_partition(edges, 4, 0.3, 1)
sub_mappings = []
sub_infos = []
for var_idxs, _ in parts:
    sub_info, sub_mapping = info.subset(var_idxs)
    sub_infos.append(sub_info)
    sub_mappings.append(sub_mapping)

In [None]:
stitch_x = [0 for _ in range(info.var_info.n)]
for sub_info, sub_mapping in zip(sub_infos, sub_mappings):
    cur_m, _ = solver.build_partial_model(sub_info)
    cur_m = with_lic(cur_m)
    cur_m.optimize()
    cur_x = [v.x for v in cur_m.getVars()]
    cur_m.dispose()
    for new_i, old_i in sub_mapping.items():
        stitch_x[old_i] = cur_x[new_i]

In [None]:
import random

predictions = []
idxs = []
uncertainty = []

for i, (v, x) in enumerate(zip(m.getVars(), stitch_x)):
    if v.vtype != "B":
        continue
    predictions.append(x)
    idxs.append(i)
    uncertainty.append(random.random())
    

In [None]:
fixed_idxs = []
fixed_vals = []

for i, p in zip(idxs, predictions):
    if random.random() > 0.05:
        continue
    fixed_idxs.append(i)
    fixed_vals.append(p)
        

In [None]:
m = with_lic(m)

In [None]:
bounds = fix_var(m, fixed_idxs, fixed_vals)


In [None]:
freed = repair(m, fixed_idxs, bounds)

In [None]:
m.optimize()

In [None]:
m.setParam("MIPFocus", 1)
m.setParam("RINS", 10)
m.setParam("TimeLimit", 120)
m.setParam("NoRelHeurTime", 128)
m.optimize()

In [None]:
{1:2}.update({3:4})

In [None]:
len(freed)

In [None]:
len(fixed_idxs)

In [None]:
assert 1 == 2