# Attention!!!

This is a very simple but bad quality notebook. 
 - I do not use any sort of ranking loss, which would be better.
 - My strategy instead is to min-max scale the relative times (time/normalized) and apply L1-loss
 - My model is also not optimized. It is a relatively simple GNN that embeds the graph and only processes 1 datapoint at a time and is only trained on 1 epoch.
 - The public score would be much better if you paired this submission with a trained model for layout. Since this only contributes to half of the score.
 - Have fun playing around with it!

In [1]:
!pip install torch-geometric torch-scatter

Collecting torch-geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l- \ | / done
[?25h  Getting requirements to build wheel ... [?25l- done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- done
[?25hCollecting torch-scatter
  Downloading torch_scatter-2.1.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
Building wheels for collected packages: torch-geometric, torch-scatter
  Building wheel for torch-geometric (pyproject.toml) ... [?25l- \ | / - \ done
[?25h  Created wheel for torch-geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910454 sha256=6eede70db732e7c4aa7466ae217e67b8e145f026b51078b843b75880

In [2]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm 

import torch
from torch import nn
from torch import Tensor
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch.utils.data import DataLoader, Dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'



We can now load all the data in dataframes to make working with it easier

In [3]:
def load_df(directory):
    splits = ["train", "valid", "test"]
    dfs = dict()
    
    for split in splits:
        path = os.path.join(directory, split)
        files = os.listdir(path)
        list_df = []
        
        for file in files:
            d = dict(np.load(os.path.join(path,file)))
            d['file'] = file
            list_df.append(d)
        dfs[split] = pd.DataFrame.from_dict(list_df)
    return dfs

If you try to run the following cell completely uncommented the Kaggle kernel will run out of memory and crash, so we will have to study the datasets individually

In [4]:
tile_xla = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/tile/xla/")
#layout_nlp_random = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/nlp/random/")
#layout_nlp_default = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/nlp/default/")
#layout_xla_random = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/xla/random/")
#layout_xla_random = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/xla/default/")

# Define Dataset and Model

In [5]:
class TileDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        config_feat = torch.tensor(row['config_feat'].astype(np.float32))
        node_feat = torch.tensor(row['node_feat'].astype(np.float32))
        node_opcode = torch.tensor(row['node_opcode'].astype(np.int32))
        edge_index = torch.tensor(np.swapaxes(row['edge_index'],0,1).astype(np.int32))
        target = (row['config_runtime']/row['config_runtime_normalizers']).astype(np.float32)
        # minmax scale the target, we only care about order
        target = (target-min(target))/(max(target) -min(target))
        target = torch.tensor(target)
        return config_feat,node_feat,node_opcode,edge_index,target

In [6]:
class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_channels, graph_feats, hidden_dim):
        super().__init__()
        op_embedding_dim = 4 # I choose 4-dimensional embedding
        self.embedding = torch.nn.Embedding(120, #120 different op-codes
                                            op_embedding_dim,
                                           )
        assert len(hidden_channels)>0
        in_channels = op_embedding_dim+140
        self.convs = torch.nn.ModuleList()
        last_dim = hidden_channels[-1]
        self.convs.append(GCNConv(in_channels, hidden_channels[0]))
        for i in range(len(hidden_channels)-1):
            self.convs.append(GCNConv(hidden_channels[i], hidden_channels[i+1]))
        self.convs.append(GCNConv(last_dim, graph_feats))
        
        self.dense = torch.nn.Sequential(nn.Linear(graph_feats+24, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 64),
                                         nn.ReLU(),
#                                          nn.Linear(64, 32),
#                                          nn.ReLU(),
#                                          nn.Linear(32, 64),
#                                          nn.ReLU(),
                                         nn.Linear(64, 1),
                                        )

        self.norms = torch.nn.ModuleList()
        for i in range(len(hidden_channels)):
            self.norms.append(torch.nn.BatchNorm1d(hidden_channels[i]))
        self.norms.append(torch.nn.BatchNorm1d(graph_feats))

    def forward(self, x_cfg: Tensor,x_feat: Tensor, x_op: Tensor, edge_index: Tensor) -> Tensor:
        
        #get graph features
        x = torch.concat([x_feat,self.embedding(x_op)],dim = 1)
        #pass though conv layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index).relu()
#             print(x.shape)
#             if i == 0:
#                 x = conv(x, edge_index).relu()
#             else:
#                 x = conv(x, edge_index).relu() + conv(x, edge_index).relu()
            x = self.norms[i](x)
#             print(x.shape)
        # get 1d graph embedding using average pooling
        x_graph = torch.mean(x,0)
        
        
        #put graph data into config data
        x = torch.concat([x_cfg,x_graph.repeat((len(x_cfg),1))],axis=1)
        #put into dense nn
        x = torch.flatten(self.dense(x))
        return x

model = SimpleModel(hidden_channels = [16,32,16,48],graph_feats = 64,hidden_dim=64).to(device)

# Train few Epoches

In [7]:
dataset = TileDataset(tile_xla["train"])

# criterion = torch.nn.SmoothL1Loss()
criterion = torch.nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4,weight_decay = 0.01)

model.train()
loss_sum = 0
n = 0
epoch_num = 30
for now_epoch in range(epoch_num):
    pbar = tqdm(range(len(dataset)))
    print('--------------epoch {}: ------------------'.format(now_epoch))
    for i in pbar:
        cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
        cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)

        out = model(cfg_ft,nd_ft,nd_op,ind)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
        optimizer.step()

        loss_sum+=loss.item()
        n+=1
        pbar.set_description(f'running loss: {(loss_sum/n):.6f},current loss: {(loss.item()):.6f}')

  0%|          | 0/5709 [00:00<?, ?it/s]

--------------epoch 0: ------------------


running loss: 0.027577,current loss: 0.058381: 100%|██████████| 5709/5709 [01:05<00:00, 87.24it/s]
running loss: 0.027512,current loss: 0.002847:   0%|          | 10/5709 [00:00<01:02, 91.43it/s]

--------------epoch 1: ------------------


running loss: 0.018582,current loss: 0.058407: 100%|██████████| 5709/5709 [01:03<00:00, 90.28it/s]
running loss: 0.018564,current loss: 0.002843:   0%|          | 10/5709 [00:00<01:01, 92.16it/s]

--------------epoch 2: ------------------


running loss: 0.015584,current loss: 0.058405: 100%|██████████| 5709/5709 [01:04<00:00, 88.70it/s]
running loss: 0.015577,current loss: 0.014617:   0%|          | 8/5709 [00:00<01:15, 75.60it/s]

--------------epoch 3: ------------------


running loss: 0.014085,current loss: 0.058412: 100%|██████████| 5709/5709 [01:04<00:00, 87.90it/s]
running loss: 0.014079,current loss: 0.003080:   0%|          | 10/5709 [00:00<01:01, 92.56it/s]

--------------epoch 4: ------------------


running loss: 0.013185,current loss: 0.058432: 100%|██████████| 5709/5709 [01:05<00:00, 86.82it/s]
running loss: 0.013181,current loss: 0.002870:   0%|          | 9/5709 [00:00<01:04, 88.91it/s]

--------------epoch 5: ------------------


running loss: 0.012584,current loss: 0.058417: 100%|██████████| 5709/5709 [01:04<00:00, 88.53it/s]
running loss: 0.012582,current loss: 0.007027:   0%|          | 7/5709 [00:00<01:22, 69.25it/s]

--------------epoch 6: ------------------


running loss: 0.012155,current loss: 0.058426: 100%|██████████| 5709/5709 [01:04<00:00, 88.09it/s]
running loss: 0.012153,current loss: 0.003075:   0%|          | 10/5709 [00:00<00:58, 96.70it/s]

--------------epoch 7: ------------------


running loss: 0.011833,current loss: 0.058429: 100%|██████████| 5709/5709 [01:05<00:00, 86.92it/s]
running loss: 0.011831,current loss: 0.003077:   0%|          | 10/5709 [00:00<00:58, 97.47it/s]

--------------epoch 8: ------------------


running loss: 0.011583,current loss: 0.058424: 100%|██████████| 5709/5709 [01:06<00:00, 85.65it/s]
running loss: 0.011581,current loss: 0.003075:   0%|          | 10/5709 [00:00<01:00, 94.36it/s]

--------------epoch 9: ------------------


running loss: 0.011383,current loss: 0.058421: 100%|██████████| 5709/5709 [01:04<00:00, 87.93it/s]
running loss: 0.011381,current loss: 0.002864:   0%|          | 10/5709 [00:00<01:02, 91.19it/s]

--------------epoch 10: ------------------


running loss: 0.011219,current loss: 0.058417: 100%|██████████| 5709/5709 [01:05<00:00, 87.30it/s]
running loss: 0.011217,current loss: 0.003076:   0%|          | 10/5709 [00:00<01:00, 94.44it/s]

--------------epoch 11: ------------------


running loss: 0.011082,current loss: 0.058415: 100%|██████████| 5709/5709 [01:05<00:00, 87.16it/s]
running loss: 0.011081,current loss: 0.002874:   0%|          | 9/5709 [00:00<01:03, 89.36it/s]

--------------epoch 12: ------------------


running loss: 0.010966,current loss: 0.058430: 100%|██████████| 5709/5709 [01:05<00:00, 87.35it/s]
running loss: 0.010965,current loss: 0.012386:   0%|          | 9/5709 [00:00<01:09, 81.97it/s]

--------------epoch 13: ------------------


running loss: 0.010867,current loss: 0.058429: 100%|██████████| 5709/5709 [01:05<00:00, 86.73it/s]
running loss: 0.010866,current loss: 0.003076:   0%|          | 10/5709 [00:00<00:59, 95.49it/s]

--------------epoch 14: ------------------


running loss: 0.010781,current loss: 0.058421: 100%|██████████| 5709/5709 [01:05<00:00, 86.97it/s]
running loss: 0.010780,current loss: 0.003076:   0%|          | 10/5709 [00:00<01:01, 93.12it/s]

--------------epoch 15: ------------------


running loss: 0.010706,current loss: 0.058439: 100%|██████████| 5709/5709 [01:06<00:00, 86.14it/s]
running loss: 0.010705,current loss: 0.002889:   0%|          | 10/5709 [00:00<01:01, 92.25it/s]

--------------epoch 16: ------------------


running loss: 0.010640,current loss: 0.058439: 100%|██████████| 5709/5709 [01:06<00:00, 85.89it/s]
running loss: 0.010639,current loss: 0.002888:   0%|          | 10/5709 [00:00<01:01, 92.73it/s]

--------------epoch 17: ------------------


running loss: 0.010580,current loss: 0.058423: 100%|██████████| 5709/5709 [01:06<00:00, 86.10it/s]
running loss: 0.010580,current loss: 0.001352:   0%|          | 7/5709 [00:00<01:27, 65.26it/s]

--------------epoch 18: ------------------


running loss: 0.010528,current loss: 0.058434: 100%|██████████| 5709/5709 [01:06<00:00, 85.52it/s]
running loss: 0.010527,current loss: 0.007059:   0%|          | 8/5709 [00:00<01:18, 72.51it/s]

--------------epoch 19: ------------------


running loss: 0.010480,current loss: 0.058407: 100%|██████████| 5709/5709 [01:04<00:00, 89.01it/s]
running loss: 0.010480,current loss: 0.001342:   0%|          | 7/5709 [00:00<01:24, 67.55it/s]

--------------epoch 20: ------------------


running loss: 0.010437,current loss: 0.058426: 100%|██████████| 5709/5709 [01:06<00:00, 86.02it/s]
running loss: 0.010436,current loss: 0.003074:   0%|          | 10/5709 [00:00<01:00, 93.69it/s]

--------------epoch 21: ------------------


running loss: 0.010398,current loss: 0.058433: 100%|██████████| 5709/5709 [01:03<00:00, 90.56it/s]
running loss: 0.010397,current loss: 0.003073:   0%|          | 10/5709 [00:00<01:00, 94.56it/s]

--------------epoch 22: ------------------


running loss: 0.010362,current loss: 0.058429: 100%|██████████| 5709/5709 [01:07<00:00, 84.71it/s]
running loss: 0.010362,current loss: 0.003074:   0%|          | 10/5709 [00:00<00:59, 95.55it/s]

--------------epoch 23: ------------------


running loss: 0.010330,current loss: 0.058418: 100%|██████████| 5709/5709 [01:02<00:00, 90.85it/s]
running loss: 0.010329,current loss: 0.003075:   0%|          | 10/5709 [00:00<00:59, 96.50it/s]

--------------epoch 24: ------------------


running loss: 0.010300,current loss: 0.058429: 100%|██████████| 5709/5709 [01:08<00:00, 83.84it/s]
running loss: 0.010299,current loss: 0.003074:   0%|          | 10/5709 [00:00<00:59, 96.31it/s]

--------------epoch 25: ------------------


running loss: 0.010272,current loss: 0.058436: 100%|██████████| 5709/5709 [01:02<00:00, 91.09it/s]
running loss: 0.010271,current loss: 0.003073:   0%|          | 10/5709 [00:00<00:59, 95.24it/s]

--------------epoch 26: ------------------


running loss: 0.010246,current loss: 0.058419: 100%|██████████| 5709/5709 [01:08<00:00, 83.23it/s]
running loss: 0.010246,current loss: 0.003075:   0%|          | 10/5709 [00:00<01:00, 94.54it/s]

--------------epoch 27: ------------------


running loss: 0.010222,current loss: 0.058422: 100%|██████████| 5709/5709 [01:02<00:00, 90.72it/s]
running loss: 0.010222,current loss: 0.002883:   0%|          | 9/5709 [00:00<01:03, 89.42it/s]

--------------epoch 28: ------------------


running loss: 0.010200,current loss: 0.058429: 100%|██████████| 5709/5709 [01:08<00:00, 83.52it/s]
running loss: 0.010200,current loss: 0.003074:   0%|          | 10/5709 [00:00<01:01, 91.96it/s]

--------------epoch 29: ------------------


running loss: 0.010179,current loss: 0.058427: 100%|██████████| 5709/5709 [01:03<00:00, 90.57it/s]


# Evaluate on Validation Dataset

In [8]:
dataset = TileDataset(tile_xla["valid"])
tile_xla_predictions = []
model.eval()

pbar = tqdm(range(len(dataset)))
for i in pbar:
    cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
    cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)
    
    out = model(cfg_ft,nd_ft,nd_op,ind)
    tile_xla_predictions.append(np.argsort(out.detach().cpu().numpy())[:5])

def score_tile(predictions, df):
    score = 0
    for i in range(len(df)):
        predbest = min(df.iloc[i]['config_runtime'][predictions[i]])
        best = min(df.iloc[i]['config_runtime'])
        score +=2 - predbest/best
    score /= len(df)
    return score
score_tile(tile_xla_predictions, tile_xla["valid"])

100%|██████████| 676/676 [00:03<00:00, 178.72it/s]


0.5168656886202452

**0.31 is not bad considering that this model only trained on 1 epoch and is not on a ranking loss!**

# Predict and Submit (only tile:xla predictions)

In [9]:
dataset = TileDataset(tile_xla["test"])
tile_xla_predictions = []
model.eval()
pbar = tqdm(range(len(dataset)))
for i in pbar:
    cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
    cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)
    
    out = model(cfg_ft,nd_ft,nd_op,ind)
    tile_xla_predictions.append(np.argsort(out.detach().cpu().numpy())[:5])

  target = (target-min(target))/(max(target) -min(target))
100%|██████████| 844/844 [00:04<00:00, 178.08it/s]


In [10]:
sub = pd.read_csv('/kaggle/input/predict-ai-model-runtime/sample_submission.csv')
for i,filename in enumerate(tile_xla["test"]['file'].values):
    id = 'tile:xla:' +filename[:-4]
    sub.loc[sub.ID == id,'TopConfigs'] = ';'.join(tile_xla_predictions[i].astype(str))
sub.to_csv('submission.csv',index=False)
sub

Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,0;22;21;20;19
1,tile:xla:e3a655daa38e34ec240df959b650ac16,252;1016;99;618;1037
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,41;116;101;166;202
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,8766;1946;1474;8565;6580
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,655;624;151;159;215
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
