# Predict AI Model Runtime (Tile Only)

[Google - Fast or Slow? Predict AI Model Runtime](https://www.kaggle.com/competitions/predict-ai-model-runtime) - 
A Kaggle research prediction competition hosted by Google

Train a model that predicts runtime rankings for different tile configurations when executing machine learning models on TPUs. Since a machine learning model can be represented as a graph, with nodes being operations and edges being tensors, a GNN model can be trained to predict the execution times of different configurations and rank them accordingly. A detailed description of the problem is described in the paper [here](https://arxiv.org/abs/2308.13490).

## Prerequisites & Setup

 - [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/stable/) (aka PyG) is used for the majority a graph-based neural network tasks.
 - [pytorch_scatter](https://github.com/rusty1s/pytorch_scatter) is required for the `global_max_pool` function used in the [model](#The-Model) to function efficiently. Note: This will not compile on Kaggle, so it has been removed.
 - Some warnings are intentionally suppressed, as they are harmless. This may change in the future.

In [1]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
import os
iskaggle = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '')
if iskaggle:
    !pip -q install torch_geometric

In [3]:
import multiprocessing
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.init as init
from IPython.display import display, HTML
from pathlib import Path
from tqdm.auto import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import MLP, GraphSAGE
from torch_geometric.nn.pool import global_mean_pool, global_max_pool
from sklearn.preprocessing import MinMaxScaler
from tqdm.contrib.concurrent import process_map

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Import Data

In [5]:
base_path = Path('./')
path = Path(base_path/'predict-ai-model-runtime/npz_all/npz/tile/xla')

Fetch each model file as needed and concatenate them into a dictionary of DataFrames according to each split. This is done via `multiprocessing` using `process_map` from `tdqm`. It offers a noticable speed increase due to the I/O limits introduced by `np.load`.

In [6]:
def process_split(file):
    d = dict(np.load(file))
    d['file'] = file.stem
    return d

def load_df(path):
    splits = ['train', 'valid', 'test']
    cpu_count = multiprocessing.cpu_count()
    df_dict = {}
    for i, split in enumerate(splits, start=1):
        split_path = path/split
        files = list(split_path.glob('*.npz'))
        df_list = []
        pbar = process_map(process_split, files, unit='file',
                           desc=f'({i}/{len(splits)}) {split}, Processing',
                           max_workers=cpu_count, chunksize=1, leave=False)
        for result in pbar:
            df_list.append(result)
        df_dict[split] = pd.DataFrame.from_dict(df_list)
    return df_dict

Define a custom PyTorch `Dataset` for shaping the data from the aformentioned DataFrame dictionary.
- A PyG `Data` object is returned, as it allows for the simplified management of the graph data.
- `torch.from_numpy` is used where possible, as it directly maps memory from a numpy array to a PyTorch tensor without copying.
- `edge_index` has it's axes swapped, as it is stored as `(dest, source)`, and PyG expects edges as `(source, dest)`.
- `edge_index` is transposed, as PyG expects a shape of `(2, num_edges)` instead of `(num_edges, 2)`.
- `torch.contiguous` is called on `edge_index`. PyG recommends this to fully utilize PyTorch [sparse tensors](https://pytorch.org/docs/stable/sparse.html#sparse-coo-docs).
- `config_feat` is multiplied by `100.0`. This helps with training as it results in the magnitude of the graph-level configuration features matching the magnitude of the node-level features.
- `MinMaxScaler` is used on `target`, as we only care about the ranking order. Also, we do not want to target potentially large integers, especially when our loss function, `MSELoss`, requires tensors with a data type of `float32` to function.

In [7]:
class TileDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        node_feat = torch.from_numpy(row['node_feat'])
        node_opcode = torch.from_numpy(row['node_opcode'].astype(np.int32))
        edge_index = torch.from_numpy(row['edge_index'][:, [1, 0]].T).contiguous()
        config_feat = torch.from_numpy(row['config_feat'] * 100.0)
        config_feat_length = torch.tensor(config_feat.shape[0])
        config_runtime = torch.from_numpy(row['config_runtime'])
        target = (config_runtime / row['config_runtime_normalizers'])
        target = MinMaxScaler().fit_transform(target.reshape(-1, 1)).flatten()
        target = torch.from_numpy(target.astype(np.float32))
        return Data(node_feat=node_feat, node_opcode=node_opcode, edge_index=edge_index,
                    config_feat=config_feat, target=target, config_runtime=config_runtime,
                   config_feat_length=config_feat_length).to(device)

## The Model

This model closely resembles the 'late-join' tile model described in the [paper](https://arxiv.org/abs/2308.13490). It is a three layer GraphSAGE model followed by an MLP classifier:

1. All trainable parameters use Kaiming initialization. This is because ReLUs are used throughout. This should help prevent exploding / vanishing gradients, especially at the start of training.
2. Concatenate the node features with opcode embeddings for each node. An opcode embedding vector of size of `128` was chosen.
3. Feed the concatenated tensor along with the edge index into a PyG `GraphSAGE` model with the following characteristics:
   - Three layers
   - ReLU activation functions (except for the final layer)
   - No normalization
   - Mean aggregation
   - Dropout applied after each aggregation
4. 'Concat pool' the GNN's output. This process is similar to `AdaptiveConcatPool2d` as used by [fastai](https://github.com/fastai/fastai). This utilizes PyG's `global_mean_pool` and `global_max_pool` functions to aggregate the output of the GNN, resulting in two seperate graph embeddings. These embeddings are then concatenated. This enhances the information content for the MLP to utilize in the next step.
5. The concatenated graph embedding is then fed into an MLP, which gradually reduces the input into a scalar ranking value for each graph configuration. It has the following features:
   - Four layers
   - ReLU activation functions (except the final layer)
   - No normalization
   - Dropout applied after each layer
6. The ranking is then z-normalized, this helps match the magnitude of predictions with the magnitude of the min-maxed target, allowing for sane values to be fed into the loss function at all times and reducing the risk of exploding / vanishing gradients.

In [8]:
class TileModel(nn.Module):
    def __init__(self):
        super().__init__()
        emb_dim = 128
        self.embedding = nn.Embedding(120, emb_dim)
        self.gnn = GraphSAGE(in_channels=emb_dim + 140, hidden_channels=128,
                             num_layers=3, out_channels=64, dropout=p)
        self.postnet = MLP(channel_list=[128 + 24, 64, 32, 16, 1], dropout=p, norm=None)
        self.initialize_parameters()

    def initialize_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    init.zeros_(m.bias)

    def forward(self, data):
        x = torch.cat([data.node_feat, self.embedding(data.node_opcode)], dim=-1)
        x = self.gnn(x, data.edge_index)
        x = torch.cat([global_mean_pool(x, data.batch), global_max_pool(x, data.batch)], dim=-1)
        x = torch.cat([data.config_feat, x.repeat_interleave(data.config_feat_length, dim=0)], dim=-1)
        x = self.postnet(x)
        x = torch.flatten(x)
        x = (x - x.mean()) / (x.std() + 1e-8)
        return x

## Training

### Hyperparameters

These closely match those used in the [paper](https://arxiv.org/abs/2308.13490). Some notable differences are outlined below in [Setting the Scheduler](#Setting-the-Scheduler)

In [9]:
num_epochs = 6
lr = 1e-4
wd = 1e-4
p = 0.1
max_norm = 1e-2
mom = 0.99
pct_start = 0.7
model = TileModel().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

### Loading the Data

In [10]:
tile_df = load_df(path)

(1/3) train, Processing:   0%|          | 0/5709 [00:00<?, ?file/s]

(2/3) valid, Processing:   0%|          | 0/676 [00:00<?, ?file/s]

(3/3) test, Processing:   0%|          | 0/844 [00:00<?, ?file/s]

In [11]:
train_ds = TileDataset(tile_df['train'])
valid_ds = TileDataset(tile_df['valid'])
test_ds = TileDataset(tile_df['test'])

- `shuffle` is enabled on the training `DataLoader`. This prevents the model from memorizing the order of configurations that are fed to it, resulting in a generalized solution.
- No batching is used, as it added a high amount of complexity to the model as tensors needed to be stacked and padded otherwise. Experimentation revealed that batching resulted in unstable training for the current model.
- Batching did result a large speed increase, due to better GPU and memory utilization, but the accuracy trade-off was not sufficient in this case.
- Memory cannot be pinned as the `Data` objects in our `Dataset`s contain [sparse tensors](https://pytorch.org/docs/stable/sparse.html#sparse-coo-docs).

In [12]:
train_dl = DataLoader(train_ds, shuffle=True)
valid_dl = DataLoader(valid_ds)
test_dl = DataLoader(test_ds)

### Setting the Scheduler

- A high value of [`max_momentum`](#Hyperparameters) is used. In theory, this should help with training because are only using a single graph configuration at each step without any batching. This results in a noisy set of gradients, as the loss can vary greatly from step to step. Using a high value of momentum smooths this noise inverse to the learning rate (with PyTorch's `OneCycleLR` scheduler), causing the model to converge on a solution in a more consistent manner than it would have otherwise.
- The startup curve for the [One-Cycle Scheduler](https://arxiv.org/abs/1708.07120) (`OneCycleLR`) is set to [`0.7`](#Hyperparameters), which is very high. In theory, this helps as the learning rate should be low and the momentum high as we begin training. This is because it would be easy to overfit with this particular ranking task at the beginning of training. For example, if a series of similar configurations were coincidentally fed into the model at the beginning (which will often happen), the model will overfit to said configurations and have a hard time recovering. If, instead, a low inital learning rate is used, a 'coarse' representation of rankings would be seen by the model at the beginning of training. Only after multiple epochs, when the data has been 'seen' multiple times, should the learning rate be increased, this way, the model is more likely to converge on a better solution.

In [13]:
scheduler = OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dl),
                       max_momentum=mom, pct_start=pct_start)

### Evaluation Metric

Each configuration is given a score using the `score_preds` function used inside the [`eval`](#The-Training-Loop) function according to the following evaluation metric:

$$1 - \left(\frac{\text{The best runtime of the top-k predictions}}{\text{The best runtime of all configurations}} - 1\right) = 2 - \frac{\min_{i \in K} y_i}{\min_{i \in A} y_i}$$

This formula measures the amount of slowdown incurred for the top-k (`5` in this case) chosen predictions in the range of `0.0-1.0`. At the end of each epoch, the average slowdown is returned so the model's performance can be evaluated on the validation set. Values closer to `1.0` are good, and values closer to `0.0` are bad. If the value falls outside of this range (`>0.0 & <1.0`), something has likely gone wrong during training.

In [14]:
def score_preds(preds, df):
    score = 0.0
    df_length = len(df)
    for index in range(df_length):
        config_runtime = df.iloc[index]['config_runtime']
        best_pred = min(config_runtime[preds[index]])
        best_targ = min(config_runtime)
        score += 2.0 - best_pred / best_targ
    score /= df_length
    return score

### The Training Loop

In [15]:
def train():
    model.train()
    pbar = tqdm(train_dl, unit='model', leave=False)
    loss_total = 0.0
    n = 0
    optimizer.zero_grad()
    for data in pbar:
        assert len(data) == 1, 'Batching is unsupported.'
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.target)
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        loss_total += loss.item()
        n += 1
        pbar.set_description(f'Training: lr: {current_lr:.2e} loss: {(loss_total / n):.4f}')
    return loss_total / n

In [16]:
def eval():
    preds = []
    model.eval()
    loss_total = 0.0
    pbar = tqdm(valid_dl, desc='Evaluating', unit='model', leave=False)
    for data in pbar:
        out = model(data)
        loss = criterion(out, data.target)
        loss_total += loss.item()
        out = list(out.split(data.config_feat_length.tolist()))
        top_five = [torch.argsort(t.to('cpu').detach())[:5] for t in out]
        preds.extend(top_five)
    loss_total /= len(valid_dl)
    score = score_preds(preds, tile_df['valid'])
    return loss_total, score

In [17]:
epoch = 1
for epoch in range(num_epochs):
    train_loss = train()
    valid_loss, score = eval()
    display(HTML(f'<div style="font-size: 12px;"><b>Epoch \
                {epoch + 1}/{num_epochs}</b>, Train Loss: {train_loss:.4f} \
                Valid Loss: {valid_loss:.4f}, Score: {score:.4f}</div>'))
    epoch += 1

  0%|          | 0/5709 [00:00<?, ?model/s]

Evaluating:   0%|          | 0/676 [00:00<?, ?model/s]

  0%|          | 0/5709 [00:00<?, ?model/s]

Evaluating:   0%|          | 0/676 [00:00<?, ?model/s]

  0%|          | 0/5709 [00:00<?, ?model/s]

Evaluating:   0%|          | 0/676 [00:00<?, ?model/s]

  0%|          | 0/5709 [00:00<?, ?model/s]

Evaluating:   0%|          | 0/676 [00:00<?, ?model/s]

  0%|          | 0/5709 [00:00<?, ?model/s]

Evaluating:   0%|          | 0/676 [00:00<?, ?model/s]

  0%|          | 0/5709 [00:00<?, ?model/s]

Evaluating:   0%|          | 0/676 [00:00<?, ?model/s]

The model typically scores `0.90` and above, meaning that it's predicted configurations for each model in the validation set result in an approximate slowdown of 10% compared to the optimal configurations.

Assuming the test set contains models different to the validation set, it will hopefully perform about the same on the test set.

## Inference

In [18]:
preds = []
model.eval()
pbar = tqdm(test_dl, desc='Inference', leave=False)
for data in pbar:
    out = model(data)
    out = list(out.split(data.config_feat_length.tolist()))
    top_five = [torch.argsort(t.to('cpu').detach())[:5].tolist() for t in out]
    preds.extend(top_five)

Inference:   0%|          | 0/844 [00:00<?, ?it/s]

In [19]:
for index, pred in enumerate(preds[:10]):
    file = tile_df['test'].iloc[index]['file']
    print(f'{file}: {pred}')

04ae9238c653f8ae08f60f2c03615f0b: [368, 29, 408, 305, 549]
85d157d3b1848c6b6fff0c633876e2e6: [1493, 3806, 2579, 6078, 4516]
862900d42397d03be2762e1bf7518bea: [1048, 161, 935, 287, 1344]
0afa527a7022415fda1dd69d11e908a4: [158, 69, 210, 212, 122]
2d09e3ab92e184c561abaf8d9efe7b87: [170, 147, 24, 89, 210]
b33abc6bdac6068f71711aa602e2e67e: [611, 565, 1111, 719, 912]
fa22caf7b94c2c10a419f99cdd516bcd: [184, 216, 146, 84, 150]
55cebbadd2b1e32f6c06779449bd5f9e: [3649, 2697, 5107, 1728, 3474]
4a1d9f054c89c29b6ec14101ec66336c: [107, 12, 66, 2, 55]
148fe6237924d9a64faf8289115d6050: [61, 45, 136, 48, 15]


### Submitting the Results

In [20]:
submission = pd.read_csv(base_path/'predict-ai-model-runtime/sample_submission.csv')
for i, file in enumerate(tile_df['test']['file'].values):
    id = f'tile:xla:{file}'
    submission.loc[submission['ID'] == id, 'TopConfigs'] = ';'.join(map(str, preds[i]))
submission.to_csv('./submission.csv', index=False)

In [21]:
submission

Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,12;24;23;22;21
1,tile:xla:e3a655daa38e34ec240df959b650ac16,171;624;630;890;574
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,41;101;116;202;44
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,1474;9852;4171;7300;6406
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,641;264;252;640;992
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...


The above results look sane to me, and match the format of the sample submission. I will use the [Kaggle API](https://www.kaggle.com/docs/api) to directly submit the predictions for the [competition](https://www.kaggle.com/competitions/predict-ai-model-runtime).

In [22]:
submit = False
if submit:
    from kaggle import api
    comp = 'predict-ai-model-runtime'
    api.competition_submit_cli('./submission.csv', 'Tile GraphSAGE v2', comp)

## Improvements / Reflections

- Batching seemed to cause the model to perform worse with the current implementation. I only have my theories, I think it perhaps has trouble finding a 'better' solution as it seems to converge very early on, even with a low learning rate.
- Any normalization while batching also caused the model to perform worse. It would converge on a single solution for each configuration. So I removed all batch normalization, and then I removed all batching. I have tried several experiments, but I cannot figure out what's going wrong.
- Having the option for 'early-joining' the graph configuration features can improve performance according to the [paper](https://arxiv.org/abs/2308.13490). This would require quite a few changes, and a different notebook may be a better option for pursuing this implementation.
- Using a ranking loss ([ListMLE](http://icml2008.cs.helsinki.fi/papers/167.pdf)) may also improve performance of the model, however it will also increase complexity, as I found most common ranking loss functions were difficult to implement with the current model as they were not differentiable in a straight-forward manner, with PyTorch's [Autograd](https://pytorch.org/tutorials/beginner/introyt/autogradyt_tutorial.html) being very disagreeable at every step of the process.
- Performance could be increased very slightly in the `forward` function of the model, as well as in the `score_preds` function, as some residual batching code is still left over (`repeat_interleave`, some extra tensor dimensions).
- Ensembling could be used to increase the accuracy of predictions very easily. Simply switch out `GraphSAGE` for `GAT` and `GCN` in the [model](#The-Model) and aggregate the results.