In [74]:
import configargparse
from pathlib import Path
from torch.utils.data import DataLoader
from models.models import KKTNetMLP, GNNPolicy
from data.datasets import LPDataset, GraphDataset, pad_collate_graphs, make_pad_collate
import torch

In [79]:
parser = configargparse.ArgumentParser(
        allow_abbrev=False,
        description="evaluate a trained model and plot predictions",
        default_config_files=["config.yml"],
    )

t = parser.add_argument_group("testing")
t.add_argument(
    "--ckpt", required=True, help="Path to best.pt or dir containing it."
)
t.add_argument("--split", choices=["test", "val"], default="test")
t.add_argument("--out_dir", type=str, default="outputs/eval-embeddings")
t.add_argument(
    "--device",
    type=str,
    default=None,
    help="CUDA device like '0' or 'cpu' (default: from ckpt args/devices).",
)

t.add_argument("--file_path", type=str, required=True, help="Path to single .bg file")

tr = parser.add_argument_group("training")
tr.add_argument("--lr", type=float, default=1e-3)

GNNPolicy.add_args(parser)

args, _ = parser.parse_known_args()

In [80]:
args

Namespace(ckpt='exps/kkt_20251021_121203/best.pt', split='test', out_dir='outputs', device='0', file_path='data/instances/CA/BG/test/5/CA-5-0012.lp.bg', lr=0.001, embedding_size=128, cons_nfeats=4, edge_nfeats=1, var_nfeats=18, num_emb_type='periodic', num_emb_bins=32, num_emb_freqs=16)

In [59]:

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

In [72]:
checkpoint = torch.load(str(args.ckpt), map_location="cpu")

In [73]:
use_bipartite_graphs = True
problems = ["IS", "CA"]
files = [args.file_path]

dataset = GraphDataset(files) if use_bipartite_graphs else LPDataset(files)

In [78]:
if not use_bipartite_graphs:
    m,n = dataset.shapes

In [82]:
checkpoint

{'epoch': 198,
 'model': OrderedDict([('cons_num_emb.0.periodic.weight',
               tensor([[-0.0222, -0.0598, -0.0253, -0.0041,  0.0545, -0.0080,  0.0547,  0.0283,
                         0.0245,  0.0141,  0.0074,  0.0733,  0.0439,  0.0037, -0.0645, -0.0206],
                       [ 0.1282,  0.1255,  0.0463,  0.0751, -0.0858,  0.1062, -0.0876, -0.1223,
                        -0.1405, -0.1793, -0.0514, -0.1046, -0.0716, -0.1046,  0.0412,  0.1549],
                       [-0.0088, -0.0003, -0.0211,  0.0029,  0.0331, -0.0013,  0.0412,  0.0133,
                         0.0281,  0.0263,  0.0364,  0.0340,  0.0506,  0.0055, -0.0221, -0.0009],
                       [-0.0191, -0.0595, -0.0625, -0.0222,  0.0268, -0.0077,  0.0410,  0.0338,
                         0.0243,  0.0124,  0.0621,  0.0354,  0.0408,  0.0322, -0.0281, -0.0143]])),
              ('cons_num_emb.0.linear.weight',
               tensor([[ 5.4167e-02, -3.6830e-02,  1.4662e-01, -1.0478e-01, -1.0543e-01,
                

In [84]:

model = GNNPolicy(args) if use_bipartite_graphs else KKTNetMLP(m,n)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

model.eval()

GNNPolicy(
  (cons_num_emb): Sequential(
    (0): PeriodicEmbeddings(
      (periodic): _Periodic()
      (linear): Linear(in_features=32, out_features=24, bias=True)
      (activation): ReLU()
    )
    (1): Flatten(start_dim=1, end_dim=-1)
  )
  (cons_proj): Sequential(
    (0): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=96, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): ReLU()
  )
  (var_num_emb): Sequential(
    (0): PeriodicEmbeddings(
      (periodic): _Periodic()
      (linear): Linear(in_features=32, out_features=24, bias=True)
      (activation): ReLU()
    )
    (1): Flatten(start_dim=1, end_dim=-1)
  )
  (var_proj): Sequential(
    (0): LayerNorm((432,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=432, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): ReLU()
  )
  (edge_num_emb): Seque

In [85]:

loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=pad_collate_graphs
            if use_bipartite_graphs
            else make_pad_collate(M_fixed=m, N_fixed=n),
)

In [93]:
for batch in loader:
    input = batch[0]
    if use_bipartite_graphs:
        x_hat, lam_hat = model(
        input.constraint_features,
        input.edge_index,
        input.edge_attr,
        input.variable_features,
    )
    else:
        x_hat, lam_hat = model(input)
    

In [92]:
print(x_hat,lam_hat)

tensor([-0.0995,  0.0824, -0.1006,  0.2154,  0.1396,  0.1784,  0.1293,  0.0137,
         0.0348, -0.1184,  0.0560,  0.0852, -0.0091,  0.1467,  0.0621,  0.0562,
         0.0148, -0.0049, -0.1619,  0.1202,  0.1715,  0.1500,  0.1643,  0.1658,
         0.0927], grad_fn=<SqueezeBackward1>) tensor([50.2684, 56.7186, 38.1299, 43.3105, 85.8977, 35.6341, 73.6087],
       grad_fn=<SoftplusBackward0>)
