In [1]:
import torch
import gurobipy as gp

In [2]:
from learn.train import build_inst, build_graphs, remove_redundant_nodes, get_train_mask, get_solution_mask, get_mask_node_feature
from learn.model import Model, FocalLoss
from learn.generator import maximum_independent_set_problem

In [13]:
%%capture
inst = build_inst(maximum_independent_set_problem, 256)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

In [14]:
mask_feat_size = 2
n_node_feats = graphs[0].ndata['feat'].shape[1] + mask_feat_size
n_edge_feats = graphs[0].edata['feat'].shape[1]
num_classes = int(graphs[0].ndata['label'].max()) + 1
hidden_size = 64

In [15]:
model = Model(n_node_feats, n_edge_feats, hidden_size, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
import random

print("Total number of graphs", len(graphs))
num_epochs = 500
for epoch in range(num_epochs):
    
    model.train()
    optimizer.zero_grad()
    
    cntr = 0
    loss = 0

    random.shuffle(graphs)
    for i, g in enumerate(graphs):

        train_mask = get_train_mask(g, ratio=0.5)
        solution_mask = get_solution_mask(train_mask, (0.2, 1.0))
        node_feat_with_hint, hint_mask = get_mask_node_feature(g.ndata['feat'], g.ndata['label'], solution_mask)
        
        logits = model(g, node_feat_with_hint, g.edata['feat'])
        labels = g.ndata['label']
        
        loss += FocalLoss()(logits[~hint_mask], labels[~hint_mask])
        cntr += 1

        if cntr == 256:
            print("loss", loss.detach().numpy())
            print("#"*78)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

            loss = 0
            cntr = 0

    
    print("-"*78)
    n_vars = g.ndata['feat'][:, 2].sum().int()
    print(
        torch.hstack(
            [logits, labels.reshape((labels.shape[0], -1))]
        )[:n_vars, :].detach().numpy()
    )

    print("-"*78)
    print(logits[:n_vars][~hint_mask[:n_vars]].detach().numpy()[:10])
    print(labels[:n_vars][~hint_mask[:n_vars]].detach().numpy()[:10])
    print('^'*78)

Total number of graphs 965
loss 130.79868
##############################################################################
loss 121.78383
##############################################################################
loss 111.15831
##############################################################################
------------------------------------------------------------------------------
[[0.2569305  0.74306947 0.        ]
 [0.30245173 0.6975483  0.        ]
 [0.25150758 0.7484925  0.        ]
 [0.2649356  0.7350644  0.        ]
 [0.25885    0.74114996 0.        ]
 [0.22214949 0.77785057 0.        ]
 [0.27809158 0.72190845 0.        ]
 [0.27028927 0.72971076 0.        ]
 [0.2971126  0.70288736 0.        ]
 [0.24381727 0.75618273 0.        ]
 [0.2242402  0.7757598  0.        ]
 [0.23180696 0.768193   0.        ]
 [0.22373532 0.77626467 0.        ]
 [0.26100263 0.73899734 0.        ]
 [0.23463428 0.7653657  0.        ]
 [0.26202148 0.7379785  0.        ]
 [0.4171406  0.5828594  1.        ]


In [None]:
assert 1 == 2