In [None]:
import os
from datetime import datetime
from typing import Literal

import torch
import wandb
import numpy as np
from torch import nn
from torch import optim
import torch_geometric.nn as geonn
from tqdm import tqdm

from gfos.data.utils import load_layout
from gfos.data.dataset import LayoutDataset
from gfos.model.utils import aggregate_neighbors
from gfos.utils.scheduler import CosineAnnealingWarmupRestarts
from gfos.metrics import kendall
from gfos.loss import MultiElementRankLoss, listMLE
from gfos.utils.misc import seed_everything


SEED = 42
DEBUG = False

seed_everything(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

## Configs

In [None]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"

configs = dict(
    conv_layer="GATConv",
    num_epochs=2000,
    learning_rate=1e-3,
    weight_decay=1e-7,
    min_lr=1e-7,
    warmup_ratio=0.0,
    max_configs=2000,
    num_configs=32,
    graph_hidden=[16, 32, 16, 48, 64],
    num_encoder=1,
    num_feedforward=512,
    nhead=1,
    loss_margin=0.1,
    loss_num_permutations=50,
)

WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = "layout_xla_default_listMLE"
TAGS = ["train", "layout", "xla", "default"]

NUM_VAL_EPOCHS = 10
INFERENCE_CONFIGS_BATCH_SIZE = 100
TOP_K = 500

## Model

In [None]:
class LayoutModel(torch.nn.Module):
    def __init__(
        self,
        conv_layer: Literal["GATConv", "GCNConv", "SAGEConv"],
        hidden_channels: list[int],
        num_encoder: int = 1,
        num_feedforward: int = 256,
        nhead: int = 1,
    ):
        super().__init__()

        conv_layer = getattr(geonn, conv_layer)

        op_embedding_dim = 32
        config_dim = 64
        graph_out = hidden_channels[-1]
        merged_node_dim = graph_out + config_dim

        self.embedding = torch.nn.Embedding(
            120,
            op_embedding_dim,
        )
        assert len(hidden_channels) > 0
        in_channels = op_embedding_dim + 140
        self.convs = torch.nn.ModuleList()

        hidden_channels = [in_channels] + hidden_channels
        for i in range(len(hidden_channels) - 1):
            self.convs += [
                conv_layer(hidden_channels[i], hidden_channels[i + 1]),
            ]
            nn.LeakyReLU()

        # Transformer encoder to merge configs features with graph features
        layer = nn.TransformerEncoderLayer(
            d_model=merged_node_dim,
            dim_feedforward=num_feedforward,
            nhead=nhead,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_encoder)

        self.config_prj = nn.Sequential(
            nn.Linear(18, config_dim),
            nn.LayerNorm(config_dim),
            nn.LeakyReLU(),
        )
        
        self.deg_prj = nn.Sequential(
            nn.Linear(hidden_channels[-1], merged_node_dim, bias=False),
            nn.LayerNorm(merged_node_dim),
            nn.LeakyReLU(),
        )

        # Define a sequential dense neural network
        self.dense = torch.nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(merged_node_dim, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, 1, bias=False),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        # Get graph features
        for conv in self.convs:
            x = conv(x, edge_index)

        neighbor_feat = aggregate_neighbors(x, edge_index)[node_config_ids]
        
        # (N, graph_out) -> (NC, graph_out)
        x = x[node_config_ids]
        # x += neighbor_feat

        # Merge graph features with config features
        # (C, NC, 18) -> (C, NC, config_dim)
        node_config_feat = self.config_prj(node_config_feat)
        pos_embedding = self.deg_prj(neighbor_feat)

        # (C, NC, graph_out + config_dim)
        x = torch.cat([x.repeat((c, 1, 1)), node_config_feat], dim=-1)
        x += pos_embedding
        x = nn.functional.normalize(x, dim=-1)

        # (C, NC, graph_out + config_dim) -> (C, graph_out + config_dim)
        x = self.encoder(x)[:, -1, :]
        x = self.dense(x).flatten()

        return x

## Inference

In [None]:
layout_files = load_layout(
    LAYOUT_DIR,
    model_type="xla",
    compile_type="default",
)
test_dataset = LayoutDataset(files=layout_files["valid"])

conv_layer = configs["conv_layer"]
num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
warmup_ratio = configs["warmup_ratio"]
max_configs = configs["max_configs"]
num_configs = configs["num_configs"]
graph_hidden = configs["graph_hidden"]
num_encoder = configs["num_encoder"]
num_feedforward = configs["num_feedforward"]
nhead = configs["nhead"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]

model = LayoutModel(
    conv_layer=conv_layer,
    hidden_channels=graph_hidden,
    num_encoder=num_encoder,
    num_feedforward=num_feedforward,
    nhead=nhead,
).to(device)

static_dict = torch.load("../../logs/wandb/run-20230930_061557-bxlnki1i/files/1419_0.3547.pth", map_location=device)
model.load_state_dict(static_dict)
model = model.eval()

In [None]:
layout_xla_scores = []
pred_idx = []
results = {}
_INFERENCE_CONFIGS_BATCH_SIZE = 64

scores = []
_scores = []
with torch.no_grad():
    for data in tqdm(test_dataset):
        model_id = data["model_id"]
        node_feat = data["node_feat"]
        node_opcode = data["node_opcode"]
        edge_index = data["edge_index"]
        node_config_feat = data["node_config_feat"]
        node_config_ids = data["node_config_ids"]
        config_runtime = data["config_runtime"]

        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
        )

        outs = []
        
        c = node_config_feat.size(0)

        for i in range(0, c, _INFERENCE_CONFIGS_BATCH_SIZE):
            end_i = min(i + _INFERENCE_CONFIGS_BATCH_SIZE, c)
            out: torch.Tensor = model(
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat[i: end_i],
                node_config_ids,
            )
            outs.append(out.detach().cpu())
        outs = torch.cat(outs)
        scores.append(kendall(np.argsort(outs.numpy()), np.argsort(config_runtime)))
        _scores.append(kendall(outs, config_runtime))

        pred_idx = np.argsort(outs.numpy())

        results[model_id] = pred_idx.tolist()


In [None]:
np.mean(_scores)

In [None]:

with open("../../output/xla_default_bxlnki1i.csv", "w") as f:
    f.write("ID,TopConfigs\n")
    for k, v in results.items():
        model_id = "layout:xla:default:"+k
        values = ";".join([str(i) for i in v])
        f.write(f"{model_id},{values}\n")
        