In [1]:
import torch
import wandb
import numpy as np
from torch import nn
from torch import optim
from torch_geometric.nn import GCNConv, GATConv
from torch.utils.data import Dataset
from tqdm import tqdm
from sklearn.model_selection import KFold

from gfos.data.utils import load_layout
from gfos.utils.scheduler import CosineAnnealingWarmupRestarts
from gfos.metrics import metric_for_layout_collections
from gfos.loss import MultiElementRankLoss

device = "cuda" if torch.cuda.is_available() else "cpu"


## Configs

In [2]:
LAYOUT_DIR = "../../data/layout"

configs = dict(
    conv_layer="GAT",
    num_epochs=100,
    learning_rate=5e-4,
    weight_decay=1e-6,
    min_lr=1e-7,
    warmup_ratio=0.05,
    max_configs=100,
    graph_hidden=[16, 32, 16, 48],
    graph_out=64,
    num_encoder=1,
    num_feedforward=512,
    nhead=1,
    loss_margin=0.1,
    loss_num_permutations=1000,
    accum_grad=4,
)

WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = "gcn_layout_xla"
TAGS = ["train", "layout", "xla"]


In [4]:
load_layout(LAYOUT_DIR, model_type="xla", compile_type="random")["valid"]

['../../data/layout/xla/random/valid/resnet_v1_50_official_batch_128_bf16.npz',
 '../../data/layout/xla/random/valid/unet_3d.4x4.bf16.npz',
 '../../data/layout/xla/random/valid/resnet50.4x4.fp16.npz',
 '../../data/layout/xla/random/valid/tf2_bert_pretrain_dynamic_batch_size.npz',
 '../../data/layout/xla/random/valid/inception_v3_batch_128_train.npz',
 '../../data/layout/xla/random/valid/bert_pretraining.4x4.fp16.npz',
 '../../data/layout/xla/random/valid/mlperf_bert_batch_24_2x2.npz']

## Dataset

In [5]:
class LayoutDataset(Dataset):
    """Take all `c` configs as one batch. Using without dataloader."""

    def __init__(self, files: list[str], max_configs: int = -1):
        self.max_configs = max_configs
        self.files = files
        self.npzs = [np.load(file) for file in self.files]
        self.num_records = [len(npz["config_runtime"]) for npz in self.npzs]

    def __len__(self):
        return len(self.npzs)

    def __getitem__(self, idx):
        # cum_records = np.cumsum(self.num_records)
        # npz_idx = np.searchsorted(cum_records, idx)
        # row = self.npzs[npz_idx]

        # cfg_idx = idx - cum_records[npz_idx - 1] if npz_idx > 0 else idx
        row = self.npzs[idx]
        max_configs = self.max_configs if self.max_configs > 0 else len(row["config_runtime"])

        target = torch.tensor(row["config_runtime"], dtype=torch.float)
        target = (target - target.mean()) / (target.std() + 1e-7)

        # Random sample `max_configs` configs from `c` configs
        config_indices = torch.randperm(target.size(0))[: max_configs]
        target = target[config_indices]

        node_feat = torch.tensor(row["node_feat"], dtype=torch.float)
        node_opcode = torch.tensor(row["node_opcode"], dtype=torch.long)
        edge_index = torch.tensor(
            np.swapaxes(row["edge_index"], 0, 1), dtype=torch.long
        )

        node_config_feat = torch.tensor(
            row["node_config_feat"], dtype=torch.float
        )[config_indices]

        node_config_ids = torch.tensor(
            row["node_config_ids"], dtype=torch.long
        )

        return (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        )

## Model

In [6]:
# def transform_node_positional_embeddings(
#     embeddings_output: torch.Tensor,
#     node_config_ids: torch.Tensor,
#     num_nodes: int,
# ) -> torch.Tensor:
#     num_configs, _, dim = embeddings_output.shape
#     idxs = node_config_ids.unsqueeze(0).repeat(num_configs, 1) # [c, nc]
#     zeros = torch.zeros(
#         num_configs,
#         num_nodes,
#         dim,
#         device=embeddings_output.device,
#         dtype=embeddings_output.dtype,
#     )
#     idxs = idxs.unsqueeze(-1).repeat(1, 1, dim) # [c, nc, dim]
#     zeros.scatter_reduce_(1, idxs, embeddings_output, reduce="sum")
#     return zeros

In [7]:
class SimpleModel(torch.nn.Module):
    def __init__(
        self,
        hidden_channels,
        graph_out,
        num_encoder=1,
        num_feedforward=256,
        nhead=1,
    ):
        super().__init__()

        op_embedding_dim = 4  # I choose 4-dimensional embedding
        config_dim = 64
        merged_node_dim = graph_out + config_dim

        self.embedding = torch.nn.Embedding(
            120,  # 120 different op-codes
            op_embedding_dim,
        )
        assert len(hidden_channels) > 0
        in_channels = op_embedding_dim + 140
        self.convs = torch.nn.ModuleList()
        last_dim = hidden_channels[0]

        # Create a sequence of Graph Convolutional Network (GCN) layers
        self.convs.append(GCNConv(in_channels, hidden_channels[0]))
        for i in range(len(hidden_channels) - 1):
            self.convs += [
                GATConv(hidden_channels[i], hidden_channels[i + 1]),
            ]
            last_dim = hidden_channels[i + 1]
        self.convs.append(GATConv(last_dim, graph_out))

        # Transformer encoder to merge configs features with graph features
        layer = nn.TransformerEncoderLayer(
            d_model=merged_node_dim,
            dim_feedforward=num_feedforward,
            nhead=nhead,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_encoder)

        self.layernorm = nn.LayerNorm(merged_node_dim)

        self.config_prj = nn.Sequential(
            nn.Linear(18, config_dim),
            nn.LayerNorm(config_dim),
        )

        # Define a sequential dense neural network
        self.dense = torch.nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(merged_node_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        node_config_feat: torch.Tensor,
        node_config_ids: torch.Tensor,
    ) -> torch.Tensor:
        # Get graph features
        c = node_config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        # Through convolutional layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()

        x = x[node_config_ids]  # (N, 64) -> (NC, 64)
        
        node_config_feat = self.config_prj(node_config_feat) # (C, NC, 18) -> (C, NC, 64)
        x = torch.cat([x.repeat((c, 1, 1)), node_config_feat], dim=-1)  # (C, NC, 128)

        x = self.layernorm(x)
        x = self.encoder(x)[:, -1, :]  # (C, NC, 128) -> (C, 128)
        x = self.dense(x).flatten()

        return x

## Loss

In [8]:
def listMLE(y_pred, y_true, eps=1e-10, padded_value_indicator=-1):
    """
    ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: loss value, a torch.Tensor
    """
    # shuffle for randomised tie resolution
    random_indices = torch.randperm(y_pred.shape[-1])
    y_pred_shuffled = y_pred[:, random_indices]
    y_true_shuffled = y_true[:, random_indices]

    y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)

    mask = y_true_sorted == padded_value_indicator

    preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
    preds_sorted_by_true[mask] = float("-inf")

    max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)

    preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values

    cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])

    observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max

    observation_loss[mask] = 0.0

    return torch.mean(torch.sum(observation_loss, dim=1))

## Training

In [None]:
layout_random = load_layout(
    LAYOUT_DIR,
    model_type="xla",
    compile_type="random",
)

num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
warmup_ratio = configs["warmup_ratio"]
max_configs = configs["max_configs"]
graph_hidden = configs["graph_hidden"]
graph_out = configs["graph_out"]
num_encoder = configs["num_encoder"]
num_feedforward = configs["num_feedforward"]
nhead = configs["nhead"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]
accum_grad = configs["accum_grad"]

model = SimpleModel(
    hidden_channels=graph_hidden,
    graph_out=graph_out,
    num_encoder=num_encoder,
    num_feedforward=num_feedforward,
    nhead=nhead,
).to(device)

# Create a K-Fold cross-validator with 5 splits
# kfold = KFold(
#     n_splits=5, shuffle=True, random_state=42
# )

# Iterate through each fold
# for fold, (tr_idx, va_idx) in enumerate(kfold.split(df)):

# criterion = MultiElementRankLoss(margin=margin, number_permutations=number_permutations)
num_steps = (
    len(layout_random["train"]) * num_epochs
)
warmup_steps = int(num_steps * warmup_ratio)

optimizer = optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
# scheduler = CosineAnnealingWarmupRestarts(
#     optimizer=optimizer,
#     first_cycle_steps=num_steps,
#     min_lr=min_lr,
#     max_lr=learning_rate,
#     warmup_steps=warmup_steps,
# )

run = wandb.init(
    project=WANDB_PROJECT,
    dir=WANDB_DIR,
    name=WANDB_RUN_NAME,
    config=configs,
    tags=TAGS,
)
run.watch(model, log="all")

best_score_mean = -1

# Training loop with increased epochs
for epoch in range(num_epochs):
    permutation = np.random.permutation(len(layout_random["train"]))
    train_layout = [layout_random["train"][i] for i in permutation]

    train_dataset = LayoutDataset(train_layout, max_configs=max_configs)
    val_dataset = LayoutDataset(layout_random["valid"], max_configs=max_configs)

    model.train()
    pbar = tqdm(range(len(train_dataset)), leave=False)
    loss_sum = 0
    n = 0

    for i in pbar:
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        ) = train_dataset[i]
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
            target.to(device),
        )


        out = model(
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        )

        loss = listMLE(out[None, ...], target[None, ...])
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

        # scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        wandb.log(
            {
                "epoch": epoch,
                # "train/lr": scheduler.get_lr()[0],
                "train/loss": loss.item(),
            }
        )

        pbar.set_description(f"epoch: {epoch} loss: {(loss.item()):.2f}")

    pbar.close()

    model.eval()
    layout_xla_scores = []
    pbar = tqdm(range(len(val_dataset)), leave=False)

    with torch.no_grad():
        for i in pbar:
            (
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
                target,
            ) = val_dataset[i]
            (
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
            ) = (
                node_feat.to(device),
                node_opcode.to(device),
                edge_index.to(device),
                node_config_feat.to(device),
                node_config_ids.to(device),
            )

            out: torch.Tensor = model(
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat,
                node_config_ids,
            )

            score = metric_for_layout_collections(out.detach().cpu().numpy(), target.numpy())
            
            if score == np.nan:
                print(f"score is nan at step {i} with file {val_dataset.files[i]}")
                print(out)

            layout_xla_scores.append(score)

    pbar.close()

    score_mean = np.mean(layout_xla_scores)
    score_max = np.max(layout_xla_scores)

    wandb.log(
        {
            "val/kendalltau": score_mean,
        }
    )

    print(
        f"epoch {epoch}, max_score = {score_max:.4f}, mean_score = {score_mean:.4f},"
    )

    # Update best scores and save the model if the mean score improves
    if score_mean > best_score_mean:
        best_score_mean = score_mean
        best_score_max = score_max
        print(f"Best score updated: {best_score_mean:.4f}")
        torch.save(model.state_dict(), f"{epoch}_{score_mean:.4f}.pth")

run.finish()

## Valid

In [27]:
layout_random = load_layout(
    LAYOUT_DIR,
    model_type="xla",
    compile_type="random",
)

num_epochs = configs["num_epochs"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
warmup_ratio = configs["warmup_ratio"]
max_configs = configs["max_configs"]
graph_hidden = configs["graph_hidden"]
graph_out = configs["graph_out"]
num_encoder = configs["num_encoder"]
num_feedforward = configs["num_feedforward"]
nhead = configs["nhead"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]
accum_grad = configs["accum_grad"]

model = SimpleModel(
    hidden_channels=graph_hidden,
    graph_out=graph_out,
    num_encoder=num_encoder,
    num_feedforward=num_feedforward,
    nhead=nhead,
).to(device)
static_dict = torch.load("../../models/45_0.0455.pth", map_location=device)
model.load_state_dict(static_dict)
model.eval()

SimpleModel(
  (embedding): Embedding(120, 4)
  (convs): ModuleList(
    (0): GCNConv(144, 16)
    (1): GATConv(16, 32, heads=1)
    (2): GATConv(32, 16, heads=1)
    (3): GATConv(16, 48, heads=1)
    (4): GATConv(48, 64, heads=1)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (layernorm): LayerNorm((128,), eps=1e-05, elementwise_affine=Tru

In [28]:
val_dataset = LayoutDataset(files=layout_random["valid"])
_INFERENCE_CONFIGS_BATCH_SIZE = 100

In [29]:
layout_xla_scores = []

with torch.no_grad():
    for data in tqdm(val_dataset):
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
            target,
        ) = data
        (
            node_feat,
            node_opcode,
            edge_index,
            node_config_feat,
            node_config_ids,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            node_config_feat.to(device),
            node_config_ids.to(device),
        )
        target = target.numpy()
        num_configs = target.shape[-1]
        outs = []

        for i in range(0, num_configs, _INFERENCE_CONFIGS_BATCH_SIZE):
            end_i = min(i + _INFERENCE_CONFIGS_BATCH_SIZE, num_configs)
            out: torch.Tensor = model(
                node_feat,
                node_opcode,
                edge_index,
                node_config_feat[i: end_i],
                node_config_ids,
            )
            outs.append(out.detach().cpu())

        score = metric_for_layout_collections(torch.concat(outs).numpy(), target)
        layout_xla_scores.append(score)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:22<00:00,  3.19s/it]


In [30]:
np.mean(layout_xla_scores)

0.02273833870531285