In [1]:
import torch
import gurobipy as gp

In [2]:
from learn.train import build_inst, build_graphs, remove_redundant_nodes, get_train_mask, get_solution_mask, get_mask_node_feature
from learn.model import Model, FocalLoss
from learn.generator import maximum_independent_set_problem

In [None]:
%%capture

inst = build_inst(maximum_independent_set_problem, 1024)
graphs = build_graphs(inst)
for g in graphs:
    remove_redundant_nodes(g)

In [None]:
mask_feat_size = 2
n_node_feats = graphs[0].ndata['feat'].shape[1] + mask_feat_size
n_edge_feats = graphs[0].edata['feat'].shape[1]
num_classes = int(graphs[0].ndata['label'].max()) + 1
hidden_size = 64

In [None]:
model = Model(n_node_feats, n_edge_feats, hidden_size, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
import random

print("Total number of graphs", len(graphs))
num_epochs = 500
for epoch in range(num_epochs):
    
    model.train()
    optimizer.zero_grad()
    
    cntr = 0
    loss = 0

    random.shuffle(graphs)
    for i, g in enumerate(graphs):

        train_mask = get_train_mask(g, ratio=0.8)
        solution_mask = get_solution_mask(train_mask, (0.2, 1.0))
        node_feat_with_hint, hint_mask = get_mask_node_feature(g.ndata['feat'], g.ndata['label'], solution_mask)
        
        logits = model(g, node_feat_with_hint, g.edata['feat'])
        labels = g.ndata['label']

        n_vars = g.ndata['feat'][:, 2].sum().int()
        loss += FocalLoss()(
            logits[:n_vars][~hint_mask[:n_vars]], 
            labels[:n_vars][~hint_mask[:n_vars]]
        )
        cntr += 1

        if cntr == 256:
            print("loss", loss.detach().numpy())
            print("#"*78)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

            loss = 0
            cntr = 0

    print("-"*78)
    print(logits[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print(labels[:n_vars][~hint_mask[:n_vars]].detach().numpy())
    print('^'*78)

In [None]:
assert 1 == 2