In [1]:
import torch
import gurobipy as gp

In [2]:
from learn.train import build_inst, build_graphs, remove_redundant_nodes, get_train_mask, get_solution_mask, get_mask_node_feature
from learn.model import Model, FocalLoss
from learn.generator import maximum_independent_set_problem
from learn.info import ModelInfo, VarInfo, ConInfo
from learn import solver

In [3]:
import json

def with_lic(m):
    with open("gb.lic") as f:
        env = gp.Env(params=json.load(f))
    return m.copy(env=env)

In [4]:
def get_constraint_side_matrices(var_info: VarInfo, con_info: ConInfo):

    idxs = [[], []]
    vals = []

    for con_idx in tqdm(range(con_info.n)):
        var_idxs = con_info.lhs_p[con_idx]
        var_cefs = con_info.lhs_c[con_idx]
        for var_idx, var_cef in zip(var_idxs, var_cefs):
            idxs[0].append(con_idx)
            idxs[1].append(var_idx)
            vals.append(var_cef)

    lhs = torch.sparse_coo_tensor(idxs, vals, (con_info.n, var_info.n))
    rhs = np.array(con_info.rhs)[:, np.newaxis]
    ops = np.array(con_info.types)[:, np.newaxis]
    return lhs, rhs, ops


def get_constraint_violations(lhs, vs, rhs, ops):
    # TODO: add type handling/preprocessing
    vs = torch.as_tensor(np.array(vs)[:, np.newaxis]).float()
    lt_ops = ops == ConInfo.ENUM_TO_OP["<="]
    eq_ops = ops == ConInfo.ENUM_TO_OP["=="]
    gt_ops = ops == ConInfo.ENUM_TO_OP[">="]

    lhs_vs = lhs @ vs
    lhs_vs = lhs_vs.numpy()

    diff = lhs_vs - rhs
    violations = np.zeros_like(diff, dtype=bool)
    violations[lt_ops] = diff[lt_ops] <= 0
    violations[gt_ops] = diff[gt_ops] >= 0
    violations[eq_ops] = diff[eq_ops] == 0
    diff[violations] = 0
    return np.abs(diff)


In [10]:
torch.tensor

<function torch._VariableFunctionsClass.tensor>

In [5]:
from tqdm import tqdm
import numpy as np 

In [6]:
def collect_feasible_solutions(model, n=64):
    model.setParam('PoolSolutions', n)
    model.setParam('PoolSearchMode', 2)
    model.optimize()
    
    vs = model.getVars()
    sols = []
    for i in tqdm(range(model.SolCount)):
        # TODO setting solution number actually takes long time
        # try to optimize
        model.params.SolutionNumber = i
        obj_val = model.PoolObjVal
        s = [v.Xn for v in vs]
        sols.append((s, obj_val))
    return sols
        
m = maximum_independent_set_problem(num_nodes=128)
info = ModelInfo.from_model(m)

lhs, rhs, ops = get_constraint_side_matrices(info.var_info, info.con_info)

m = with_lic(m)
m.params.NoRelHeurTime = 30
s = collect_feasible_solutions(m)

Restricted license - for non-production use only - expires 2026-11-23


100%|███████████████████| 2489/2489 [00:00<00:00, 709117.15it/s]

Set parameter CloudAccessID
Set parameter CloudSecretKey
Set parameter CloudPool to value "831775-C3Dev"
Set parameter CSAppName to value "Josh"





Compute Server job ID: 5cd76ef2-93cc-4129-9e65-038202ebccbf
Capacity available on '831775-C3Dev' cloud pool - connecting...
Established HTTPS encrypted connection
Set parameter NoRelHeurTime to value 30
Set parameter PoolSolutions to value 64
Set parameter PoolSearchMode to value 2
Gurobi Optimizer version 12.0.0 build v12.0.0rc1 (mac64[x86] - Darwin 22.4.0 22E252)
Gurobi Compute Server Worker version 12.0.0 build v12.0.0rc1 (linux64 - "Ubuntu 20.04.6 LTS")

CPU model: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz, instruction set [SSE2|AVX|AVX2|AVX512]
Thread count: 8 physical cores, 16 logical processors, using up to 16 threads

Non-default parameters:
NoRelHeurTime  30
CSIdleTimeout  1800
PoolSolutions  64
PoolSearchMode  2

Optimize a model with 2489 rows, 129 columns and 4978 nonzeros
Model fingerprint: 0x0700a2ed
Variable types: 1 continuous, 128 integer (128 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [1e+00, 1e+00]
  Bounds range     [1e

100%|███████████████████████████| 64/64 [00:14<00:00,  4.57it/s]


In [48]:
def sv(sol):
    violation = get_constraint_violations(lhs, sol, rhs, ops) 
    return not any(violation)

In [49]:
def perturb_solution(solution, validator, n=32, ratio=0.1):
    
    n_vars = len(solution)
    rand_idxs = np.random.randint(0, len(solution), size=max(int(n_vars * ratio), 1))
    
    pos = []
    neg = []

    for _ in tqdm(range(n)):
        cur_pos = []
        cur_neg = []

        s = solution
        p = []
        for i in rand_idxs:
            
            s = s.copy()
            s[i] = 1 - s[i]
            p.append(i)
            perturbed = (s, p.copy())
            
            if cur_neg:
                cur_neg.append(perturbed)
                continue
                
            valid = validator(s)
            if not valid:
                cur_neg.append(perturbed)
            else:
                cur_pos.append(perturbed)

        pos.extend(cur_pos)
        neg.extend(cur_neg)
        
    return pos, neg
        
    

In [50]:
for each in s:
    pos, neg = perturb_solution(each[0], sv, n=1000)

100%|█████████████████████| 1000/1000 [00:00<00:00, 3212.47it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 1209.14it/s]


1000


100%|█████████████████████| 1000/1000 [00:00<00:00, 1890.59it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3851.52it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 1766.90it/s]


1000


100%|█████████████████████| 1000/1000 [00:00<00:00, 1317.32it/s]


1000


100%|█████████████████████| 1000/1000 [00:00<00:00, 3367.17it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3816.40it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2927.63it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 4474.57it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3960.79it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2451.24it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 4152.55it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 4302.22it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2610.54it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3704.49it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2207.72it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2008.65it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 4696.68it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2848.51it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3895.38it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3566.72it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2648.97it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2991.83it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2873.12it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 2247.04it/s]


0


100%|█████████████████████| 1000/1000 [00:00<00:00, 3930.88it/s]


0


 76%|████████████████▊     | 762/1000 [00:00<00:00, 4405.27it/s]


KeyboardInterrupt: 

In [45]:
len(pos)

0

In [8]:
assert 1 == 2

AssertionError: 

In [None]:
s[0][0] == s[1][0]

In [None]:

m.setParam("PoolSolutions", int(1e5))
m.optimize()

In [None]:
# collect solutions
# random pertub to get feasible - infeasible solution
    # continue infeasible solution
    # 

In [None]:
%%capture
inst = build_inst(maximum_independent_set_problem, 4096)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

In [None]:
mask_feat_size = 2
n_node_feats = graphs[0].ndata['feat'].shape[1] + mask_feat_size
n_edge_feats = graphs[0].edata['feat'].shape[1]
num_classes = int(graphs[0].ndata['label'].max()) + 1
hidden_size = 256

In [None]:
model = Model(n_node_feats, n_edge_feats, hidden_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiLabelFocalLoss(nn.Module):
    """
    Multi-label Focal Loss for each class independently.
    
    Args:
        alpha (float): Weighting factor for positive examples. Default: 0.25 (commonly 0.25 or 0.5).
        gamma (float): Exponent for down-weighting easy examples. Default: 2.0.
        reduction (str): 'mean', 'sum' or 'none'. Default: 'mean'.

    Shapes:
        pred: (batch_size, num_classes) - raw, unnormalized logits
        target: (batch_size, num_classes) - multi-label targets in {0,1}

    Example usage:
        criterion = MultiLabelFocalLoss(alpha=0.5, gamma=2.0, reduction='mean')
        logits = torch.randn(batch_size, num_classes)  # model outputs
        targets = torch.randint(0, 2, (batch_size, num_classes)).float()
        loss = criterion(logits, targets)
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(MultiLabelFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, probs, target):
        # pred: (m, n) raw logits
        # target: (m, n) in {0, 1}
        
        # Apply sigmoid to get probabilities
        # probs = torch.sigmoid(pred)  # (m, n)
        # Add a small epsilon to avoid log(0)
        eps = 1e-8

        # Calculate the focal loss
        # Positive term
        pos_loss = -self.alpha * (1 - probs).pow(self.gamma) * target * torch.log(probs + eps)
        # Negative term
        neg_loss = -(1 - self.alpha) * probs.pow(self.gamma) * (1 - target) * torch.log(1 - probs + eps)

        loss = pos_loss + neg_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss


In [None]:
import random
from torch.nn import MSELoss

print("Total number of graphs", len(graphs))
num_epochs = 5000
for epoch in range(num_epochs):
    
    model.train()
    optimizer.zero_grad()
    
    cntr = 0
    loss = 0

    random.shuffle(graphs)
    for i, g in enumerate(graphs):

        train_mask = get_train_mask(g, ratio=1.0)
        solution_mask = get_solution_mask(train_mask, (0.2, 0.8))
        node_feat_with_hint, hint_mask = get_mask_node_feature(g.ndata['feat'], g.ndata['label'], solution_mask)
        
        probs, dists = model(g, node_feat_with_hint, g.edata['feat'])
        labels = g.ndata['label']
        distances = g.ndata['distance'].float()

        n_vars = g.ndata['feat'][:, 2].sum().int()
        
        fkl = MultiLabelFocalLoss()(
            probs[:n_vars][~hint_mask[:n_vars]], 
            labels[:n_vars][~hint_mask[:n_vars]]
        )
        msl = MultiLabelFocalLoss()(
            dists[:n_vars][hint_mask[:n_vars]], 
            distances[:n_vars][hint_mask[:n_vars]]
        )
        loss += fkl + msl
        cntr += 1

        if cntr == 256:
            print("loss", loss.detach().numpy())
            print("#"*78)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

            loss = 0
            cntr = 0

    print("-"*78)
    print(probs[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print(labels[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print(">"*78)
    print(dists[:n_vars][hint_mask[:n_vars]].detach().numpy())
    print(distances[:n_vars][hint_mask[:n_vars]].detach().numpy())
    print('^'*78)

In [None]:
m = maximum_independent_set_problem(num_nodes=128, edge_prob=0.5)
m.optimize()

In [None]:
inst = build_inst(lambda: maximum_independent_set_problem(num_nodes=128), 1)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

In [None]:
MODEL[0]
model_vars = MODEL[0].getVars()
model_vars[56].BranchPriority = 3
model_vars[42].BranchPriority = 2
model_vars[2].BranchPriority = 1

In [None]:
MODEL[0].optimize()

In [None]:
with open("gb.lic") as f:
    env = gp.Env(params=json.load(f))

In [None]:
import math
import random

import gurobipy as gp
import numpy as np

np.random.seed(0)
random.seed(0)

MODEL = []
def maximum_independent_set_problem(
    num_nodes=128,
    edge_prob=0.3,
) -> gp.Model:
    edges = []
    num_nodes = random.randint(num_nodes-10, num_nodes)
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if np.random.rand() < edge_prob:
                edges.append((i, j))

    m = gp.Model("maximum_independent_set")
    x = m.addVars(num_nodes, vtype=gp.GRB.BINARY, name="x")

    for i, j in edges:
        m.addConstr(x[i] + x[j] <= 1, name=f"edge_{i}_{j}")

    m.setObjective(gp.quicksum(x[i] for i in range(num_nodes)), gp.GRB.MAXIMIZE)
    m.update()
    MODEL.append(m.copy())
    return m


In [None]:
inst = build_inst(maximum_independent_set_problem, 1, env=env)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

graph = graphs[0]
is_var_flag_idx = 2
var_flag = graph.ndata["feat"][:, is_var_flag_idx]
size = list(var_flag[var_flag == 1].size())[0]

In [None]:
# g.ndata['label'][:size]

In [None]:
from tqdm import tqdm
importances = {i: 0 for i in range(size)}

for trial_idx in tqdm(range(500)):
        
    prev_count = 0
    seeds = torch.zeros((g.ndata['feat'].shape[0], 2))
    trail_importance = {i: 0 for i in range(size)}
    while (seeds[:size, 1] == 0).any():
        
        for i in range(size//20):
            rand_var_idx = random.randint(0, size)
            rand_var_val = random.randint(0, 1)
            seeds[rand_var_idx, 0] = rand_var_val
            seeds[rand_var_idx, 1] = 1
            
        feat = torch.hstack([g.ndata['feat'], seeds])    
        probs, dists = model(g, feat, g.edata['feat'])

        new_vars = set()
        for i in range(size):
            if seeds[i, 1] == 1:
                continue
                
            if probs[i, 0] >= 0.9:
                seeds[i, 0] = 0
                seeds[i, 1] = 1
                new_vars.add(i)
                
            if probs[i, 1] >= 0.9:
                seeds[i, 0] = 1
                seeds[i, 1] = 1
                new_vars.add(i)
                
        curr_count = seeds[:size, 1].sum()
        if curr_count == prev_count:
            for k in trail_importance:
                trail_importance[k] = -trail_importance[k]
            break

        impv = curr_count - prev_count
        for nv in new_vars:
            trail_importance[nv] += impv/len(new_vars)
        
        prev_count = curr_count
        
    for k in importances:
        importances[k] += trail_importance[k] 

In [None]:
import json
import random
from functools import partial
from typing import Callable, List, Tuple, Union

import dgl
import gurobipy as gp
import numpy as np
import torch

from learn.feature import ConFeature, EdgFeature, VarFeature
from learn.info import ModelInfo

__DEVICE_PTR = [torch.device("cuda" if torch.cuda.is_available() else "cpu")]


def SET_DEVICE(device):
    __DEVICE_PTR[0] = device


def GET_DEVICE():
    return __DEVICE_PTR[0]


class Inst:
    def __init__(
        self,
        v_features: List[VarFeature],
        c_features: List[ConFeature],
        e_features: List[EdgFeature],
        solutions: List[List[Union[float, int]]],
        distances: List[List[Union[float, int]]],
    ):
        assert len(v_features) == len(c_features) == len(e_features) == len(solutions)
        self.v_features = v_features
        self.c_features = c_features
        self.e_features = e_features
        self.solutions = solutions
        self.distances = distances

    @property
    def n(self):
        return len(self.solutions)

    # TODO: use consistent naming for getting the features
    @property
    def xs(self):

        n_var = []
        n_con = []
        c_v_edges = []
        v_c_edges = []
        node_features = []

        for i in range(self.n):

            # [v0, v1, v2, ... c0, c1, c2]
            var_xs = self.v_features[i].values
            con_xs = self.c_features[i].values

            # TODO: use dgl or pyg to remove the need of padding
            con_xs, var_xs = self._pad_features(con_xs, var_xs)
            n_var.append(len(var_xs))
            n_con.append(len(con_xs))
            xs = np.vstack([var_xs, con_xs])
            node_features.append(torch.as_tensor(xs, dtype=torch.float32))

            con_idxs, var_idxs = self.e_features[i].indices
            con_idxs, var_idxs = self._shift_idxs(con_idxs, var_idxs, len(var_xs))
            assert len(con_idxs) == len(var_idxs)

            cve = []
            vce = []

            n_edges = len(con_idxs)
            for i in range(n_edges):
                cve.append([con_idxs[i], var_idxs[i]])
                vce.append([var_idxs[i], con_idxs[i]])
            cve = torch.as_tensor(cve, dtype=torch.int)
            vce = torch.as_tensor(vce, dtype=torch.int)

            c_v_edges.append(cve)
            v_c_edges.append(vce)

        edge_features = [
            torch.as_tensor(f.values, dtype=torch.float32) for f in self.e_features
        ]
        return c_v_edges, v_c_edges, node_features, edge_features, n_var, n_con

    @staticmethod
    def _shift_idxs(con_idxs, var_idxs, n_vars):
        # [v0, v1, v2, ... c0, c1, c2]
        return con_idxs + n_vars, var_idxs

    @staticmethod
    def _pad_features(con_xs, var_xs):
        con_x_dim = con_xs.shape[1]
        var_x_dim = var_xs.shape[1]

        if con_x_dim < var_x_dim:
            pad_size = var_x_dim - con_x_dim
            con_xs = np.pad(
                con_xs,
                pad_width=((0, 0), (0, pad_size)),
                mode="constant",
                constant_values=0,
            )

        if var_x_dim < con_x_dim:
            pad_size = con_x_dim - var_x_dim
            var_xs = np.pad(
                var_xs,
                pad_width=((0, 0), (0, pad_size)),
                mode="constant",
                constant_values=0,
            )

        return con_xs, var_xs

    @property
    def ys(self):
        # TODO: handle other type of x
        values = []
        for i, s in enumerate(self.solutions):

            # TODO: use accessor method
            n_constr = len(self.c_features[i].values)

            # TODO: remove padding
            # TODO: replace with hetro-graph

            arr = np.array([[0, 1] if v == 1 else [1, 0] for v in s] + [[0, 0]] * n_constr)
            values.append(torch.as_tensor(arr, dtype=torch.int32))

        distances = []
        for i, d in enumerate(self.distances):

            # TODO: use accessor method
            n_constr = len(self.c_features[i].values)

            # TODO: remove padding
            # TODO: replace with hetro-graph
            arr = np.array(d + [[0, 0, 0]] * n_constr)
            distances.append(torch.as_tensor(arr, dtype=torch.int32))

        return values, distances


def get_train_mask(graph, ratio: float):
    """return mask for solution that can be included as hint"""
    assert 0.0 <= ratio <= 1.0
    is_var_flag_idx = 2
    var_flag = graph.ndata["feat"][:, is_var_flag_idx]
    size = list(var_flag[var_flag == 1].size())[0]
    n_include = int(round(size * ratio))
    mask = torch.zeros(size, dtype=torch.bool)
    idx = torch.randperm(size)[:n_include]
    mask[idx] = 1
    return mask


def get_solution_mask(
    mask: torch.Tensor, ratio: Union[float, Tuple[float, float]] = (0.5, 1.0)
) -> torch.Tensor:
    """get solution mask that has ratio between given ratio range that can be included as hint"""
    mask = mask.clone()
    ones_indices = torch.where(mask == 1)[0]
    ratio = (ratio, ratio) if isinstance(ratio, float) else ratio
    assert 0.0 <= ratio[0] <= 1.0 and 0.0 <= ratio[1] <= 1.0

    min_num_keep = int(round(len(ones_indices) * ratio[0]))
    max_num_keep = int(round(len(ones_indices) * ratio[1]))
    max_num_keep = max(min_num_keep, max_num_keep)
    num_keep = random.randint(min_num_keep, max_num_keep)

    if num_keep <= 0:
        mask[ones_indices] = 0
        return mask

    if num_keep >= len(ones_indices):
        return mask

    selected_indices = torch.randperm(len(ones_indices))[:num_keep]
    keep_indices = ones_indices[selected_indices]
    mask[ones_indices] = 0
    mask[keep_indices] = 1
    return mask


def get_mask_node_feature(node_feature, y, mask):
    """add mask and hint into feature"""
    node_feature_with_y = torch.hstack([node_feature, (y[:, 1] == 0).unsqueeze(1)])
    mask = torch.cat([mask, torch.zeros(len(y) - len(mask), dtype=torch.bool)])
    masked = node_feature_with_y.clone()
    masked[~mask, -1] = 0
    return torch.hstack([masked, mask.unsqueeze(1)]), mask


def build_graphs(inst):
    c_v_edges, v_c_edges, node_features, edge_features, _, _ = inst.xs
    ys, dists = inst.ys

    graphs = []
    for i in range(len(ys)):
        srcs = torch.cat([c_v_edges[i][:, 0], v_c_edges[i][:, 0]])
        dsts = torch.cat([c_v_edges[i][:, 1], v_c_edges[i][:, 1]])

        # TODO: replace with hetro-graph
        g = dgl.graph((srcs, dsts))
        g.ndata["feat"] = node_features[i]
        g.ndata["label"] = ys[i]
        g.ndata["distance"] = dists[i]
        g.edata["feat"] = torch.cat([edge_features[i], edge_features[i]])
        assert (g.in_degrees() == g.out_degrees()).all()
        graphs.append(g)

    return graphs

MODEL = []
def build_inst(model_generator: Callable[[], gp.Model], n=1024, env=None) -> Inst:

    # if env is None:
    #     with open("gb.lic") as f:
    #         params = json.load(f)
    #         env = gp.Env(params=params)

    var_feats = []
    con_feats = []
    edg_feats = []
    solutions = []
    distances = []

    for _ in range(n):
        raw_m = model_generator()
        MODEL.append(raw_m.copy())
        info = ModelInfo.from_model(raw_m)
        vf = VarFeature.from_info(info.var_info, info.obj_info)
        cf = ConFeature.from_info(info.con_info)
        ef = EdgFeature.from_info(info.con_info)
        
        m = raw_m if env is None else raw_m.copy(env=env) 
        m.update()

        ss = []
        vs = m.getVars()
        m.optimize(partial(_collect_mip_sol, variables=vs, collection=ss))

        final_s = [v.X for v in vs]
        if ss and ss[-1] != final_s:
            ss.append(final_s)

        for s in ss:
            var_feats.append(vf)
            con_feats.append(cf)
            edg_feats.append(ef)
            solutions.append(s)
            d = []
            for v1, v2 in zip(s, ss[-1]):
                if v1 == v2:
                    d.append([0, 1, 0])
                    continue
                if v1 < v2:
                    d.append([0, 0, 1])
                    continue
                if v1 > v2:
                    d.append([1, 0, 0])
                    continue
            distances.append(d)

        raw_m.dispose()
        m.dispose()

    return Inst(var_feats, con_feats, edg_feats, solutions, distances)


def remove_redundant_nodes(g) -> None:
    to_remove = (g.in_degrees() == 0).nonzero().reshape(-1).int()
    g.remove_nodes(to_remove)


# TODO: take the objective value into consideration and weight the sample
def _collect_mip_sol(
    model: gp.Model, where: int, variables: List, collection: List
) -> None:
    if where == gp.GRB.Callback.MIPSOL:
        s = model.cbGetSolution(variables)
        collection.append(s)


In [None]:
to_solve = with_lic(MODEL[0])
to_set = to_solve.getVars()

imp = 1
for vidx, _ in sorted(importances.items(), key=lambda tup: tup[1])[-len(to_set)//4: ]:
    to_set[vidx].branch_priority = imp
    imp += 1
to_solve.update()
    
    

In [None]:
to_solve.optimize()

In [None]:
import random
for i in range(size//6):
    rand_var_idx = random.randint(0, size)
    rand_var_val = random.randint(0, 1)
    seeds[rand_var_idx, 0] = rand_var_val
    seeds[rand_var_idx, 1] = 1

In [None]:
feat = torch.hstack([g.ndata['feat'], seeds])    
probs, dists = model(g, feat, g.edata['feat'])

In [None]:
probs[:size]

In [None]:
for i in range(size):
    if seeds[i, 1] == 1:
        continue
    if probs[i, 0] >= 0.9:
        seeds[i, 0] = 0
        seeds[i, 1] = 1
    if probs[i, 1] >= 0.9:
        seeds[i, 0] = 1
        seeds[i, 1] = 1

In [None]:
seeds[:size]

In [None]:
feat = torch.hstack([g.ndata['feat'], seeds])    
probs, dists = model(g, feat, g.edata['feat'])

In [None]:
probs[:size]

In [None]:
for i in range(size):
    if seeds[i, 1] == 1:
        continue
    if probs[i, 0] >= 0.9:
        seeds[i, 0] = 0
        seeds[i, 1] = 1
    if probs[i, 1] >= 0.9:
        seeds[i, 0] = 1
        seeds[i, 1] = 1

In [None]:
seeds[:size]