In [1]:
import torch
import gurobipy as gp

In [2]:
from learn.train import build_inst, build_graphs, remove_redundant_nodes, get_train_mask, get_solution_mask, get_mask_node_feature
from learn.model import Model, FocalLoss
from learn.generator import maximum_independent_set_problem

In [3]:
%%capture

inst = build_inst(maximum_independent_set_problem, 1024)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

Set parameter CloudAccessID
Set parameter CloudSecretKey
Set parameter CloudPool to value "831775-C3Dev"
Set parameter CSAppName to value "Josh"
Compute Server job ID: 708c38f8-7444-469c-a640-090f2794e28b
Capacity available on '831775-C3Dev' cloud pool - connecting...
Established HTTPS encrypted connection
Restricted license - for non-production use only - expires 2026-11-23
Gurobi Optimizer version 12.0.0 build v12.0.0rc1 (mac64[x86] - Darwin 22.4.0 22E252)
Gurobi Compute Server Worker version 12.0.0 build v12.0.0rc1 (linux64 - "Ubuntu 20.04.6 LTS")

CPU model: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz, instruction set [SSE2|AVX|AVX2|AVX512]
Thread count: 8 physical cores, 16 logical processors, using up to 16 threads

Non-default parameters:
CSIdleTimeout  1800

Optimize a model with 1796 rows, 109 columns and 3592 nonzeros
Model fingerprint: 0xce9b33f3
Variable types: 0 continuous, 109 integer (109 binary)
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective ran

In [4]:
mask_feat_size = 2
n_node_feats = graphs[0].ndata['feat'].shape[1] + mask_feat_size
n_edge_feats = graphs[0].edata['feat'].shape[1]
num_classes = int(graphs[0].ndata['label'].max()) + 1
hidden_size = 64

In [5]:
model = Model(n_node_feats, n_edge_feats, hidden_size, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
import random

print("Total number of graphs", len(graphs))
num_epochs = 500
for epoch in range(num_epochs):
    
    model.train()
    optimizer.zero_grad()
    
    cntr = 0
    loss = 0

    random.shuffle(graphs)
    for i, g in enumerate(graphs):

        train_mask = get_train_mask(g, ratio=0.8)
        solution_mask = get_solution_mask(train_mask, (0.2, 1.0))
        node_feat_with_hint, hint_mask = get_mask_node_feature(g.ndata['feat'], g.ndata['label'], solution_mask)
        
        logits = model(g, node_feat_with_hint, g.edata['feat'])
        labels = g.ndata['label']

        n_vars = g.ndata['feat'][:, 2].sum().int()
        loss += FocalLoss()(
            logits[:n_vars][~hint_mask[:n_vars]], 
            labels[:n_vars][~hint_mask[:n_vars]]
        )
        cntr += 1

        if cntr == 256:
            print("loss", loss.detach().numpy())
            print("#"*78)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

            loss = 0
            cntr = 0

    print("-"*78)
    print(logits[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print(labels[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print('^'*78)

Total number of graphs 6


  node_x = F.softmax(node_x)


------------------------------------------------------------------------------
[[0.864218   0.13578196]
 [0.8800926  0.11990735]
 [0.86211854 0.13788152]
 [0.8580588  0.14194119]
 [0.856945   0.14305502]
 [0.866818   0.13318194]
 [0.8845978  0.11540218]
 [0.8647955  0.13520454]
 [0.85587674 0.14412323]
 [0.8645618  0.1354382 ]
 [0.87043417 0.12956586]
 [0.8537408  0.14625916]
 [0.85267353 0.14732647]
 [0.8637091  0.13629092]
 [0.86952955 0.13047048]
 [0.87184685 0.1281531 ]
 [0.88765955 0.11234042]
 [0.8506627  0.14933725]
 [0.87115777 0.12884228]
 [0.85703665 0.14296329]
 [0.86525416 0.13474578]
 [0.88535523 0.11464477]
 [0.8661467  0.13385336]
 [0.8601077  0.13989231]
 [0.8628315  0.13716854]
 [0.852235   0.14776497]
 [0.86911154 0.13088846]
 [0.86823976 0.13176027]
 [0.856351   0.14364898]
 [0.8778825  0.12211746]
 [0.8735538  0.12644623]
 [0.8851774  0.11482269]
 [0.85096425 0.14903574]
 [0.86115307 0.13884693]
 [0.85299265 0.14700738]
 [0.86090213 0.13909781]
 [0.8707111  0.129288

In [None]:
assert 1 == 2