In [1]:
import os
from functools import partial
from datetime import datetime
from typing import Literal

import torch
import wandb
import numpy as np
import tensorflow as tf
import tensorflow_ranking as tfr
from torch import nn
from torch import optim
import torch_geometric.nn as geonn
from torch_geometric.data import Data, Batch
from tqdm import tqdm

from gfos.data.utils import load_layout
from gfos.data.dataset import LayoutDataset, Normalizer
from gfos.metrics import kendall, topk_error
from gfos.loss import MultiElementRankLoss, listMLE, lambdaLoss
from gfos.utils.misc import seed_everything


SEED = 42
DEBUG = False

seed_everything(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

os.environ["WANDB_NOTEBOOK_NAME"] = "layout_gnn_gnn_2_reranking.ipynb"


## Configs

In [2]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"
NORMALIZER_PATH = "../../data/normalizer.json"

# configs = dict(
#     conv_layer="SAGEConv",
#     op_embedding_dim=32,
#     config_dim=64,
#     graph_dim=64,
#     num_epochs=4000,
#     learning_rate=5e-3,
#     weight_decay=1e-5,
#     min_lr=1e-7,
#     max_configs=2000,
#     num_configs=32,
#     loss_margin=0.2,
#     loss_num_permutations=10,
#     accum_iter=8,
#     grad_clip=1e-1,
# )

configs = dict(
    conv_layer="SAGEConv",
    op_embedding_dim=32,
    config_dim=64,
    graph_dim=64,
    num_epochs=4000,
    learning_rate=5e-3,
    weight_decay=1e-7,
    min_lr=1e-7,
    max_configs=10000,
    num_configs=128,
    loss_margin=0.5,
    loss_num_permutations=100,
    accum_iter=4,
    grad_clip=1e-1,
)

SOURCE = "xla"
SEARCH = "random"
WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = f"layout_{SOURCE}_{SEARCH}"
TAGS = ["train", "layout", SOURCE, SEARCH]

NUM_VAL_EPOCHS = 25
INFERENCE_CONFIGS_BATCH_SIZE = 50


## Model

In [3]:
def aggregate_neighbors(node_feat: torch.Tensor, edge_index: torch.Tensor):
    source_nodes = edge_index[0]
    target_nodes = edge_index[1]

    in_degree_features = torch.zeros_like(node_feat, device=node_feat.device)
    out_degree_features = torch.zeros_like(node_feat, device=node_feat.device)

    source_node_features = node_feat[source_nodes]
    target_node_features = node_feat[target_nodes]

    in_degree_features.scatter_reduce_(
        0,
        target_nodes.unsqueeze(-1).expand_as(source_node_features),
        source_node_features,
        reduce="mean",
        include_self=False,
    )

    out_degree_features.scatter_reduce_(
        0,
        source_nodes.unsqueeze(-1).expand_as(target_node_features),
        target_node_features,
        reduce="mean",
        include_self=False,
    )

    return out_degree_features - in_degree_features

In [4]:
class LayoutModel(torch.nn.Module):
    def __init__(
        self,
        conv_layer: Literal["GATConv", "GCNConv", "SAGEConv"],
        op_embedding_dim: int = 32,
        config_dim: int = 64,
        graph_dim: int = 64,
        node_feat_dim: int = 140,
        node_config_dim: int = 18,
        num_bins: int = 0,
    ):
        super().__init__()

        NUM_OPCODE = 120

        conv_layer = getattr(geonn, conv_layer)

        merged_node_dim = 2 * graph_dim + config_dim

        self.embedding = torch.nn.Embedding(
            NUM_OPCODE,
            op_embedding_dim,
        )
        in_channels = op_embedding_dim + node_feat_dim

        self.model_gnn = geonn.Sequential(
            "x, edge_index",
            [
                (conv_layer(in_channels, graph_dim), "x, edge_index -> x1"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(graph_dim, graph_dim), "x1, edge_index -> x2"),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(graph_dim, graph_dim), "x3, edge_index -> x4"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(graph_dim, graph_dim), "x4, edge_index -> x5"),
                (lambda x4, x5: x4 + x5, "x4, x5 -> x6"),
                nn.LeakyReLU(inplace=True),
            ],
        )

        self.config_mp = geonn.Sequential(
            "x, edge_index",
            [
                (geonn.GATConv(graph_dim, graph_dim), "x, edge_index -> x1"),
                nn.LeakyReLU(inplace=True),
                (geonn.GATConv(graph_dim, graph_dim), "x1, edge_index -> x2"),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),
                nn.LeakyReLU(inplace=True),
            ],
        )

        self.config_gnn = geonn.Sequential(
            "x, edge_index",
            [
                (nn.Dropout(p=0.2), "x -> x"),
                (
                    conv_layer(merged_node_dim, config_dim),
                    "x, edge_index -> x1",
                ),
                nn.LeakyReLU(inplace=True),
                (conv_layer(config_dim, config_dim), "x1, edge_index -> x2"),
                (lambda x1, x2: x1 + x2, "x1, x2 -> x3"),
                nn.LeakyReLU(inplace=True),
                (conv_layer(config_dim, config_dim), "x3, edge_index -> x4"),
                nn.LeakyReLU(inplace=True),
                (
                    conv_layer(config_dim, config_dim),
                    "x4, edge_index -> x5",
                ),
                (lambda x4, x5: x4 + x5, "x4, x5 -> x6"),
                nn.LeakyReLU(inplace=True),
            ],
        )

        self.config_prj = nn.Sequential(
            nn.Linear(node_config_dim, config_dim),
            nn.LeakyReLU(),
        )
        
        self.clf = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(config_dim, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, num_bins, bias=False),
        )

        self.reranking = geonn.Sequential(
            "x, xc", [
                (nn.Dropout(0.2), "x -> x"),
                (nn.Linear(config_dim, num_bins, bias=False), "x -> x"),
                nn.LeakyReLU(inplace=True),
                (lambda x, xc: x * xc, "x, xc -> x"), 
                (nn.Linear(num_bins, 64, bias=False), "x -> x"),
                nn.LeakyReLU(inplace=True),
                (nn.Linear(64, 1, bias=False), "x -> x"),
            ]
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
        config_edge_index: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        # Get graph features
        x = self.model_gnn(x, edge_index)

        config_neighbors = aggregate_neighbors(x, edge_index)[node_config_ids]
        config_neighbors = nn.functional.normalize(config_neighbors, dim=-1)
        config_neighbors = self.config_mp(config_neighbors, config_edge_index)

        # (N, graph_out) -> (NC, graph_out)
        x = x[node_config_ids]
        # x += config_neighbors

        # Merge graph features with config features
        # (C, NC, 18) -> (C, NC, config_dim)
        node_config_feat = self.config_prj(node_config_feat)
        # pos_embedding = self.deg_prj(neighbor_feat)

        # (C, NC, 2*graph_out + config_dim)
        x = torch.cat(
            [
                config_neighbors.repeat((c, 1, 1)),
                x.repeat((c, 1, 1)),
                node_config_feat,
            ],
            dim=-1,
        )
        # x += pos_embedding
        x = nn.functional.normalize(x, dim=-1)

        datas = [
            Data(x=x[i], edge_index=config_edge_index)
            for i in range(x.shape[0])
        ]
        batch = Batch.from_data_list(datas)

        x = self.config_gnn(batch.x, batch.edge_index)
        xf = geonn.pool.global_mean_pool(x, batch.batch) # (C, config_dim)
        
        xc = self.clf(xf) # (C, num_bins)

        # TODO: split by cls
        x = self.reranking(xf, xc).flatten()

        return x, xc, xf  # predicted ranking, predicted bins, config features

## Prepare for training

### Expand configs

In [5]:
conv_layer = configs["conv_layer"]
op_embedding_dim = configs["op_embedding_dim"]
config_dim = configs["config_dim"]
graph_dim = configs["graph_dim"]
num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
max_configs = configs["max_configs"]
num_configs = configs["num_configs"]
accum_iter = configs["accum_iter"]
grad_clip = configs["grad_clip"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]
bins = np.concatenate(
    [
        np.arange(1e7, 1e8, 1e6),
        np.arange(1e8, 1e9, 2e7),
        np.arange(1e9, 1e10, 4e8),
    ]
)

### Read data

In [6]:
normalizer = Normalizer.from_json(NORMALIZER_PATH, SOURCE, SEARCH)

layout_data = load_layout(
    LAYOUT_DIR,
    model_type=SOURCE,
    compile_type=SEARCH,
)

train_dataset = LayoutDataset(
    layout_data["train"],
    max_configs=max_configs,
    num_configs=num_configs,
    config_edges="simple",
    normalizer=normalizer,
    bins=bins,
)
val_dataset = LayoutDataset(
    layout_data["valid"],
    config_edges="simple",
    normalizer=normalizer,
)


# pickle.dump(train_dataset, open("../../data/train_dataset.pkl", "wb"))
# pickle.dump(val_dataset, open("../../data/val_dataset.pkl", "wb"))

# train_dataset = pickle.load(open("../../data/train_dataset.pkl", "rb"))
# val_dataset = pickle.load(open("../../data/val_dataset.pkl", "rb"))

Loading data: 100%|██████████| 69/69 [00:53<00:00,  1.30it/s]
Loading data: 100%|██████████| 7/7 [00:07<00:00,  1.11s/it]


In [7]:
num_data_cls = torch.concat([train_dataset.data[i]["cls_label"] for i in range(len(train_dataset.data))]).bincount()
num_cls = len(num_data_cls)
weights = num_data_cls.sum() / (num_cls * num_data_cls)

In [8]:
node_feat_dim = train_dataset[0]["node_feat"].shape[-1]
node_config_dim = train_dataset[0]["node_config_feat"].shape[-1]


### Model, loss, optimizer, scheduler

In [9]:
model = LayoutModel(
    conv_layer=conv_layer,
    op_embedding_dim=op_embedding_dim,
    config_dim=config_dim,
    graph_dim=graph_dim,
    node_feat_dim=node_feat_dim,
    node_config_dim=node_config_dim,
    num_bins=num_cls,
).to(device)

local_rank_criterion = MultiElementRankLoss(margin, 10)
rank_criterion = MultiElementRankLoss(
    margin=margin, number_permutations=number_permutations
)
cls_critertion = nn.CrossEntropyLoss(weight=weights.to(device))


optimizer = optim.AdamW([
    {"name": "lr_embed", 'params': model.embedding.parameters(), 'lr': learning_rate / 10},
    {"name": "lr_model_gnn", 'params': model.model_gnn.parameters(), 'lr': learning_rate / 10},
    {"name": "lr_config_prj", 'params': model.config_prj.parameters(), 'lr': learning_rate / 10},
    {"name": "lr_config_mp", 'params': model.config_mp.parameters(), 'lr': learning_rate / 10},
    {"name": "lr_config_gnn", 'params': model.config_gnn.parameters(), 'lr': learning_rate},
    {"name": "lr_cls", 'params': model.clf.parameters(), 'lr': learning_rate},
    {"name": "lr_rank", 'params': model.reranking.parameters(), 'lr': learning_rate},
], betas=[0.85, 0.9], weight_decay=weight_decay)

# optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# num_steps = len(layout_data["train"]) * num_epochs
# warmup_steps = int(num_steps * warmup_ratio)

# scheduler = CosineAnnealingWarmupRestarts(
#     optimizer=optimizer,
#     first_cycle_steps=num_steps,
#     min_lr=min_lr,
#     max_lr=learning_rate,
#     warmup_steps=warmup_steps,
# )
# scheduler = optim.lr_scheduler.CosineAnnealingLR(
#     optimizer,
#     T_max=num_steps,
#     eta_min=min_lr,
# )
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer,
    mode="max",
    factor=0.1,
    patience=2, # 5 times evaluation = 5 * NUM_VAL_EPOCHS epochs
    threshold=0.01,
    min_lr=min_lr,
)

In [10]:
# tf.config.set_visible_devices([], 'GPU')

# # Don't know why but this cost huge amount of VRAM
# opa = tfr.keras.metrics.OPAMetric()


### Init wandb and train&valid dataset

In [11]:
if not DEBUG:
    run = wandb.init(
        project=WANDB_PROJECT,
        dir=WANDB_DIR,
        name=WANDB_RUN_NAME,
        config=configs,
        tags=TAGS,
    )
    run.watch(model, log="all")
    run.log_code("../")

    time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = f"../../logs/{WANDB_RUN_NAME}/{time_str}"
    os.makedirs(log_dir, exist_ok=True)




[34m[1mwandb[0m: Currently logged in as: [33medenn0[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Training

In [12]:
best_score = -1

epoch_steps = [100, 500]  # train cls, train local ranking

# scaler = GradScaler()
loss_mean = 0
for epoch in range(num_epochs):
    # Shuffle the training dataset
    permutation = np.random.permutation(len(train_dataset))

    # Training phase
    model.train()
    pbar = tqdm(permutation, leave=False)

    for i in pbar:
        record = train_dataset[i]
        node_feat = record["node_feat"]
        node_opcode = record["node_opcode"]
        edge_index = record["edge_index"]
        node_config_feat = record["node_config_feat"]
        node_config_ids = record["node_config_ids"]
        config_runtime = record["config_runtime"]
        config_edge_index = record["config_edge_index"]
        cls_label = record["cls_label"]

        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_runtime,
            config_edge_index,
            cls_label,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            config_runtime.to(device),
            config_edge_index.to(device),
            cls_label.to(device),
        )

        # with autocast():
        pred_runtime, pred_logits, _ = model(
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            config_edge_index,
        )
        
        if epoch < epoch_steps[0]:
            loss = cls_critertion(pred_logits, cls_label)
        elif epoch < epoch_steps[1]:
            cls_rank_loss = []
            for c in torch.unique(cls_label):
                mask = cls_label == c
                cls_rank_loss.append(local_rank_criterion(pred_runtime[mask], config_runtime[mask]))
            local_rank_loss = sum(cls_rank_loss) / len(cls_rank_loss)
            loss = local_rank_loss + cls_critertion(pred_logits, cls_label)
        else:
            loss = rank_criterion(pred_runtime, config_runtime)

        loss = loss / accum_iter
        loss_mean += loss.item()
        loss.backward()
        # scaler.scale(loss).backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        
        pbar.set_description(f"epoch: {epoch} loss: {loss_mean:.4f}")

        if ((i + 1) % accum_iter == 0) or (i + 1 == len(train_dataset)):
            # scaler.step(optimizer)
            # scaler.update()
            optimizer.step()
            optimizer.zero_grad()

            if not DEBUG:
                if epoch < epoch_steps[0]:
                    suffix = "_cls"
                elif epoch < epoch_steps[1]:
                    suffix = "_local_rank"
                else:
                    suffix = ""
                wandb.log(
                    {
                        "epoch": epoch,
                        "train/lr_slow": optimizer.param_groups[0]["lr"],
                        "train/lr_fast": optimizer.param_groups[-1]["lr"],
                        f"train/loss{suffix}": loss_mean,
                        # "train/cls_loss": cls_loss.item(),
                        # "train/rank_loss": all_rank_loss.item(),
                        # "train/cls_rank_loss": sum(cls_rank_loss) / len(cls_rank_loss),
                    }
                )
            loss_mean = 0

    pbar.close()

    if (epoch + 1) % NUM_VAL_EPOCHS != 0 and epoch != num_epochs - 1:
        continue

    model.eval()

    # Validation phase
    # Scores placeholder
    val_loss = []
    kendalltau_scores = []
    raw_kendalltau_scores = []
    idx_kendalltau_scores = []
    opa_scores = []
    top500_scores = []
    top100_scores = []
    score_per_model = {}

    with torch.no_grad():
        for record in tqdm(val_dataset, desc="valid", leave=False):
            model_id = record["model_id"]
            node_feat = record["node_feat"]
            node_opcode = record["node_opcode"]
            edge_index = record["edge_index"]
            node_config_feat = record["node_config_feat"]
            node_config_ids = record["node_config_ids"]
            config_runtime = record["config_runtime"]
            config_edge_index = record["config_edge_index"]

            (
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
                config_edge_index,
            ) = (
                node_feat.to(device),
                node_opcode.to(device),
                edge_index.to(device),
                node_config_feat.to(device),
                node_config_ids.to(device),
                config_edge_index.to(device),
            )
            # config_runtime = config_runtime.numpy()
            c = config_runtime.shape[-1]
            outs = []

            for i in range(0, c, INFERENCE_CONFIGS_BATCH_SIZE):
                end_i = min(i + INFERENCE_CONFIGS_BATCH_SIZE, c)
                # with autocast():
                out, _, _ = model(
                    node_feat,
                    node_opcode,
                    edge_index,
                    node_config_feat[i:end_i],
                    node_config_ids,
                    config_edge_index,
                )
                outs.append(out.detach().cpu())

            outs = torch.concat(outs).numpy()

            kendalltau_scores.append(kendall(np.argsort(outs), np.argsort(config_runtime)))
            raw_kendalltau_scores.append(kendall(outs, config_runtime))
            
            sorted_runtime, sorted_idx = torch.sort(config_runtime)
            sorted_outs = outs[sorted_idx]
            idx_kendalltau_score = kendall(np.argsort(sorted_outs), np.argsort(sorted_runtime))
            idx_kendalltau_scores.append(idx_kendalltau_score)
            
            # opa_scores.append(opa(config_runtime[None], outs[None]))
            top100_scores.append(topk_error(outs, config_runtime, top_k=100))
            top500_scores.append(topk_error(outs, config_runtime, top_k=500))
            
            score_per_model[model_id] = kendall(outs, config_runtime)

    kendalltau_mean = np.mean(kendalltau_scores)
    raw_kendalltau_mean = np.mean(raw_kendalltau_scores)
    idx_kendalltau_mean = np.mean(idx_kendalltau_scores)
    top100_mean = np.mean(top100_scores)
    top500_mean = np.mean(top500_scores)
    
    if epoch >= epoch_steps[1]:
        scheduler.step(idx_kendalltau_mean)

    if not DEBUG:
        wandb.log(
            {
                "val/kendalltau": kendalltau_mean,
                "val/raw_kendalltau": raw_kendalltau_mean,
                "val/idx_kendalltau": idx_kendalltau_mean,
                "val/top100_error": top100_mean,
                "val/top500_error": top500_mean,
            }
        )
        
        wandb.log(
            {f"val/kendall_{model}": score for model, score in score_per_model.items()}
        )

    print(
        f"epoch {epoch}, kendall = {idx_kendalltau_mean:.4f}"
    )

    # Update best scores and save the model if the mean score improves
    if idx_kendalltau_mean > best_score:
        best_score = idx_kendalltau_mean
        print(f"Best score updated: {best_score:.4f}")
        if not DEBUG:
            filename = f"{epoch}_{best_score:.4f}.pth"
            path = os.path.join(wandb.run.dir, filename)
            torch.save(
                model.state_dict(),
                path,
            )

if not DEBUG:
    run.finish()

                                                                        

epoch 24, kendall = 0.0836
Best score updated: 0.0836


                                                                        

epoch 49, kendall = -0.0590


                                                                        

epoch 74, kendall = 0.0569


                                                                        

epoch 99, kendall = -0.0042


                                                                         

epoch 124, kendall = 0.0101


                                                                         

epoch 149, kendall = -0.0011


                                                                         

epoch 174, kendall = -0.0060


                                                                         

epoch 199, kendall = 0.0922
Best score updated: 0.0922


                                                                         

epoch 224, kendall = 0.1342
Best score updated: 0.1342


                                                                         

epoch 249, kendall = 0.2442
Best score updated: 0.2442


                                                                         

epoch 274, kendall = 0.2877
Best score updated: 0.2877


                                                                         

epoch 299, kendall = 0.3167
Best score updated: 0.3167


                                                                         

epoch 324, kendall = 0.1533


epoch: 342 loss: 2.3905:  22%|██▏       | 15/69 [00:02<00:08,  6.13it/s]

In [None]:
run.finish()