In [1]:
import os
from datetime import datetime
from typing import Literal

import torch
import wandb
import numpy as np
from torch import nn
from torch import optim
import torch_geometric.nn as geonn
from tqdm import tqdm

from gfos.data.utils import load_layout
from gfos.data.dataset import LayoutDataset
from gfos.utils.scheduler import CosineAnnealingWarmupRestarts
from gfos.metrics import metric_for_layout_collections
from gfos.loss import MultiElementRankLoss, listMLE
from gfos.utils.misc import seed_everything


SEED = 42
DEBUG = False

seed_everything(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

## Configs

In [2]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"

configs = dict(
    conv_layer="GATConv",
    num_epochs=100,
    learning_rate=1e-3,
    weight_decay=1e-6,
    min_lr=1e-7,
    warmup_ratio=0.0,
    max_configs=100,
    graph_hidden=[16, 32, 16, 48, 64],
    num_encoder=1,
    num_feedforward=512,
    nhead=1,
    loss_margin=0.1,
    loss_num_permutations=2000,
)

WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = "gcn_layout_xla_default"
TAGS = ["train", "layout", "xla", "default"]

In [3]:
def min_max_normalize(feature_matrix: torch.Tensor) -> torch.Tensor:
    max_feat, _ = torch.max(feature_matrix, dim=0, keepdim=True)
    min_feat, _ = torch.min(feature_matrix, dim=0, keepdim=True)
    used_columns = min_feat[0] != max_feat[0]

    feature_matrix = feature_matrix[:, used_columns]
    min_feat = min_feat[:, used_columns]
    max_feat = max_feat[:, used_columns]
    return (feature_matrix - min_feat) / (max_feat - min_feat)

## Model

In [4]:
def edges_adjacency(edges: torch.Tensor, add_diagonal=True) -> torch.Tensor:
    """
    Generate an adjacency matrix from the edges
    Args:
        edges: Tensor of shape (num_edges, 2) with the edges
        add_diagonal: Boolean indicating if the diagonal should be added to the adjacency matrix
    Returns:
        adjacency_matrix: Tensor of shape (num_nodes, num_nodes) with the adjacency matrix
    """
    adjacency_matrix = torch.zeros(
        (edges.max() + 1, edges.max() + 1), device=edges.device
    )
    adjacency_matrix[edges[:, 0], edges[:, 1]] = 1
    if add_diagonal:
        diag_idx = torch.arange(adjacency_matrix.shape[0])
        adjacency_matrix[diag_idx, diag_idx] = 1
    return adjacency_matrix


def aggregate_neighbors(node_feat: torch.Tensor, edge_index: torch.Tensor):
    source_nodes = edge_index[0]
    target_nodes = edge_index[1]

    in_degree_features = torch.zeros_like(node_feat, device=node_feat.device)
    out_degree_features = torch.zeros_like(node_feat, device=node_feat.device)

    source_node_features = node_feat[source_nodes]
    target_node_features = node_feat[target_nodes]

    in_degree_features.scatter_reduce_(
        0,
        target_nodes.unsqueeze(-1).expand_as(source_node_features),
        source_node_features,
        reduce="sum",
    )
    
    out_degree_features.scatter_reduce_(
        0,
        source_nodes.unsqueeze(-1).expand_as(target_node_features),
        target_node_features,
        reduce="mean",
    )
    
    return out_degree_features - in_degree_features

In [5]:
class LayoutModel(torch.nn.Module):
    def __init__(
        self,
        conv_layer: Literal["GATConv", "GCNConv", "SAGEConv"],
        hidden_channels: list[int],
        num_encoder: int = 1,
        num_feedforward: int = 256,
        nhead: int = 1,
    ):
        super().__init__()

        conv_layer = getattr(geonn, conv_layer)

        op_embedding_dim = 32
        config_dim = 64
        graph_out = hidden_channels[-1]
        merged_node_dim = graph_out + config_dim

        self.embedding = torch.nn.Embedding(
            120,
            op_embedding_dim,
        )
        assert len(hidden_channels) > 0
        in_channels = op_embedding_dim + 140
        self.convs = torch.nn.ModuleList()

        hidden_channels = [in_channels] + hidden_channels
        for i in range(len(hidden_channels) - 1):
            self.convs += [
                conv_layer(hidden_channels[i], hidden_channels[i + 1]),
            ]
            nn.LeakyReLU()

        # Transformer encoder to merge configs features with graph features
        layer = nn.TransformerEncoderLayer(
            d_model=merged_node_dim,
            dim_feedforward=num_feedforward,
            nhead=nhead,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_encoder)

        self.layernorm = nn.LayerNorm(merged_node_dim)

        self.config_prj = nn.Sequential(
            nn.Linear(18, config_dim),
            nn.LayerNorm(config_dim),
            nn.LeakyReLU(),
        )

        # Define a sequential dense neural network
        self.dense = torch.nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(merged_node_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 1),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        # Get graph features
        for conv in self.convs:
            x = conv(x, edge_index)

        # neighbor_feat = aggregate_neighbors(x, edge_index)

        # (N, graph_out) -> (NC, graph_out)
        x = x[node_config_ids]
        # neighbor_feat = neighbor_feat[node_config_ids]
        # x += neighbor_feat

        # Merge graph features with config features
        # (C, NC, 18) -> (C, NC, config_dim)
        node_config_feat = self.config_prj(node_config_feat)

        # (C, NC, graph_out + config_dim)
        x = torch.cat([x.repeat((c, 1, 1)), node_config_feat], dim=-1)
        x = nn.functional.normalize(x, dim=-1)

        # (C, NC, graph_out + config_dim) -> (C, graph_out + config_dim)
        x = self.encoder(x)[:, -1, :]
        x = self.dense(x).flatten()

        return x

## Training

In [6]:
xla_random_layouts = load_layout(
    LAYOUT_DIR,
    model_type="xla",
    compile_type="random",
)

conv_layer = configs["conv_layer"]
num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
warmup_ratio = configs["warmup_ratio"]
max_configs = configs["max_configs"]
graph_hidden = configs["graph_hidden"]
num_encoder = configs["num_encoder"]
num_feedforward = configs["num_feedforward"]
nhead = configs["nhead"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]
_INFERENCE_CONFIGS_BATCH_SIZE = 100

model = LayoutModel(
    conv_layer=conv_layer,
    hidden_channels=graph_hidden,
    num_encoder=num_encoder,
    num_feedforward=num_feedforward,
    nhead=nhead,
).to(device)

criterion = MultiElementRankLoss(
    margin=margin, number_permutations=number_permutations
)
num_steps = len(xla_random_layouts["train"]) * num_epochs
warmup_steps = int(num_steps * warmup_ratio)

optimizer = optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)

scheduler = CosineAnnealingWarmupRestarts(
    optimizer=optimizer,
    first_cycle_steps=num_steps,
    min_lr=min_lr,
    max_lr=learning_rate,
    warmup_steps=warmup_steps,
)

In [7]:
if not DEBUG:
    run = wandb.init(
        project=WANDB_PROJECT,
        dir=WANDB_DIR,
        name=WANDB_RUN_NAME,
        config=configs,
        tags=TAGS,
    )
    run.watch(model, log="all")

    time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = f"../../logs/{WANDB_RUN_NAME}/{time_str}"
    os.makedirs(log_dir, exist_ok=True)

best_score_mean = -1

train_dataset = LayoutDataset(xla_random_layouts["train"], max_configs=2000, num_configs=100)
val_dataset = LayoutDataset(xla_random_layouts["valid"], permute=False)

for epoch in range(num_epochs):
    # Shuffle the training dataset
    permutation = np.random.permutation(len(train_dataset))
    # train_layouts = [xla_random_layouts["train"][i] for i in permutation]
    # train_dataset = LayoutDataset(train_layouts, max_configs=max_configs)

    model.train()
    pbar = tqdm(permutation, leave=False)

    for i in pbar:
        record = train_dataset[i]
        node_feat = record["node_feat"]
        node_opcode = record["node_opcode"]
        edge_index = record["edge_index"]
        node_config_feat = record["node_config_feat"]
        node_config_ids = record["node_config_ids"]
        config_runtime = record["target"]

        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_runtime,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            config_runtime.to(device),
        )

        out = model(
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        )

        loss = criterion(out, config_runtime)
        # loss = listMLE(out[None], target[None])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

        scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        if not DEBUG:
            wandb.log(
                {
                    "epoch": epoch,
                    "train/lr": scheduler.get_lr()[0],
                    "train/loss": loss.item(),
                }
            )

        pbar.set_description(f"epoch: {epoch} loss: {(loss.item()):.2f}")

    pbar.close()

    model.eval()
    layout_xla_scores = []

    with torch.no_grad():
        for record in tqdm(val_dataset, desc="valid", leave=False):
            node_feat = record["node_feat"]
            node_opcode = record["node_opcode"]
            edge_index = record["edge_index"]
            node_config_feat = record["node_config_feat"]
            node_config_ids = record["node_config_ids"]
            config_runtime = record["config_runtime"]

            (
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
            ) = (
                node_feat.to(device),
                node_opcode.to(device),
                edge_index.to(device),
                node_config_feat.to(device),
                node_config_ids.to(device),
            )
            config_runtime = config_runtime.numpy()
            num_configs = config_runtime.shape[-1]
            outs = []

            for i in range(0, num_configs, _INFERENCE_CONFIGS_BATCH_SIZE):
                end_i = min(i + _INFERENCE_CONFIGS_BATCH_SIZE, num_configs)
                out: torch.Tensor = model(
                    node_feat,
                    node_opcode,
                    edge_index,
                    node_config_feat[i:end_i],
                    node_config_ids,
                )
                outs.append(out.detach().cpu())

            pred_idx = np.argsort(torch.concat(outs).numpy())
            true_idx = np.argsort(config_runtime)

            score = metric_for_layout_collections(pred_idx, true_idx)
            layout_xla_scores.append(score)

    score_mean = np.mean(layout_xla_scores)
    score_max = np.max(layout_xla_scores)

    if not DEBUG:
        wandb.log(
            {
                "val/kendalltau": score_mean,
            }
        )

    print(
        f"epoch {epoch}, max_score = {score_max:.4f}, mean_score = {score_mean:.4f},"
    )

    # Update best scores and save the model if the mean score improves
    if score_mean > best_score_mean:
        best_score_mean = score_mean
        best_score_max = score_max
        print(f"Best score updated: {best_score_mean:.4f}")
        if not DEBUG:
            torch.save(
                model.state_dict(), f"{log_dir}/{epoch}_{score_mean:.4f}.pth"
            )

if not DEBUG:
    run.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33medenn0[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loading data: 100%|██████████| 69/69 [00:26<00:00,  2.58it/s]
Loading data: 100%|██████████| 7/7 [00:03<00:00,  1.85it/s]
                                                                    

epoch 0, max_score = 0.0751, mean_score = 0.0139,
Best score updated: 0.0139


                                                                    

epoch 1, max_score = 0.0047, mean_score = -0.0127,


                                                                    

epoch 2, max_score = 0.0306, mean_score = 0.0053,


                                                                    

epoch 3, max_score = 0.0031, mean_score = -0.0051,


                                                                    

epoch 4, max_score = 0.0337, mean_score = 0.0021,


                                                                    

epoch 5, max_score = 0.0215, mean_score = 0.0058,


                                                                    

epoch 6, max_score = 0.0332, mean_score = 0.0022,


                                                                    

epoch 7, max_score = 0.0190, mean_score = 0.0006,


                                                                    

epoch 8, max_score = 0.0100, mean_score = 0.0014,


                                                                    

epoch 9, max_score = 0.0287, mean_score = 0.0080,


                                                                     

epoch 10, max_score = 0.0201, mean_score = -0.0043,


                                                                     

epoch 11, max_score = 0.0425, mean_score = 0.0060,


                                                                     

epoch 12, max_score = 0.0179, mean_score = -0.0030,


                                                                     

epoch 13, max_score = 0.0118, mean_score = 0.0001,


                                                                     

epoch 14, max_score = 0.0111, mean_score = -0.0031,


                                                                     

epoch 15, max_score = 0.0450, mean_score = 0.0078,


                                                                     

epoch 16, max_score = 0.0165, mean_score = 0.0029,


                                                                     

epoch 17, max_score = 0.0143, mean_score = -0.0049,


                                                                     

epoch 18, max_score = 0.0068, mean_score = -0.0039,


                                                                     

epoch 19, max_score = -0.0003, mean_score = -0.0111,


                                                                     

epoch 20, max_score = 0.0049, mean_score = -0.0020,


                                                                     

epoch 21, max_score = 0.0046, mean_score = -0.0031,


                                                                     

epoch 22, max_score = 0.0274, mean_score = 0.0099,


                                                                     

epoch 23, max_score = 0.0078, mean_score = -0.0041,


                                                                     

epoch 24, max_score = 0.0131, mean_score = 0.0010,


                                                                     

epoch 25, max_score = 0.0553, mean_score = 0.0105,


                                                                     

epoch 26, max_score = 0.0161, mean_score = -0.0051,


                                                                     

epoch 27, max_score = 0.0240, mean_score = -0.0008,


                                                                     

epoch 28, max_score = 0.0266, mean_score = -0.0000,


                                                                     

epoch 29, max_score = 0.0057, mean_score = -0.0028,


                                                                     

epoch 30, max_score = 0.0067, mean_score = -0.0019,


                                                                     

epoch 31, max_score = 0.0191, mean_score = 0.0077,


                                                                     

epoch 32, max_score = 0.0073, mean_score = -0.0114,


                                                                     

epoch 33, max_score = 0.0151, mean_score = 0.0039,


                                                                     

epoch 34, max_score = 0.0045, mean_score = -0.0051,


                                                                     

epoch 35, max_score = 0.0310, mean_score = 0.0019,


                                                                     

epoch 36, max_score = 0.0200, mean_score = -0.0029,


                                                                     

epoch 37, max_score = -0.0012, mean_score = -0.0039,


                                                                     

epoch 38, max_score = 0.0102, mean_score = 0.0006,


                                                                     

epoch 39, max_score = 0.0117, mean_score = 0.0041,


                                                                     

epoch 40, max_score = 0.0169, mean_score = 0.0014,


                                                                     

epoch 41, max_score = 0.0137, mean_score = 0.0034,


                                                                     

epoch 42, max_score = 0.0233, mean_score = 0.0043,


                                                                     

epoch 43, max_score = 0.0120, mean_score = -0.0042,


                                                                     

epoch 44, max_score = 0.0056, mean_score = -0.0051,


                                                                     

epoch 45, max_score = 0.0151, mean_score = -0.0025,


                                                                     

epoch 46, max_score = 0.0082, mean_score = -0.0123,


                                                                     

epoch 47, max_score = 0.0094, mean_score = 0.0007,


                                                                     

epoch 48, max_score = 0.0373, mean_score = 0.0118,


                                                                     

epoch 49, max_score = 0.0077, mean_score = -0.0081,


                                                                     

epoch 50, max_score = 0.0330, mean_score = 0.0101,


                                                                     

epoch 51, max_score = 0.0312, mean_score = 0.0085,


                                                                     

epoch 52, max_score = 0.0475, mean_score = 0.0038,


                                                                     

epoch 53, max_score = 0.0165, mean_score = -0.0077,


                                                                     

epoch 54, max_score = 0.0216, mean_score = 0.0040,


                                                                     

epoch 55, max_score = 0.0310, mean_score = 0.0034,


                                                                     

epoch 56, max_score = 0.0083, mean_score = -0.0016,


                                                                     

epoch 57, max_score = 0.0028, mean_score = -0.0064,


                                                                     

epoch 58, max_score = 0.0139, mean_score = -0.0007,


                                                                     

epoch 59, max_score = 0.0163, mean_score = -0.0062,


                                                                     

epoch 60, max_score = 0.0055, mean_score = -0.0041,


                                                                     

epoch 61, max_score = 0.0145, mean_score = -0.0057,


                                                                     

epoch 62, max_score = 0.0207, mean_score = 0.0004,


                                                                     

epoch 63, max_score = 0.0082, mean_score = -0.0057,


                                                                     

epoch 64, max_score = 0.0176, mean_score = -0.0011,


                                                                     

epoch 65, max_score = 0.0118, mean_score = -0.0088,


                                                                     

epoch 66, max_score = 0.0157, mean_score = 0.0032,


                                                                     

epoch 67, max_score = 0.0134, mean_score = 0.0020,


                                                                     

epoch 68, max_score = 0.0105, mean_score = -0.0009,


                                                                     

epoch 69, max_score = 0.0126, mean_score = -0.0031,


                                                                     

epoch 70, max_score = 0.0267, mean_score = 0.0024,


                                                                     

epoch 71, max_score = 0.0176, mean_score = 0.0005,


                                                                     

epoch 72, max_score = 0.0124, mean_score = -0.0033,


                                                                     

epoch 73, max_score = 0.0052, mean_score = -0.0061,


                                                                     

epoch 74, max_score = 0.0154, mean_score = -0.0004,


                                                                     

epoch 75, max_score = 0.0074, mean_score = -0.0019,


                                                                     

epoch 76, max_score = 0.0189, mean_score = -0.0029,


                                                                     

epoch 77, max_score = 0.0185, mean_score = 0.0051,


                                                                     

epoch 78, max_score = 0.0068, mean_score = -0.0032,


                                                                     

epoch 79, max_score = 0.0137, mean_score = 0.0006,


                                                                     

KeyboardInterrupt: 