In [None]:
from pathlib import Path

from torch_geometric.loader import DataLoader
from gfos.data.constants import CONFIG_RUNTIME_MEAN_STD
from gfos.metrics import LayoutMetrics

from gfos.data.utils import load_layout

source = "nlp"
search = "default"

data_root = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"
data_root = Path(data_root)

xla_default = load_layout(data_root, compile_type=search, model_type=source)

grad_clip = 1.0
num_configs = 256
epoch_infer = 400
accum_iter = 4

In [None]:
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from tqdm import tqdm

from gfos.data.graph import get_config_graph


class LayoutData(Data):
    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        if key in ("node_config_ids", "edge_index"):
            return self.num_nodes
        elif key == "config_edge_index":
            return self.num_config_nodes
        else:
            return 0

    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
        if "index" in key or "node_config_feat" == key:
            return 1
        elif (
            "node_opcode" in key or "node_config_ids" in key or "config_runtime" in key
        ):
            return -1
        else:
            return 0


@dataclass
class Normalizer:
    node_feat_mask: torch.Tensor
    node_feat_min: torch.Tensor
    node_feat_max: torch.Tensor
    node_config_feat_mask: torch.Tensor
    node_config_feat_min: torch.Tensor
    node_config_feat_max: torch.Tensor

    def normalize_node_feat(self, node_feat: torch.Tensor) -> torch.Tensor:
        assert node_feat.ndim == 2, "node_feat must be 2D"
        node_feat = node_feat[:, self.node_feat_mask]

        return (node_feat - self.node_feat_min) / (
            self.node_feat_max - self.node_feat_min
        )

    def normalize_node_config_feat(
        self, node_config_feat: torch.Tensor
    ) -> torch.Tensor:
        assert node_config_feat.ndim == 3, "node_config_feat must be 3D"
        node_config_feat = node_config_feat[:, :, self.node_config_feat_mask]
        return (node_config_feat - self.node_config_feat_min) / (
            self.node_config_feat_max - self.node_config_feat_min
        )

    @classmethod
    def from_dict(
        cls,
        configs: dict,
        source: Literal["xla", "nlp"],
        search: Literal["default", "random"],
    ) -> "Normalizer":
        try:
            data = configs[source][search]
        except KeyError:
            raise KeyError(
                f"Invalid source or search: source={source}, search={search}"
            )
        else:
            node_feat_mask = torch.tensor(data["node_feat_mask"], dtype=torch.bool)
            node_feat_min = torch.tensor(data["node_feat_min"], dtype=torch.float)[
                node_feat_mask
            ]
            node_feat_max = torch.tensor(data["node_feat_max"], dtype=torch.float)[
                node_feat_mask
            ]
            node_config_feat_mask = torch.tensor(
                data["node_config_feat_mask"], dtype=torch.bool
            )
            node_config_feat_min = torch.tensor(
                data["node_config_feat_min"], dtype=torch.float
            )[node_config_feat_mask]
            node_config_feat_max = torch.tensor(
                data["node_config_feat_max"], dtype=torch.float
            )[node_config_feat_mask]

            return Normalizer(
                node_feat_mask=node_feat_mask,
                node_feat_min=node_feat_min,
                node_feat_max=node_feat_max,
                node_config_feat_mask=node_config_feat_mask,
                node_config_feat_min=node_config_feat_min,
                node_config_feat_max=node_config_feat_max,
            )

    @classmethod
    def from_json(cls, path, source, search):
        import json

        json_data = json.load(open(path))
        return Normalizer.from_dict(json_data, source, search)


class LayoutDataset(Dataset):
    """Load all data in advance."""

    def __init__(
        self,
        files: list[str],
        max_configs: int = -1,
        num_configs: int = -1,
        normalizer: Normalizer = None,
        bins: np.array = None,
        three_split_sampling: bool = True,
        indices_dir: str = None,
        runtime_mean: float = None,
        runtime_std: float = None,
        thres: int = 5000,
    ):
        self.max_configs = max_configs
        self.num_configs = num_configs
        self.files = files
        self.normalizer = normalizer
        self.thres = thres

        if indices_dir is not None:
            if not Path(indices_dir).exists():
                raise FileNotFoundError(
                    f"Fold index dir <{indices_dir}> " "specified but does not exist"
                )
            indices_dir = Path(indices_dir)
            target_models = set([f.stem for f in indices_dir.glob("*.npy")])
            self.files = [f for f in files if Path(f).stem in target_models]
        else:
            self.files = files

        self.data = []
        pbar = tqdm(self.files, desc="Loading data")
        parts_cnt = 0

        for file in pbar:
            record = dict(np.load(file))
            model_id = Path(file).stem
            pbar.set_postfix_str(model_id)

            record["model_id"] = model_id
            runtime = record["config_runtime"]

            if bins is not None:
                cls_lables = np.digitize(runtime, bins)

            if runtime_mean is None or runtime_std is None:
                runtime = (runtime - runtime.mean()) / runtime.std()
            else:
                runtime = (runtime - runtime_mean) / runtime_std

            if indices_dir is not None:
                indices_file = Path(indices_dir) / f"{model_id}.npy"
                if indices_file.exists():
                    config_indices = np.load(indices_file)
                    runtime_sampled = runtime[config_indices]
                else:
                    raise FileNotFoundError(f"{indices_file} does not exist")
            else:
                if self.max_configs > 0:
                    # sample `max_configs` with order
                    # [good_configs, bad_configs, random_configs]
                    if three_split_sampling:
                        runtime_sampled, config_indices = sample_configs(
                            runtime, max_configs
                        )
                    else:
                        config_indices = torch.randperm(len(runtime))[:max_configs]
                        runtime_sampled = runtime[config_indices]
                else:
                    # use all configs
                    runtime_sampled = runtime
                    config_indices = torch.arange(len(runtime))

            record["config_runtime"] = runtime_sampled
            record["node_config_feat"] = record["node_config_feat"][config_indices]
            record["argsort_runtime"] = np.argsort(runtime_sampled)

            if bins is not None:
                record["cls_label"] = cls_lables[config_indices]

            # create graph for configurable nodes
            config_edge_index, edge_weight, paths = get_config_graph(
                record["edge_index"],
                record["node_config_ids"],
            )
            record["config_edge_weight"] = torch.tensor(edge_weight, dtype=torch.float)
            record["config_edge_path"] = paths

            config_edge_index = torch.tensor(
                config_edge_index.T,
                dtype=torch.long,
            )
            record["config_edge_index"] = config_edge_index

            record["config_runtime"] = torch.tensor(
                record["config_runtime"], dtype=torch.float
            )
            record["argsort_runtime"] = torch.tensor(
                record["argsort_runtime"], dtype=torch.long
            )
            record["node_feat"] = torch.tensor(record["node_feat"], dtype=torch.float)
            record["node_opcode"] = torch.tensor(
                record["node_opcode"], dtype=torch.long
            )
            record["edge_index"] = torch.tensor(
                record["edge_index"].T, dtype=torch.long
            )
            record["node_config_feat"] = torch.tensor(
                record["node_config_feat"], dtype=torch.float
            )
            record["node_config_ids"] = torch.tensor(
                record["node_config_ids"], dtype=torch.long
            )

            # GST
            num_nodes = torch.tensor(record["node_feat"].shape[0])
            num_parts = num_nodes // self.thres + 1
            interval = num_nodes // num_parts
            partptr = torch.arange(0, num_nodes, interval + 1)
            if partptr[-1] != num_nodes:
                partptr = torch.cat([partptr, torch.tensor([num_nodes])])

            record["partptr"] = partptr
            record["num_nodes"] = num_nodes
            record["num_configs"] = torch.tensor(len(record["config_runtime"]))
            record["partition_idx"] = parts_cnt
            parts_cnt += (num_parts * record["num_configs"]).item()


            if self.normalizer is not None:
                record["node_feat"] = self.normalizer.normalize_node_feat(
                    record["node_feat"]
                )
                record["node_config_feat"] = self.normalizer.normalize_node_config_feat(
                    record["node_config_feat"]
                )

            if bins is not None:
                record["cls_label"] = torch.tensor(
                    record["cls_label"], dtype=torch.long
                )

            self.data.append(record)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx) -> dict[str, Any]:
        record = self.data[idx]

        config_runtime = record["config_runtime"]
        node_feat = record["node_feat"]
        node_opcode = record["node_opcode"]
        edge_index = record["edge_index"]
        node_config_feat = record["node_config_feat"]
        node_config_ids = record["node_config_ids"]
        argsort_runtime = record["argsort_runtime"]
        config_edge_index = record["config_edge_index"]
        num_nodes = record["num_nodes"]
        num_configs = record["num_configs"]
        partptr = record["partptr"]
        partition_idx = record["partition_idx"]

        c = len(config_runtime)

        if self.num_configs > 0:
            num_configs = min(self.num_configs, c)
        elif self.max_configs > 0:
            num_configs = min(self.max_configs, c)
        else:
            num_configs = c

        # Sample
        if self.max_configs > 0 or self.num_configs > 0:
            # config_indices = torch.randperm(config_runtime.size(0))[
            #     :num_configs
            # ]
            idx = torch.topk(
                # Sample wrt GumbulSoftmax([NumConfs, NumConfs-1, ..., 1])
                (c - torch.arange(c)) / c - torch.log(-torch.log(torch.rand(c))),
                num_configs,
            )[1]
            config_indices = argsort_runtime[idx]
        else:
            config_indices = torch.arange(num_configs)
        config_runtime = config_runtime[config_indices]

        model_id = record["model_id"]

        node_config_feat = node_config_feat[config_indices]

        sample = dict(
            model_id=model_id,
            node_feat=node_feat,
            node_opcode=node_opcode,
            edge_index=edge_index,
            node_config_feat=node_config_feat,
            node_config_ids=node_config_ids,
            config_runtime=config_runtime,
            config_edge_index=config_edge_index,
            num_config_nodes=len(node_config_ids),
            num_config_edges=len(config_edge_index[0]),
            num_nodes=num_nodes,
            num_configs=num_configs,
            partptr=partptr,
            config_indices=config_indices,
            partition_idx=partition_idx,
        )

        if "cls_label" in record:
            sample["cls_label"] = record["cls_label"][config_indices]

        return LayoutData(**sample)


def sample_configs(config_runtime: np.array, max_configs: int) -> (np.array, np.array):
    """Sample 1/3 max_configs of best configs and 1/3 of worst configs,
    and the rest randomly. Return the sampled configs and indices.
    """
    c = len(config_runtime)
    max_configs = min(max_configs, c) if max_configs > 0 else c
    third = max_configs // 3

    sorted_indices = np.argsort(config_runtime)

    keep_indices = np.concatenate(
        [
            sorted_indices[:third],  # Good configs.
            sorted_indices[-third:],  # Bad configs.
            np.random.choice(
                sorted_indices[third:-third],
                max_configs - 2 * third,
            ),
        ]
    )

    return config_runtime[keep_indices], keep_indices


In [None]:
NUM_BATCHES = 16
norm_path = "../../data/normalizer.json"

runtime_mean = CONFIG_RUNTIME_MEAN_STD[source][search]["mean"]
runtime_std = CONFIG_RUNTIME_MEAN_STD[source][search]["std"]

trainset = LayoutDataset(
    xla_default["train"],
    max_configs=10240,
    num_configs=num_configs,
    normalizer=Normalizer.from_json(norm_path, source=source, search=search),
    runtime_mean=runtime_mean,
    runtime_std=runtime_std,
)

valset = LayoutDataset(
    xla_default["valid"],
    normalizer=Normalizer.from_json(norm_path, source=source, search=search),
    runtime_mean=runtime_mean,
    runtime_std=runtime_std,
)

In [None]:
from torch_sparse import SparseTensor
from torch_geometric.data import Batch
import copy


def get_adj(batch):
    batch_list = batch.to_data_list()
    processed_batch_list = []

    for g in batch_list:
        g.adj = SparseTensor(
            row=g.edge_index[0],
            col=g.edge_index[1],
            sparse_sizes=(g.num_nodes, g.num_nodes),
        )

        processed_batch_list.append(g)

    return Batch.from_data_list(processed_batch_list)


In [None]:
from typing import Optional

import torch
from torch import Tensor


class History(torch.nn.Module):
    r"""A historical embedding storage module."""

    def __init__(self, num_embeddings: int, embedding_dim: int, device=None):
        super().__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        pin_memory = device is None or str(device) == "cpu"
        self.emb = torch.empty(
            num_embeddings,
            embedding_dim,
            device=device,
            pin_memory=pin_memory,
            requires_grad=False,
        )

        self._device = torch.device("cpu")

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.fill_(0)

    def _apply(self, fn):
        # Set the `_device` of the module without transfering `self.emb`.
        self._device = fn(torch.zeros(1)).device
        return self

    @torch.no_grad()
    def pull(self, n_id: Optional[Tensor] = None) -> Tensor:
        out = self.emb
        if n_id is not None:
            assert n_id.device == self.emb.device
            out = out.index_select(0, n_id)
        return out.to(device=self._device)

    @torch.no_grad()
    def push(
        self,
        x,
        n_id: Optional[Tensor] = None,
        offset: Optional[Tensor] = None,
        count: Optional[Tensor] = None,
    ):
        if n_id is None and x.size(0) != self.num_embeddings:
            raise ValueError

        elif n_id is None and x.size(0) == self.num_embeddings:
            self.emb.copy_(x)

        elif offset is None or count is None:
            assert n_id.device == self.emb.device
            self.emb[n_id] = x.to(self.emb.device)

        else:  # Push in chunks:
            src_o = 0
            x = x.to(self.emb.device)
            for (
                dst_o,
                c,
            ) in zip(offset.tolist(), count.tolist()):
                self.emb[dst_o : dst_o + c] = x[src_o : src_o + c]
                src_o += c

    def forward(self, *args, **kwargs):
        """"""
        raise NotImplementedError

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}({self.num_embeddings}, "
            f"{self.embedding_dim}, emb_device={self.emb.device}, "
            f"device={self._device})"
        )


emb_table = History(500000000, 1)

In [None]:
import gfos.model.gnn as gfos_gnn

import importlib

importlib.reload(gfos_gnn)

model = gfos_gnn.LayoutModel(
    node_feat_dim=112,
    node_config_dim=14,
    config_neighbor_layer="GATConv",
    dropout=0.2,
    num_config_neighbor_layers=2,
    # config_neighbor_dropout_between_layers=0.2,
    # config_dropout_between_layers=0.2,
    head_dim=128,
)

In [None]:
from gfos.loss import BatchMultiElementRankLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR


optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-6, betas=(0.85, 0.9))
criterion = BatchMultiElementRankLoss(0.5, 50)
scheduler = CosineAnnealingLR(optimizer, T_max=10000, eta_min=1e-7)

In [None]:
device = torch.device("cuda:0")
model = model.to(device)


In [None]:
import wandb

run = wandb.init(project="gfos", entity="edenn0", name=f"{source}_{search}_gst")

In [None]:
losses = []

from tqdm import tqdm

for epoch in range(10000):
    loader = DataLoader(
        trainset,
        batch_size=NUM_BATCHES,
        shuffle=True,
        follow_batch=["node_config_feat", "node_feat"],
    )

    pbar = tqdm(enumerate(loader), leave=False, desc=f"Epoch: {epoch}")
    for iter, batch in pbar:
        batch = get_adj(batch)
        true = batch.config_runtime

        batch_list = batch.to_data_list()
        batch_train_list = []
        batch_other = []
        batch_num_parts = []
        segments_to_train = []
        skipped_batch = []
        for i in range(len(batch_list)):
            num_parts = len(batch_list[i].partptr) - 1

            segment_to_train = np.random.randint(num_parts)

            batch_other_ = []
            add_target = False
            for j in range(num_parts):
                start = int(batch_list[i].partptr.cpu().numpy()[j])
                length = int(batch_list[i].partptr.cpu().numpy()[j + 1]) - start

                # filter out nodes that are not in the current partition
                cidx = torch.where(
                    (batch_list[i].node_config_ids >= start)
                    & (batch_list[i].node_config_ids < start + length)
                )[0]
                if len(cidx) == 0:
                    if j == segment_to_train:
                        break
                    continue

                N, E, NC, EC = (
                    batch_list[i].num_nodes,
                    batch_list[i].num_edges,
                    batch_list[i].num_config_nodes,
                    batch_list[i].num_config_edges,
                )

                data = copy.copy(batch_list[i])
                del data.num_nodes

                # select the subgraph
                adj, data.adj = data.adj, None
                adj = adj.narrow(0, start, length).narrow(1, start, length)

                # select the sub config graph
                data.node_config_ids = data.node_config_ids[cidx] - start
                data.node_config_feat = data.node_config_feat.index_select(1, cidx)

                ei = data.config_edge_index
                mask = torch.isin(ei, cidx).all(dim=0)
                filtered_edge_index = ei[:, mask]

                new_indices = torch.zeros(data.num_config_nodes, dtype=torch.long) - 1
                new_indices[cidx] = torch.arange(cidx.size(0))

                data.config_edge_index = new_indices[
                    filtered_edge_index
                ]  # map to new indices

                for key, item in data:
                    if (
                        isinstance(item, torch.Tensor) and item.size(0) == N
                    ):  # node_feat, node_opcode
                        data[key] = item.narrow(0, start, length)
                    else:
                        data[key] = item

                row, col, _ = adj.coo()
                data.edge_index = torch.stack([row, col], dim=0)

                if j == segment_to_train:
                    batch_train_list.append(
                        LayoutData(
                            node_feat=data.node_feat,
                            node_opcode=data.node_opcode,
                            edge_index=data.edge_index,
                            node_config_feat=data.node_config_feat,
                            node_config_ids=data.node_config_ids,
                            config_runtime=data.config_runtime,
                            config_edge_index=data.config_edge_index,
                            num_config_nodes=len(data.node_config_ids),
                            num_nodes=len(data.node_feat),
                            model_id=data.model_id,
                        )
                    )
                    add_target = True
                else:
                    batch_other_.append(
                        emb_table.pull(
                            batch_list[i].partition_idx.cpu()
                            + data.config_indices * num_parts
                            + j
                        )
                    )

            if len(batch_other_) > 0 and add_target:
                batch_other_ = torch.mean(torch.stack(batch_other_, dim=0), dim=0)
                batch_other.append(batch_other_)
                batch_num_parts.extend([num_parts] * num_configs)
                segments_to_train.append(segment_to_train)
            elif add_target:  # only training segment contains configurable nodes
                batch_other.append(
                    torch.zeros_like(batch_train_list[-1].config_runtime).unsqueeze(1)
                )
                batch_num_parts.extend([num_parts] * num_configs)
                segments_to_train.append(segment_to_train)
            else:
                skipped_batch.append(i)

        if len(batch_train_list) == 0:
            continue

        batch_seg = Batch.from_data_list(
            batch_train_list, follow_batch=["node_config_feat", "node_feat"]
        )

        node_feat = batch_seg["node_feat"]
        node_opcode = batch_seg["node_opcode"]
        edge_index = batch_seg["edge_index"]
        node_config_feat = batch_seg["node_config_feat"]
        node_config_ids = batch_seg["node_config_ids"]
        config_runtime = batch_seg["config_runtime"]
        config_edge_index = batch_seg["config_edge_index"]
        node_config_feat_batch = batch_seg["node_config_feat_batch"]
        batch_size = len(batch_seg.model_id)

        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_edge_index,
            config_runtime,
            node_config_feat_batch,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            config_edge_index.to(device),
            config_runtime.to(device),
            node_config_feat_batch.to(device),
        )

        out = model(
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_edge_index,
            node_config_feat_batch,
            batch_size,
        )

        out = out.reshape(num_configs, -1).T.contiguous().reshape(-1, 1)

        binomial = torch.distributions.binomial.Binomial(probs=0.5)
        if len(batch_other) > 0:
            batch_other = torch.cat([b.to(device) for b in batch_other], dim=0)
            mask = binomial.sample((batch_other.shape[0], 1)).to(device)
            batch_other = batch_other.to(device)
            batch_other_embed = batch_other * mask

            batch_num_parts = torch.Tensor(batch_num_parts).to(device)
            batch_num_parts = batch_num_parts.view(-1, 1)
            multiplier_num = (batch_num_parts - 1) / 2 + 1
            pred = out * multiplier_num + batch_other_embed
        else:
            pred = out

        # pred = pred.reshape(-1, num_sample_config)
        out = out.reshape(-1, num_configs)
        config_runtime = config_runtime.reshape(-1, num_configs)

        loss = criterion(out, config_runtime)
        loss.backward()
        pbar.set_postfix_str(f"loss={loss.item():.4f}")
        losses.append(loss.item())

        # Backward
        if ((iter + 1) % accum_iter == 0) or (iter + 1 == len(trainset)):
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            optimizer.zero_grad()

            run.log(
                {
                    "train/loss": loss.item(),
                    "train/lr": optimizer.param_groups[0]["lr"],
                    "epoch": epoch,
                }
            )

        used_batch = [
            batch_list[i] for i in range(len(batch_list)) if i not in skipped_batch
        ]
        for i in range(out.shape[0]):
            config_idx = used_batch[i].config_indices
            push_idx = (
                batch_list[i].partition_idx.cpu()
                + config_idx * (len(batch_list[i].partptr) - 1)
                + segments_to_train[i]
            )
            emb_table.push(out[i].unsqueeze(1).cpu(), push_idx)

    scheduler.step()

    if epoch == 0 or (epoch + 1) % epoch_infer != 0:
        continue

    model.eval()
    metrics = LayoutMetrics()
    val_outs = {}  # save the output for each model

    for record in tqdm(
        valset,
        desc=f"Valid epoch: {epoch}",
        leave=False,
    ):
        config_runtime: torch.Tensor = record["config_runtime"]
        with torch.no_grad():
            node_feat = record["node_feat"]
            node_opcode = record["node_opcode"]
            edge_index = record["edge_index"]
            node_config_feat = record["node_config_feat"]
            node_config_ids = record["node_config_ids"]
            config_runtime = record["config_runtime"]
            config_edge_index = record["config_edge_index"]

            (
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
                config_edge_index,
            ) = (
                node_feat.to(device),
                node_opcode.to(device),
                edge_index.to(device),
                node_config_feat.to(device),
                node_config_ids.to(device),
                config_edge_index.to(device),
            )

            c = len(config_runtime)
            outs = []

            for i in range(0, c, 30):
                end_i = min(i + 30, c)
                out: torch.Tensor = model(
                    node_feat,
                    node_opcode,
                    edge_index,
                    node_config_feat[i:end_i],
                    node_config_ids,
                    config_edge_index,
                )
                outs.append(out.detach().cpu())
            outs = torch.concat(outs)

        metrics.add(
            record["model_id"],
            outs.numpy(),
            config_runtime.numpy(),
        )

        val_outs[record["model_id"]] = outs.numpy()

    prefix = "val/"
    scores = metrics.compute_scores(prefix=prefix)

    run.log(scores)

In [None]:
run.finish()

In [None]:
x = torch.randn(100, 5)
edge_index = torch.cat([torch.arange(0, 99).unsqueeze(0), torch.arange(1, 100).unsqueeze(0)], dim=0)
node_config_ids = torch.arange(30, 40)
node_config_feat = torch.randn(node_config_ids.size(0), 5)
node_config_feat = node_config_feat.repeat(3, 1, 1)
config_edge_index = torch.cat([torch.arange(0, 9).unsqueeze(0), torch.arange(1, 10).unsqueeze(0)], dim=0)

x = LayoutData(
    node_feat=x,
    node_opcode=None,
    edge_index=edge_index,
    node_config_feat=node_config_feat,
    node_config_ids=node_config_ids,
    config_runtime=None,
    config_edge_index=config_edge_index,
    num_config_nodes=len(node_config_ids),
    num_nodes=len(x),
    model_id=None,
)

In [None]:
start = 35
length = 10

# filter out nodes that are not in the current partition
cidx = torch.where(
    (x.node_config_ids >= start)
    & (x.node_config_ids < start + length)
)[0]
# if len(cidx) == 0:
#     if j == segment_to_train:
#         break
#     continue

N, E, NC = (
    x.num_nodes,
    x.num_edges,
    x.num_config_nodes,
)

data = copy.copy(x)
data.adj = SparseTensor(
    row=data.edge_index[0],
    col=data.edge_index[1],
    sparse_sizes=(data.num_nodes, data.num_nodes),
)
del data.num_nodes

# select the subgraph
adj, data.adj = data.adj, None
adj = adj.narrow(0, start, length).narrow(1, start, length)

# select the sub config graph
data.node_config_ids = data.node_config_ids[cidx] - start
data.node_config_feat = data.node_config_feat.index_select(1, cidx)

ei = data.config_edge_index
mask = torch.isin(ei, cidx).all(dim=0)
filtered_edge_index = ei[:, mask]

new_indices = torch.zeros(data.num_config_nodes, dtype=torch.long) - 1
new_indices[cidx] = torch.arange(cidx.size(0))

data.config_edge_index = new_indices[filtered_edge_index]  # map to new indices

for key, item in data:
    if (
        isinstance(item, torch.Tensor) and item.size(0) == N
    ):  # node_feat, node_opcode
        data[key] = item.narrow(0, start, length)
    else:
        data[key] = item

row, col, _ = adj.coo()
data.edge_index = torch.stack([row, col], dim=0)
