In [1]:
import torch
import gurobipy as gp

In [2]:
from learn.train import build_inst, build_graphs, remove_redundant_nodes, get_train_mask, get_solution_mask, get_mask_node_feature
from learn.model import Model, FocalLoss
from learn.generator import maximum_independent_set_problem

In [3]:
%%capture

inst = build_inst(maximum_independent_set_problem, 10)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

In [4]:
mask_feat_size = 2
n_node_feats = graphs[0].ndata['feat'].shape[1] + mask_feat_size
n_edge_feats = graphs[0].edata['feat'].shape[1]
num_classes = int(graphs[0].ndata['label'].max()) + 1
hidden_size = 64

In [5]:
model = Model(n_node_feats, n_edge_feats, hidden_size, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
import random

print("Total number of graphs", len(graphs))
num_epochs = 500
for epoch in range(num_epochs):
    
    model.train()
    optimizer.zero_grad()
    
    cntr = 0
    loss = 0

    random.shuffle(graphs)
    for i, g in enumerate(graphs):

        train_mask = get_train_mask(g, ratio=0.8)
        solution_mask = get_solution_mask(train_mask, (0.2, 1.0))
        node_feat_with_hint, hint_mask = get_mask_node_feature(g.ndata['feat'], g.ndata['label'], solution_mask)
        
        logits = model(g, node_feat_with_hint, g.edata['feat'])
        labels = g.ndata['label']

        n_vars = g.ndata['feat'][:, 2].sum().int()
        loss += FocalLoss()(
            logits[:n_vars][~hint_mask[:n_vars]], 
            labels[:n_vars][~hint_mask[:n_vars]]
        )
        cntr += 1

        if cntr == 256:
            print("loss", loss.detach().numpy())
            print("#"*78)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

            loss = 0
            cntr = 0

    print("-"*78)
    print(logits[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print(labels[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print('^'*78)

Total number of graphs 2198


  node_x = F.softmax(node_x)


loss 41.254913
##############################################################################
loss 41.25874
##############################################################################
loss 40.09499
##############################################################################
loss 40.028175
##############################################################################
loss 40.33964
##############################################################################
loss 39.8017
##############################################################################
loss 40.958214
##############################################################################
loss 39.75825
##############################################################################
------------------------------------------------------------------------------
[[0.64065945 0.3593405 ]
 [0.6293926  0.37060735]
 [0.662728   0.337272  ]
 [0.6211132  0.3788868 ]
 [0.61756194 0.38243806]
 [0.6544853  0.3455147 ]
 [0.64304966 0.3569503 ]


In [4]:
from learn.info import ModelInfo
m = maximum_independent_set_problem()
m.optimize()
solution = [v.X for v in m.getVars()]
info = ModelInfo.from_model(m)

Gurobi Optimizer version 12.0.0 build v12.0.0rc1 (mac64[x86] - Darwin 22.4.0 22E252)

CPU model: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
Thread count: 8 physical cores, 16 logical processors, using up to 16 threads

Optimize a model with 18 rows, 14 columns and 36 nonzeros
Model fingerprint: 0x63956828
Variable types: 0 continuous, 14 integer (14 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [1e+00, 1e+00]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+00]
Found heuristic solution: objective 6.0000000
Presolve removed 18 rows and 14 columns
Presolve time: 0.00s
Presolve: All rows and columns removed

Explored 0 nodes (0 simplex iterations) in 0.01 seconds (0.00 work units)
Thread count was 1 (of 16 available processors)

Solution count 2: 8 6 

Optimal solution found (tolerance 1.00e-04)
Best objective 8.000000000000e+00, best bound 8.000000000000e+00, gap 0.0000%


In [18]:
type(info.con_info)

learn.info.ConInfo

In [22]:
lhs, vs, rhs, ops = get_side_matrices(solution, info.con_info)


In [25]:
lhs @ vs - rhs

array([[ 0.],
       [-1.],
       [-1.],
       [ 0.],
       [-1.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [-1.],
       [ 0.],
       [ 0.],
       [ 0.],
       [-1.],
       [ 0.],
       [ 0.]])

In [26]:
info.con_info.types

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [None]:
# optim.nsolver

In [None]:
def deep_solve(problem, solution=None):
    sub_problems = partition(problem)    
    for p in sub_problems:
        
        if has_solution(p):
            seed = solution[p]
            break
            
        if is_small(p):
            seed = exact_solve(p)
            break

        seed = deep_solve(problem)
    
    infer(seed)
    solution = recursive_cross()
    return solution
    

    
        
        
    
    
        
    
    
        
        

In [62]:
# TODO: normalize violation by constraint type
import numpy as np


def get_constraint_side_matrices(x, con_info):
    
    n_con = len(con_info.lhs_p)
    n_var = len(solution)
    lhs = np.zeros((n_con, n_var))
    
    for con_idx in range(n_con):
        var_idxs = con_info.lhs_p[con_idx]
        var_cefs = con_info.lhs_c[con_idx]
        for var_idx, var_cef in zip(var_idxs, var_cefs):
            lhs[con_idx][var_idx] = var_cef

    var_vals = np.array(solution)[:, np.newaxis]
    rhs = np.array(con_info.rhs)[:, np.newaxis]
    return lhs, var_vals, rhs


def get_constraint_violations(x, con_info):
    lhs, vs, rhs = get_constraint_side_matrices(x, con_info)
    ops = np.array(con_info.types)[:, np.newaxis]
    lt_ops = ops == con_info.ENUM_TO_OP["<="]
    eq_ops = ops == con_info.ENUM_TO_OP["=="]
    gt_ops = ops == con_info.ENUM_TO_OP[">="]
    
    diff = rhs - lhs @ vs
    violations = np.zeros_like(diff, dtype=bool)
    violations[lt_ops] = diff[lt_ops] > 0
    violations[gt_ops] = diff[gt_ops] < 0
    violations[eq_ops] = diff[eq_ops] != 0
    diff[violations] = 0
    return np.abs(diff)
    

def freeze_variables(x, con_info, var_info):
    violations = get_constraint_violations(x, con_info)
    
    



    
    

In [73]:
solution[0] = 1
solution[2] = 1
solution[3] = 1
solution[4] = 1
solution[5] = 1
solution[6] = 1
solution[7] = 1

In [76]:
solution = [1 for _ in range(len(solution))]

In [77]:
get_constraint_violations(solution, info.con_info)

array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]])

In [None]:
assert 1 == 2