In [1]:
import os
import pickle

from datetime import datetime
from typing import Literal

import torch
import wandb
import numpy as np
import tensorflow as tf
import tensorflow_ranking as tfr
from torch import nn
from torch import optim
import torch_geometric.nn as geonn
from torch_geometric.data import Data, Batch
from tqdm import tqdm

from gfos.data.utils import load_layout
from gfos.data.dataset import LayoutDataset, Normalizer
from gfos.metrics import kendall, topk_error
from gfos.loss import MultiElementRankLoss
from gfos.utils.misc import seed_everything
from gfos.data.constants import mask_min_max

SEED = 42
DEBUG = False

seed_everything(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"


## Configs

In [2]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"

configs = dict(
    conv_layer="SAGEConv",
    op_embedding_dim=32,
    config_dim=64,
    graph_dim=64,
    num_epochs=4000,
    learning_rate=5e-3,
    weight_decay=1e-5,
    min_lr=1e-7,
    max_configs=2000,
    num_configs=16,
    loss_margin=0.2,
    loss_num_permutations=10,
    accum_iter=8,
    grad_clip=1e-1,
)

SOURCE = "xla"
SEARCH = "random"
WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = f"layout_{SOURCE}_{SEARCH}"
TAGS = ["train", "layout", SOURCE, SEARCH]

NUM_VAL_EPOCHS = 250
INFERENCE_CONFIGS_BATCH_SIZE = 20


## Model

In [3]:
def aggregate_neighbors(node_feat: torch.Tensor, edge_index: torch.Tensor):
    source_nodes = edge_index[0]
    target_nodes = edge_index[1]

    in_degree_features = torch.zeros_like(node_feat, device=node_feat.device)
    out_degree_features = torch.zeros_like(node_feat, device=node_feat.device)

    source_node_features = node_feat[source_nodes]
    target_node_features = node_feat[target_nodes]

    in_degree_features.scatter_reduce_(
        0,
        target_nodes.unsqueeze(-1).expand_as(source_node_features),
        source_node_features,
        reduce="mean",
        include_self=False,
    )

    out_degree_features.scatter_reduce_(
        0,
        source_nodes.unsqueeze(-1).expand_as(target_node_features),
        target_node_features,
        reduce="mean",
        include_self=False,
    )

    return out_degree_features - in_degree_features

In [4]:
class LayoutModel(torch.nn.Module):
    def __init__(
        self,
        conv_layer: Literal["GATConv", "GCNConv", "SAGEConv"],
        op_embedding_dim: int = 32,
        config_dim: int = 64,
        graph_dim: int = 64,
        node_feat_dim: int = 140,
        node_config_dim: int = 18,
    ):
        super().__init__()
        
        NUM_OPCODE = 120


        conv_layer = getattr(geonn, conv_layer)

        merged_node_dim = 2 * graph_dim + config_dim

        self.embedding = torch.nn.Embedding(
            NUM_OPCODE,
            op_embedding_dim,
        )
        in_channels = op_embedding_dim + node_feat_dim

        self.model_gnn = geonn.Sequential(
            "x, edge_index",
            [
                (conv_layer(in_channels, graph_dim), "x, edge_index -> x1"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(graph_dim, graph_dim), "x1, edge_index -> x2"),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(graph_dim, graph_dim), "x3, edge_index -> x4"),
                nn.LeakyReLU(inplace=True),
                (
                    conv_layer(graph_dim, graph_dim),#, normalize=True, bias=False),
                    "x4, edge_index -> x5",
                ),
                (lambda x4, x5: x4 + x5, "x4, x5 -> x6"),
                nn.LeakyReLU(inplace=True),
            ],
        )

        self.config_mp = geonn.Sequential(
            "x, edge_index",
            [
                (geonn.GATConv(graph_dim, graph_dim), "x, edge_index -> x1"),
                nn.LeakyReLU(inplace=True),
                (geonn.GATConv(graph_dim, graph_dim), "x1, edge_index -> x2"),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),
                nn.LeakyReLU(inplace=True),
                
            ]
        )

        self.config_gnn = geonn.Sequential(
            "x, edge_index",
            [
                (nn.Dropout(p=0.2), "x -> x"),
                (conv_layer(merged_node_dim, config_dim), "x, edge_index -> x1"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(config_dim, config_dim), "x1, edge_index -> x2"),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),
                nn.LeakyReLU(inplace=True),
                # (nn.Dropout(p=0.2), "x3 -> x3"),
                (conv_layer(config_dim, config_dim), "x3, edge_index -> x4"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(config_dim, config_dim, normalize=True), "x4, edge_index -> x5"),
                (lambda x4, x5: x4 + x5, "x4, x5 -> x6"),
                nn.LeakyReLU(inplace=True),
                
            ],
        )

        self.config_prj = nn.Sequential(
            nn.Linear(node_config_dim, config_dim),
            nn.LeakyReLU(),
        )

        # self.deg_prj = nn.Sequential(
        #     nn.Linear(hidden_channels[-1], merged_node_dim, bias=False),
        #     nn.LayerNorm(merged_node_dim),
        #     nn.LeakyReLU(),
        # )

        self.dense = torch.nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(config_dim, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, 1, bias=False),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
        config_edge_index: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        # Get graph features
        x = self.model_gnn(x, edge_index)

        config_neighbors = aggregate_neighbors(x, edge_index)[node_config_ids]
        config_neighbors = nn.functional.normalize(config_neighbors, dim=-1)
        config_neighbors = self.config_mp(config_neighbors, config_edge_index)

        # (N, graph_out) -> (NC, graph_out)
        x = x[node_config_ids]
        # x += neighbor_feat

        # Merge graph features with config features
        # (C, NC, 18) -> (C, NC, config_dim)
        node_config_feat = self.config_prj(node_config_feat)
        # pos_embedding = self.deg_prj(neighbor_feat)

        # (C, NC, 2*graph_out + config_dim)
        x = torch.cat([config_neighbors.repeat((c, 1, 1)), x.repeat((c, 1, 1)), node_config_feat], dim=-1)
        # x += pos_embedding
        x = nn.functional.normalize(x, dim=-1)

        datas = [
            Data(x=x[i], edge_index=config_edge_index)
            for i in range(x.shape[0])
        ]
        batch = Batch.from_data_list(datas)

        x = self.config_gnn(batch.x, batch.edge_index)
        x = geonn.pool.global_mean_pool(x, batch.batch)

        x = self.dense(x).flatten()

        return x

## Prepare for training

### Expand configs

In [5]:
conv_layer = configs["conv_layer"]
op_embedding_dim = configs["op_embedding_dim"]
config_dim = configs["config_dim"]
graph_dim = configs["graph_dim"]
num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
max_configs = configs["max_configs"]
num_configs = configs["num_configs"]
accum_iter = configs["accum_iter"]
grad_clip = configs["grad_clip"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]


### Read data

In [6]:
normalizer = Normalizer.from_configs(mask_min_max, SOURCE, SEARCH)

layout_data = load_layout(
    LAYOUT_DIR,
    model_type=SOURCE,
    compile_type=SEARCH,
)

test_dataset = LayoutDataset(
    layout_data["test"],
    config_edges="simple",
    normalizer=normalizer,
    test=True,
)

  runtime = (runtime - runtime.min()) / (
Loading data: 100%|██████████| 8/8 [00:02<00:00,  3.93it/s]


In [7]:
node_feat_dim = test_dataset[0]["node_feat"].shape[-1]
node_config_dim = test_dataset[0]["node_config_feat"].shape[-1]

### Model, loss, optimizer, scheduler

In [8]:
model = LayoutModel(
    conv_layer=conv_layer,
    op_embedding_dim=op_embedding_dim,
    config_dim=config_dim,
    graph_dim=graph_dim,
    node_feat_dim=node_feat_dim,
    node_config_dim=node_config_dim,
).to(device)


In [9]:
model_path = r"G:\projects\gfos\logs\wandb\run-20231004_105418-ye4hk2ey\files\1999_0.3409.pth"
static_dict = torch.load(model_path, map_location=device)
model.load_state_dict(static_dict)
model = model.eval()

## Inference

In [10]:

results = {}
scores = []
_INFERENCE_CONFIGS_BATCH_SIZE = 100

model.eval()

with torch.no_grad():
    for data in tqdm(test_dataset):
        model_id = data["model_id"]
        node_feat = data["node_feat"]
        node_opcode = data["node_opcode"]
        edge_index = data["edge_index"]
        node_config_feat = data["node_config_feat"]
        node_config_ids = data["node_config_ids"]
        config_edges = data["config_edge_index"]
        config_runtime = data["config_runtime"]

        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_edges,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            config_edges.to(device),
        )

        outs = []

        c = node_config_feat.size(0)

        for i in range(0, c, _INFERENCE_CONFIGS_BATCH_SIZE):
            end_i = min(i + _INFERENCE_CONFIGS_BATCH_SIZE, c)
            out: torch.Tensor = model(
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat[i:end_i],
                node_config_ids,
                config_edges,
            )
            outs.append(out.detach().cpu())
        outs = torch.cat(outs)
        scores.append(kendall(np.argsort(outs), np.argsort(config_runtime)))

        pred_idx = np.argsort(outs.numpy())

        results[model_id] = pred_idx.tolist()

100%|██████████| 8/8 [00:03<00:00,  2.10it/s]


In [11]:
with open("../../output/xla_random_ye4hk2ey.csv", "w") as f:
    f.write("ID,TopConfigs\n")
    for k, v in results.items():
        model_id = "layout:xla:random:"+k
        values = ";".join([str(i) for i in v])
        f.write(f"{model_id},{values}\n")
