In [7]:
import configargparse
from pathlib import Path
from torch.utils.data import DataLoader
from models.models import KKTNetMLP, GNNPolicy
from data.datasets import LPDataset, GraphDataset, pad_collate_graphs, make_pad_collate
import torch
import pickle

In [8]:
parser = configargparse.ArgumentParser(
        allow_abbrev=False,
        description="evaluate a trained model and plot predictions",
        default_config_files=["config.yml"],
    )

t = parser.add_argument_group("testing")
t.add_argument(
    "--ckpt", required=True, help="Path to best.pt or dir containing it."
)
t.add_argument("--split", choices=["test", "val"], default="test")
t.add_argument("--out_dir", type=str, default="outputs/eval-embeddings")
t.add_argument(
    "--device",
    type=str,
    default=None,
    help="CUDA device like '0' or 'cpu' (default: from ckpt args/devices).",
)

t.add_argument("--file_path", type=str, required=True, help="Path to single .bg file")

tr = parser.add_argument_group("training")
tr.add_argument("--lr", type=float, default=1e-3)
tr.add_argument("--use_bipartite_graphs", action="store_true")

GNNPolicy.add_args(parser)

args, _ = parser.parse_known_args()

In [9]:
args

Namespace(ckpt='exps/kkt_20251020_203603/best.pt', split='test', out_dir='outputs', device='0', file_path='data/instances/RND/instance/test/5/RND-5-0012.lp', lr=0.001, use_bipartite_graphs=False, embedding_size=128, cons_nfeats=4, edge_nfeats=1, var_nfeats=18, num_emb_type='periodic', num_emb_bins=32, num_emb_freqs=16)

In [10]:

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

In [11]:
checkpoint = torch.load(str(args.ckpt), map_location="cpu")

In [12]:
use_bipartite_graphs = args.use_bipartite_graphs
problems = ["RND"]
files = [args.file_path]

dataset = GraphDataset(files) if use_bipartite_graphs else LPDataset(files)

original problem has 5 variables (0 bin, 0 int, 0 impl, 5 cont) and 5 constraints


In [13]:
if not args.use_bipartite_graphs:
    m,n = dataset.shapes[0]

In [14]:
checkpoint

{'epoch': 114,
 'model': OrderedDict([('net.0.weight',
               tensor([[ 0.0570,  0.0286, -0.0287,  ...,  0.0605,  0.0833, -0.0796],
                       [ 0.0592,  0.0303,  0.1093,  ..., -0.1462, -0.0694, -0.0838],
                       [ 0.1145,  0.0270,  0.1013,  ..., -0.0422, -0.0912,  0.0349],
                       ...,
                       [-0.1234,  0.0817, -0.0695,  ..., -0.0692,  0.1124, -0.0439],
                       [-0.0192,  0.0860, -0.0743,  ..., -0.1297, -0.0388, -0.1231],
                       [-0.0616,  0.0131,  0.1246,  ...,  0.0391, -0.0499, -0.0398]])),
              ('net.0.bias',
               tensor([-0.0160,  0.1157,  0.0871,  0.1268,  0.0916,  0.0507,  0.1160,  0.0174,
                        0.0930,  0.0680, -0.1314,  0.0893,  0.0474, -0.0238,  0.1014, -0.1008,
                        0.0820, -0.1097, -0.0011, -0.1256,  0.1151,  0.1078,  0.1526,  0.0974,
                       -0.0180,  0.0544,  0.0915, -0.1802, -0.0788,  0.0416, -0.0866,  0.0

In [15]:

model = GNNPolicy(args) if use_bipartite_graphs else KKTNetMLP(m,n)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

model.eval()

KKTNetMLP(
  (net): Sequential(
    (0): Linear(in_features=35, out_features=256, bias=True)
    (1): SELU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): SELU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): SELU()
  )
  (head_x): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): SELU()
    (2): Linear(in_features=64, out_features=5, bias=True)
  )
  (head_lam): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): SELU()
    (2): Linear(in_features=64, out_features=5, bias=True)
  )
  (relu): ReLU()
)

In [16]:

loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=pad_collate_graphs
            if use_bipartite_graphs
            else make_pad_collate(M_fixed=m, N_fixed=n),
)

In [17]:
for batch in loader:
    input = batch[0]
    if use_bipartite_graphs:
        x_hat, lam_hat = model(
        input.constraint_features,
        input.edge_index,
        input.edge_attr,
        input.variable_features,
    )
    else:
        
        pred = model(input)
        if args.use_bipartite_graphs:
            x_hat, lam_hat = pred
        else:
            x_hat = pred[0][:n]
            lam_hat = pred[0][n:]
    

original problem has 5 variables (0 bin, 0 int, 0 impl, 5 cont) and 5 constraints


In [18]:
if args.use_bipartite_graphs:
    solution_path = args.file_path.replace("BG","solution").replace(".bg",".sol")
else:
    solution_path = args.file_path.replace("instance","solution") + ".sol"
    solution_path = solution_path.replace("solutions", "instances")

with open(solution_path, "rb") as file:
    solution_data = pickle.load(file)

In [19]:
solutions = solution_data["sols"]

In [20]:
len(solutions)

1

In [21]:
print(f"{'x_hat':>10} | {'optimal':>10} | {'Diff':>10}")
print("-" * 35)
for t, f in zip(x_hat, solutions[0]):
    print(f"{t.item():10.4f} | {f:10.4f} | {abs(t.item() - f):10.4f}")

     x_hat |    optimal |       Diff
-----------------------------------
   -0.0426 |     0.2965 |     0.3391
   -0.1098 |    -1.2941 |     1.1843
   -0.1512 |    -0.4512 |     0.3001
   -0.1817 |    -1.6734 |     1.4917
   -0.3921 |    -1.5761 |     1.1840
