In [1]:
import configargparse
from pathlib import Path
from torch.utils.data import DataLoader
from models.models import KKTNetMLP, GNNPolicy
from data.datasets import LPDataset, GraphDataset, pad_collate_graphs, make_pad_collate
import torch
import pickle

In [2]:
parser = configargparse.ArgumentParser(
        allow_abbrev=False,
        description="evaluate a trained model and plot predictions",
        default_config_files=["config.yml"],
    )

t = parser.add_argument_group("testing")
t.add_argument(
    "--ckpt", required=True, help="Path to best.pt or dir containing it."
)
t.add_argument("--split", choices=["test", "val"], default="test")
t.add_argument("--out_dir", type=str, default="outputs/eval-embeddings")
t.add_argument(
    "--device",
    type=str,
    default=None,
    help="CUDA device like '0' or 'cpu' (default: from ckpt args/devices).",
)

t.add_argument("--file_path", type=str, required=True, help="Path to single .bg file")

tr = parser.add_argument_group("training")
tr.add_argument("--lr", type=float, default=1e-3)
tr.add_argument("--use_bipartite_graphs", action="store_true")

GNNPolicy.add_args(parser)

args, _ = parser.parse_known_args()

In [3]:
args

Namespace(ckpt='exps/kkt_20251114_055905/best.pt', split='test', out_dir='outputs', device='0', file_path='data/instances/RND/BG/test/10/RND-10-0012.lp.bg', lr=0.001, use_bipartite_graphs=True, embedding_size=128, cons_nfeats=4, edge_nfeats=1, var_nfeats=18, num_emb_type='periodic', num_emb_bins=32, num_emb_freqs=16)

In [4]:

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

In [5]:
checkpoint = torch.load(str(args.ckpt), map_location="cpu")

In [6]:
use_bipartite_graphs = args.use_bipartite_graphs
problems = ["RND"]
files = [args.file_path]

dataset = GraphDataset(files) if use_bipartite_graphs else LPDataset(files)

In [7]:
if not args.use_bipartite_graphs:
    m,n = dataset.shapes[0]

In [8]:
checkpoint

{'epoch': 96,
 'model': OrderedDict([('cons_num_emb.0.periodic.weight',
               tensor([[ 3.6042e-02,  3.6437e-02, -4.0318e-02, -3.3122e-02,  1.5622e-02,
                        -3.8417e-02, -3.5618e-03, -3.1772e-02,  1.6489e-02, -2.4132e-02,
                        -1.0678e-02,  5.3796e-03, -7.9424e-04,  6.7209e-02,  7.5781e-02,
                        -2.5948e-02],
                       [ 4.5148e-02,  3.2090e-02, -3.2970e-02,  4.8956e-05,  2.5128e-02,
                        -1.3864e-02,  3.0738e-02, -1.6493e-02,  1.5307e-02, -1.1089e-02,
                         5.9743e-03,  1.8313e-02,  2.6063e-02, -6.9464e-03,  4.0403e-02,
                        -8.8276e-03],
                       [ 7.4965e-02,  9.5240e-02, -1.0169e-01, -7.8384e-02,  1.2400e-01,
                        -7.6879e-02,  8.5957e-02, -2.6173e-02,  5.6656e-02, -3.9505e-03,
                        -4.1745e-02,  3.5903e-02,  4.0891e-02,  1.1397e-01,  1.6080e-01,
                        -5.8769e-02],
             

In [9]:

model = GNNPolicy(args) if use_bipartite_graphs else KKTNetMLP(m,n)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

model.eval()

GNNPolicy(
  (cons_num_emb): Sequential(
    (0): PeriodicEmbeddings(
      (periodic): _Periodic()
      (linear): Linear(in_features=32, out_features=24, bias=True)
      (activation): ReLU()
    )
    (1): Flatten(start_dim=1, end_dim=-1)
  )
  (cons_proj): Sequential(
    (0): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=96, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): ReLU()
  )
  (var_num_emb): Sequential(
    (0): PeriodicEmbeddings(
      (periodic): _Periodic()
      (linear): Linear(in_features=32, out_features=24, bias=True)
      (activation): ReLU()
    )
    (1): Flatten(start_dim=1, end_dim=-1)
  )
  (var_proj): Sequential(
    (0): LayerNorm((432,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=432, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): ReLU()
  )
  (edge_num_emb): Seque

In [10]:

loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=pad_collate_graphs
            if use_bipartite_graphs
            else make_pad_collate(M_fixed=m, N_fixed=n),
)

In [11]:
for batch in loader:
    input = batch[0]
    if use_bipartite_graphs:
        x_hat, lam_hat = model(
        input.constraint_features,
        input.edge_index,
        input.edge_attr,
        input.variable_features,
    )
    else:
        
        pred = model(input)
        if args.use_bipartite_graphs:
            x_hat, lam_hat = pred
        else:
            x_hat = pred[0][:n]
            lam_hat = pred[0][n:]
    

In [12]:
if args.use_bipartite_graphs:
    solution_path = args.file_path.replace("BG","solution").replace(".bg",".sol")
else:
    solution_path = args.file_path.replace("instance","solution") + ".sol"
    solution_path = solution_path.replace("solutions", "instances")

with open(solution_path, "rb") as file:
    solution_data = pickle.load(file)

In [13]:
solutions = solution_data["sols"]

In [14]:
len(solutions)

1

In [15]:
print(f"{'x_hat':>10} | {'optimal':>10} | {'Diff':>10}")
print("-" * 35)
for t, f in zip(x_hat, solutions[0]):
    print(f"{t.item():10.4f} | {f:10.4f} | {abs(t.item() - f):10.4f}")

     x_hat |    optimal |       Diff
-----------------------------------
   -0.1128 |     0.1840 |     0.2968
   -0.1551 |    -0.7846 |     0.6295
    0.0709 |    -0.3813 |     0.4522
   -0.2497 |    -1.2589 |     1.0092
    0.4674 |    -0.3173 |     0.7847
    0.2759 |     0.8757 |     0.5997
   -0.2591 |    -0.4133 |     0.1543
   -0.0087 |     0.7101 |     0.7188
   -0.1343 |    -0.2696 |     0.1353
   -0.2191 |    -0.3734 |     0.1544
