In [1]:
import os
from datetime import datetime
from typing import Literal

import torch
import wandb
import numpy as np
from tqdm import tqdm
from torch import nn
from torch import optim
from torch.nn import Linear, ReLU, Dropout
import torch_geometric.nn as geonn
from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
from torch_geometric.nn import global_mean_pool

from gfos.data.utils import load_layout
from gfos.data.dataset import LayoutDataset
from gfos.utils.scheduler import CosineAnnealingWarmupRestarts
from gfos.metrics import metric_for_layout_collections
from gfos.loss import MultiElementRankLoss, listMLE
from gfos.utils.misc import seed_everything


SEED = 42
DEBUG = True

seed_everything(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

## Configs

In [2]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"

configs = dict(
    conv_layer="GCNConv",
    num_epochs=100,
    learning_rate=1e-3,
    weight_decay=1e-6,
    min_lr=1e-7,
    warmup_ratio=0.0,
    max_configs=5000,
    num_configs=16,
    num_encoder=1,
    num_feedforward=512,
    nhead=1,
    loss_margin=0.1,
    loss_num_permutations=50,
)

WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = "gcn_layout_xla_default"
TAGS = ["train", "layout", "xla", "default"]

In [3]:
def min_max_normalize(feature_matrix: torch.Tensor) -> torch.Tensor:
    max_feat, _ = torch.max(feature_matrix, dim=0, keepdim=True)
    min_feat, _ = torch.min(feature_matrix, dim=0, keepdim=True)
    used_columns = min_feat[0] != max_feat[0]

    feature_matrix = feature_matrix[:, used_columns]
    min_feat = min_feat[:, used_columns]
    max_feat = max_feat[:, used_columns]
    return (feature_matrix - min_feat) / (max_feat - min_feat)

## Model

In [4]:
def edges_adjacency(edges: torch.Tensor, add_diagonal=True) -> torch.Tensor:
    """
    Generate an adjacency matrix from the edges
    Args:
        edges: Tensor of shape (num_edges, 2) with the edges
        add_diagonal: Boolean indicating if the diagonal should be added to the adjacency matrix
    Returns:
        adjacency_matrix: Tensor of shape (num_nodes, num_nodes) with the adjacency matrix
    """
    adjacency_matrix = torch.zeros(
        (edges.max() + 1, edges.max() + 1), device=edges.device
    )
    adjacency_matrix[edges[:, 0], edges[:, 1]] = 1
    if add_diagonal:
        diag_idx = torch.arange(adjacency_matrix.shape[0])
        adjacency_matrix[diag_idx, diag_idx] = 1
    return adjacency_matrix


def aggregate_neighbors(node_feat: torch.Tensor, edge_index: torch.Tensor):
    source_nodes = edge_index[0]
    target_nodes = edge_index[1]

    in_degree_features = torch.zeros_like(node_feat, device=node_feat.device)
    out_degree_features = torch.zeros_like(node_feat, device=node_feat.device)

    source_node_features = node_feat[source_nodes]
    target_node_features = node_feat[target_nodes]

    in_degree_features.scatter_reduce_(
        0,
        target_nodes.unsqueeze(-1).expand_as(source_node_features),
        source_node_features,
        reduce="sum",
    )
    
    out_degree_features.scatter_reduce_(
        0,
        source_nodes.unsqueeze(-1).expand_as(target_node_features),
        target_node_features,
        reduce="mean",
    )
    
    return out_degree_features - in_degree_features

In [5]:
class LayoutModel(torch.nn.Module):
    def __init__(
        self,
        conv_layer: Literal["GATConv", "GCNConv", "SAGEConv"],
    ):
        super().__init__()

        conv_layer = getattr(geonn, conv_layer)

        op_embedding_dim = 32
        config_dim = 64

        self.embedding = torch.nn.Embedding(
            120,
            op_embedding_dim,
        )
        in_channels = config_dim + op_embedding_dim + 140

        self.convs = Sequential(
            "x, edge_index, batch",
            [
                (Dropout(p=0.2), "x -> x"),
                (
                    conv_layer(in_channels, 64, node_dim=1),
                    "x, edge_index -> x1",
                ),
                nn.LeakyReLU(inplace=True),
                (conv_layer(64, 64, node_dim=1), "x1, edge_index -> x2"),
                nn.LeakyReLU(inplace=True),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),

                (nn.Dropout(p=0.2), "x3 -> x3"),
                (
                    conv_layer(64, 128, node_dim=1),
                    "x3, edge_index -> x4",
                ),
                nn.LeakyReLU(inplace=True),
                (conv_layer(128, 128, node_dim=1), "x4, edge_index -> x5"),
                nn.LeakyReLU(inplace=True),
                (lambda x4, x5: x4 + x5, "x4, x5 -> x6"),

                (nn.Dropout(p=0.2), "x6 -> x6"),
                (
                    conv_layer(128, 64, node_dim=1),
                    "x6, edge_index -> x7",
                ),
                nn.LeakyReLU(inplace=True),
                (conv_layer(64, 64, node_dim=1), "x7, edge_index -> x8"),
                nn.LeakyReLU(inplace=True),
                (lambda x7, x8: [x7, x8], "x7, x8 -> xs"),
                (JumpingKnowledge("cat", 64, num_layers=2), "xs -> x"),
                (global_mean_pool, "x, batch -> x"),
                Linear(2 * 64, 1),
            ],
        )

        self.config_prj = nn.Sequential(
            nn.Linear(18, config_dim),
            nn.LayerNorm(config_dim),
            nn.LeakyReLU(),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.shape[0]
        n = node_feat.shape[0]

        # (C, NC, 18) -> (C, NC, config_dim)
        node_config_feat = self.config_prj(node_config_feat)

        configs = torch.zeros(
            (c, n, node_config_feat.shape[-1]), device=node_feat.device
        )
        configs[:, node_config_ids] = node_config_feat

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)
        x = torch.cat([configs, x.repeat(c, 1, 1)], dim=2)

        # Get graph features
        x = self.convs(
            x,
            edge_index,
            torch.zeros(n, dtype=torch.long, device=node_feat.device),
        )

        return x.squeeze([1, 2])

## Training

In [6]:
xla_random_layouts = load_layout(
    LAYOUT_DIR,
    model_type="xla",
    compile_type="default",
)

conv_layer = configs["conv_layer"]
num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
warmup_ratio = configs["warmup_ratio"]
max_configs = configs["max_configs"]
num_configs = configs["num_configs"]
num_encoder = configs["num_encoder"]
num_feedforward = configs["num_feedforward"]
nhead = configs["nhead"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]
_INFERENCE_CONFIGS_BATCH_SIZE = 100

model = LayoutModel(conv_layer=conv_layer).to(device)

criterion = MultiElementRankLoss(
    margin=margin, number_permutations=number_permutations
)
num_steps = len(xla_random_layouts["train"]) * num_epochs
warmup_steps = int(num_steps * warmup_ratio)

optimizer = optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)

scheduler = CosineAnnealingWarmupRestarts(
    optimizer=optimizer,
    first_cycle_steps=num_steps,
    min_lr=min_lr,
    max_lr=learning_rate,
    warmup_steps=warmup_steps,
)

In [7]:
# train_dataset = LayoutDataset(xla_random_layouts["train"], max_configs=1000, num_configs=100)

# record = train_dataset[0]
# node_feat = record["node_feat"]
# node_opcode = record["node_opcode"]
# edge_index = record["edge_index"]
# node_config_feat = record["node_config_feat"]
# node_config_ids = record["node_config_ids"]
# config_runtime = record["config_runtime"]

# (
#     node_feat,
#     node_opcode,
#     edge_index,
#     node_config_feat,
#     node_config_ids,
#     config_runtime,
# ) = (
#     node_feat.to(device),
#     node_opcode.to(device),
#     edge_index.to(device),
#     node_config_feat.to(device),
#     node_config_ids.to(device),
#     config_runtime.to(device),
# )

# out = model(
#     node_feat,
#     node_opcode,
#     edge_index,
#     node_config_feat,
#     node_config_ids,
# )
# out.shape


In [8]:
NUM_VAL_EPOCHS = 10

if not DEBUG:
    run = wandb.init(
        project=WANDB_PROJECT,
        dir=WANDB_DIR,
        name=WANDB_RUN_NAME,
        config=configs,
        tags=TAGS,
    )
    run.watch(model, log="all")

    time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = f"../../logs/{WANDB_RUN_NAME}/{time_str}"
    os.makedirs(log_dir, exist_ok=True)

best_score_mean = -1

train_dataset = LayoutDataset(
    xla_random_layouts["train"],
    max_configs=max_configs,
    num_configs=num_configs,
)
val_dataset = LayoutDataset(xla_random_layouts["valid"], permute=False)

for epoch in range(num_epochs):
    # Shuffle the training dataset
    permutation = np.random.permutation(len(train_dataset))
    # train_layouts = [xla_random_layouts["train"][i] for i in permutation]
    # train_dataset = LayoutDataset(train_layouts, max_configs=max_configs)

    model.train()
    pbar = tqdm(permutation, leave=False)

    for i in pbar:
        record = train_dataset[i]
        node_feat = record["node_feat"]
        node_opcode = record["node_opcode"]
        edge_index = record["edge_index"]
        node_config_feat = record["node_config_feat"]
        node_config_ids = record["node_config_ids"]
        config_runtime = record["config_runtime"]

        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_runtime,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            config_runtime.to(device),
        )

        out = model(
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        )

        # loss = criterion(out, config_runtime)
        loss = listMLE(out.view(1, -1), config_runtime[None])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

        scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        if not DEBUG:
            wandb.log(
                {
                    "epoch": epoch,
                    "train/lr": scheduler.get_lr()[0],
                    "train/loss": loss.item(),
                }
            )

        pbar.set_description(f"epoch: {epoch} loss: {(loss.item()):.2f}")

    pbar.close()
    if (epoch + 1) % NUM_VAL_EPOCHS != 0:
        continue

    model.eval()
    layout_xla_scores = []

    with torch.no_grad():
        for record in tqdm(val_dataset, desc="valid", leave=False):
            node_feat = record["node_feat"]
            node_opcode = record["node_opcode"]
            edge_index = record["edge_index"]
            node_config_feat = record["node_config_feat"]
            node_config_ids = record["node_config_ids"]
            config_runtime = record["config_runtime"]

            (
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
            ) = (
                node_feat.to(device),
                node_opcode.to(device),
                edge_index.to(device),
                node_config_feat.to(device),
                node_config_ids.to(device),
            )
            config_runtime = config_runtime.numpy()
            num_configs = config_runtime.shape[-1]
            outs = []

            for i in range(0, num_configs, _INFERENCE_CONFIGS_BATCH_SIZE):
                end_i = min(i + _INFERENCE_CONFIGS_BATCH_SIZE, num_configs)
                out: torch.Tensor = model(
                    node_feat,
                    node_opcode,
                    edge_index,
                    node_config_feat[i:end_i],
                    node_config_ids,
                )
                outs.append(out.detach().cpu())

            pred_idx = np.argsort(torch.concat(outs).numpy())
            true_idx = np.argsort(config_runtime)

            score = metric_for_layout_collections(pred_idx, true_idx)
            layout_xla_scores.append(score)

    score_mean = np.mean(layout_xla_scores)
    score_max = np.max(layout_xla_scores)

    if not DEBUG:
        wandb.log(
            {
                "val/kendalltau": score_mean,
            }
        )

    print(
        f"epoch {epoch}, max_score = {score_max:.4f}, mean_score = {score_mean:.4f},"
    )

    # Update best scores and save the model if the mean score improves
    if score_mean > best_score_mean:
        best_score_mean = score_mean
        best_score_max = score_max
        print(f"Best score updated: {best_score_mean:.4f}")
        if not DEBUG:
            torch.save(
                model.state_dict(), f"{log_dir}/{epoch}_{score_mean:.4f}.pth"
            )

if not DEBUG:
    run.finish()

Loading data:   0%|          | 0/61 [00:00<?, ?it/s]

Loading data: 100%|██████████| 61/61 [00:31<00:00,  1.92it/s]
Loading data: 100%|██████████| 7/7 [00:04<00:00,  1.46it/s]
                                                                         

epoch 9, max_score = 0.0079, mean_score = -0.0030,
Best score updated: -0.0030


                                                                        

epoch 19, max_score = 0.0114, mean_score = -0.0031,


                                                                       

epoch 29, max_score = 0.0180, mean_score = 0.0056,
Best score updated: 0.0056


                                                                       

epoch 39, max_score = 0.0189, mean_score = -0.0002,


                                                                      

epoch 49, max_score = 0.0107, mean_score = 0.0024,


                                                                      

epoch 59, max_score = 0.0188, mean_score = 0.0048,


                                                                      

epoch 69, max_score = -0.0002, mean_score = -0.0059,


                                                                      

epoch 79, max_score = 0.0025, mean_score = -0.0043,


                                                                      

epoch 89, max_score = 0.0044, mean_score = -0.0030,


                                                                      

epoch 99, max_score = 0.0090, mean_score = -0.0009,


