## Create Labels

### Imports

In [1]:
import argparse
import json
import numpy as np
import torch
from pathlib import Path
from pcst_fast import pcst_fast as pcst
from tqdm import tqdm

### Métodos

In [2]:
def call_pcst_ensemble(edges, prizes, costs, root):
    trials = [
        (edges, prizes, costs, root),
        (edges, prizes, costs, root, 1),
        (edges, prizes, costs, root, 1, "gw"),
        (edges, prizes, costs, root, 1, "strong"),
        (edges, prizes, costs, root, 1, "gw", 0),
        (edges, prizes, costs, root, 1, "strong", 0),
    ]
    best_res, best_cost = None, float('inf')
    for args in trials:
        try:
            n, e = pcst(*args)
            c = costs[e].sum() if len(e) > 0 else 0.0
            if c < best_cost:
                best_cost, best_res = c, (n, e)
        except: pass
    
    if best_res is None: raise RuntimeError("PCST falhou")
    return best_res

In [3]:
def build_undirected_edges(g):
    ei, et = g["edge_index"], g["edge_type"]
    pair_type = {}
    for i in range(ei.size(1)):
        u, v = int(ei[0, i]), int(ei[1, i])
        t = int(et[i])
        k = tuple(sorted((u, v)))

        if k not in pair_type or (t == 0 and pair_type[k] != 0):
            pair_type[k] = t
            
    edges, costs = [], []
    for (a, b), t in pair_type.items():
        edges.append([a, b])
        costs.append(1.0 if t == 0 else 1.2)
    return np.array(edges, dtype=np.int64), np.array(costs, dtype=np.float64)

### Main

Configuração de caminhos com criação de subpastas automática

In [4]:
input_dir = Path("../../inputs")
output_dir = Path("../../outputs")

snapshot_dir = output_dir / "snapshots"
snapshot_dir.mkdir(parents=True, exist_ok=True)

instance_dir = output_dir / "instances"
instance_dir.mkdir(parents=True, exist_ok=True)

label_dir = output_dir / "labels"
label_dir.mkdir(parents=True, exist_ok=True)

In [5]:
if not (instance_dir / "instances.jsonl").exists(): 
    raise FileNotFoundError("instances.jsonl")

In [6]:
total_lines = sum(1 for _ in open(instance_dir / "instances.jsonl", "r"))

In [7]:
graph_cache, edge_cache = {}, {}
written = 0

In [8]:
with open(label_dir / "labels.jsonl", "w", encoding="utf-8") as out_f:
    for line in tqdm(open(instance_dir / "instances.jsonl", "r"), total=total_lines, desc="labels: escrevendo labels.jsonl", unit="instância"):
        if not line.strip(): 
            continue

        inst = json.loads(line)
        snap = inst.get("snapshot_next") or inst.get("snapshot")
        if snap not in graph_cache:
            graph_cache[snap] = torch.load(snapshot_dir / f"as_graph_{snap}.pt", map_location="cpu")

        g = graph_cache[snap]
        if snap not in edge_cache:
            edge_cache[snap] = build_undirected_edges(g)

        edges, costs = edge_cache[snap]
        prizes = np.zeros(g["num_nodes"]); prizes[inst["root"]] = 10.0
        for t in inst["terminals_out"]: 
            prizes[int(t)] = 10.0
        sel_nodes, sel_e_idx = call_pcst_ensemble(edges, prizes, costs, inst["root"])
        
        inst.update({
            "tree_nodes": list(map(int, sel_nodes)),
            "tree_edges": edges[sel_e_idx].tolist()
        })
        
        out_f.write(json.dumps(inst, ensure_ascii=False) + "\n")
        written += 1
            
print(f"  -> labels gerados: {written} instâncias\n")

labels: escrevendo labels.jsonl: 100%|██████████| 595/595 [12:03<00:00,  1.22s/instância]

  -> labels gerados: 595 instâncias




