In [1]:
import configargparse
from pathlib import Path
from torch.utils.data import DataLoader
from models.models import KKTNetMLP, GNNPolicy
from data.datasets import LPDataset, GraphDataset, pad_collate_graphs, make_pad_collate
import torch
import pickle

In [2]:
parser = configargparse.ArgumentParser(
        allow_abbrev=False,
        description="evaluate a trained model and plot predictions",
        default_config_files=["config.yml"],
    )

t = parser.add_argument_group("testing")
t.add_argument(
    "--ckpt", required=True, help="Path to best.pt or dir containing it."
)
t.add_argument("--split", choices=["test", "val"], default="test")
t.add_argument("--out_dir", type=str, default="outputs/eval-embeddings")
t.add_argument(
    "--device",
    type=str,
    default=None,
    help="CUDA device like '0' or 'cpu' (default: from ckpt args/devices).",
)

t.add_argument("--file_path", type=str, required=True, help="Path to single .bg file")

tr = parser.add_argument_group("training")
tr.add_argument("--lr", type=float, default=1e-3)
tr.add_argument("--use_bipartite_graphs", action="store_true")

GNNPolicy.add_args(parser)

args, _ = parser.parse_known_args()

In [3]:
args

Namespace(ckpt='exps/kkt_20251021_090044/best.pt', split='test', out_dir='outputs', device='0', file_path='data/instances/RND/instance/test/2/RND-2-0012.lp', lr=0.001, use_bipartite_graphs=False, embedding_size=128, cons_nfeats=4, edge_nfeats=1, var_nfeats=18, num_emb_type='periodic', num_emb_bins=32, num_emb_freqs=16)

In [4]:

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

In [5]:
checkpoint = torch.load(str(args.ckpt), map_location="cpu")

In [6]:
use_bipartite_graphs = args.use_bipartite_graphs
problems = ["RND"]
files = [args.file_path]

dataset = GraphDataset(files) if use_bipartite_graphs else LPDataset(files)

original problem has 2 variables (0 bin, 0 int, 0 impl, 2 cont) and 2 constraints


In [11]:
if not args.use_bipartite_graphs:
    m,n = dataset.shapes[0]

In [12]:
checkpoint

{'epoch': 189,
 'model': OrderedDict([('net.0.weight',
               tensor([[ 0.2921,  0.2666, -0.1132,  ...,  0.0313, -0.1886,  0.2727],
                       [ 0.3560, -0.3060,  0.3288,  ...,  0.1476,  0.1360,  0.0152],
                       [ 0.2938, -0.0162, -0.1531,  ..., -0.0897, -0.1494,  0.2415],
                       ...,
                       [ 0.0856, -0.0019, -0.3015,  ...,  0.0244,  0.3176,  0.2198],
                       [-0.1736, -0.1370, -0.1221,  ..., -0.3502,  0.1231,  0.2690],
                       [-0.0136, -0.2379, -0.2337,  ...,  0.0410,  0.0512,  0.4480]])),
              ('net.0.bias',
               tensor([ 0.1389,  0.0501, -0.0379,  0.1714,  0.1032, -0.3229, -0.2526,  0.0268,
                        0.1345,  0.1983, -0.2454, -0.3713, -0.0146, -0.0427,  0.0928, -0.0282,
                        0.2308,  0.0806,  0.1941, -0.2267,  0.0691,  0.0897, -0.2069,  0.1272,
                       -0.3195,  0.2527,  0.0558,  0.0275,  0.1401,  0.0078,  0.1144,  0.1

In [13]:

model = GNNPolicy(args) if use_bipartite_graphs else KKTNetMLP(m,n)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

model.eval()

KKTNetMLP(
  (net): Sequential(
    (0): Linear(in_features=8, out_features=256, bias=True)
    (1): SELU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): SELU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): SELU()
  )
  (head_x): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): SELU()
    (2): Linear(in_features=64, out_features=2, bias=True)
  )
  (head_lam): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): SELU()
    (2): Linear(in_features=64, out_features=2, bias=True)
  )
  (relu): ReLU()
)

In [14]:

loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=pad_collate_graphs
            if use_bipartite_graphs
            else make_pad_collate(M_fixed=m, N_fixed=n),
)

In [27]:
for batch in loader:
    input = batch[0]
    if use_bipartite_graphs:
        x_hat, lam_hat = model(
        input.constraint_features,
        input.edge_index,
        input.edge_attr,
        input.variable_features,
    )
    else:
        
        pred = model(input)
        if args.use_bipartite_graphs:
            x_hat, lam_hat = pred
        else:
            x_hat = pred[0][:n]
            lam_hat = pred[0][n:]
    

original problem has 2 variables (0 bin, 0 int, 0 impl, 2 cont) and 2 constraints


In [31]:
if args.use_bipartite_graphs:
    solution_path = args.file_path.replace("BG","solution").replace(".bg",".sol")
else:
    solution_path = args.file_path.replace("instance","solution") + ".sol"
    solution_path = solution_path.replace("solutions", "instances")

with open(solution_path, "rb") as file:
    solution_data = pickle.load(file)

In [32]:
solutions = solution_data["sols"]

In [33]:
len(solutions)

1

In [34]:
print(f"{'x_hat':>10} | {'optimal':>10} | {'Diff':>10}")
print("-" * 35)
for t, f in zip(x_hat, solutions[0]):
    print(f"{t.item():10.4f} | {f:10.4f} | {abs(t.item() - f):10.4f}")

     x_hat |    optimal |       Diff
-----------------------------------
   -0.2738 |    -0.5322 |     0.2584
    0.6501 |     0.9742 |     0.3241
