In [1]:
import os
os.chdir('/home/matt/projects/kaggle-fast-or-slow')

from ml.layout_v1.model import SAGEMLP
from ml.layout_v1.dataset import LayoutDataset
from ml.layout_v1.preprocessors import reduce_to_config_node_communities
import torch_geometric
import torch

import wandb
torch.set_float32_matmul_precision('high')

In [76]:
DATA_DIRS = ["data/layout/xla/default/test", "data/layout/xla/random/test"]
INPUT_DIM = 261
GLOBAL_INPUT_DIM = 24

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = LayoutDataset(
    directories=DATA_DIRS,
    mode="lazy",
    #processed_dir="data/processed_layout",
    data_transform=reduce_to_config_node_communities,
)
dataset.load()


# |%%--%%| <ViXrO6qwi1|O6LNSSPtuo>
RUN_ID = "1plqtaqr"
WANDB_RUN_ID = f"kaggle-fast-or-slow/{RUN_ID}"
WEIGHTS_PATH = f"models/{RUN_ID}/checkpoint_1995000.pth"

api = wandb.Api()
run = api.run(WANDB_RUN_ID)

config = run.config


In [80]:
dataset.get(9)

Data(x=[6411, 279], edge_index=[2, 4774], y=0)

In [60]:
GRAPH_DIM = 279

model = SAGEMLP(
        graph_input_dim=GRAPH_DIM,
        sage_layers=config["sage_layers"],
        sage_channels=config["sage_channels"],
        linear_channels=config["linear_channels"],
        linear_layers=config["linear_layers"],
        dropout=0,
    )
model = torch_geometric.compile(model)
model = model.to("cuda")

In [61]:
state_dict = torch.load(WEIGHTS_PATH)
model.load_state_dict(state_dict["model_state_dict"])
model = model.to('cuda')

In [62]:
def make_id_from_file(filepath: str):
    file_id = filepath.removeprefix("data/").removesuffix(".npz")
    file_id = file_id.replace("/test","")
    file_id = file_id.replace("/",":")
    
    return file_id

In [63]:
from collections import defaultdict
from tqdm.auto import tqdm
from torch_geometric.data import Batch


results = defaultdict(dict)

batch_size = 64


next_batch = []
for i in tqdm(range(len(dataset))):
    file_path, config_idx = dataset.idx_to_config[i]
    file_id = make_id_from_file(file_path)
    data = dataset.get(i)
    next_batch.append((data, file_id, config_idx))
    
    if len(next_batch) == batch_size or i == len(dataset) - 1:
        batch_data = [d[0] for d in next_batch]
        file_ids = [d[1] for d in next_batch]
        config_ids = [d[2] for d in next_batch]
        
        with torch.no_grad():
            batch = Batch.from_data_list(batch_data)
            batch = batch.to('cuda')
            output = model(batch).flatten()
        
        for o, f, c in zip(output.tolist(), file_ids, config_ids):
            results[f][c] = o
        
        next_batch = []
            
    
    

  0%|          | 0/16001 [00:00<?, ?it/s]

In [72]:
dataset.get(0)

Data(x=[6411, 279], edge_index=[2, 4774], y=0)

In [64]:
import numpy as np

processed_results = {}
for key, value in results.items():
    best_order = np.array(list(value.values())).argsort()
    stringified = ";".join([str(x) for x in best_order.tolist()])
    processed_results[key] = stringified

In [66]:
import pandas as pd

BASE_FILE = "data/layout_nlp_w_tile.csv"
base_submissions = pd.read_csv(BASE_FILE)

In [67]:
base_submissions.head()

Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,0;22;21;20;19
1,tile:xla:e3a655daa38e34ec240df959b650ac16,1016;252;99;1037;807
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,112;49;166;121;15
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,5487;5660;7723;1906;7311
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,576;554;215;236;624


In [68]:
processed_df = pd.DataFrame.from_dict(processed_results, orient="index").reset_index()

In [69]:
processed_df.columns = ["ID","TopConfigs"]
processed_df.head()

Unnamed: 0,ID,TopConfigs
0,layout:xla:default:fbaa8bb6a1aed9988281085c910...,380;427;460;443;653;359;912;418;684;762;103;39...
1,layout:xla:default:cd708819d3f5103afd6460b15e7...,430;604;601;4;503;739;455;463;115;911;979;509;...
2,layout:xla:default:937ee0eb0d5d6151b7b8252933b...,425;605;29;426;126;789;271;681;347;221;573;721...
3,layout:xla:default:5335ed13823b0a518ee3c79ba44...,688;191;415;881;959;240;760;652;249;412;620;97...
4,layout:xla:default:05ae41e26dd3c4c06390371a042...,969;253;127;184;300;866;95;129;905;939;317;442...


In [70]:
OUTPUT_FILENAME="xla_oos.csv"

joined = base_submissions.merge(processed_df, on="ID", how="left")

joined["TopConfigs"] = joined["TopConfigs_y"].fillna(joined["TopConfigs_x"])
joined = joined.drop(columns=["TopConfigs_x", "TopConfigs_y"])

joined.to_csv(f"data/{OUTPUT_FILENAME}", index=False)


In [71]:
joined

Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,0;22;21;20;19
1,tile:xla:e3a655daa38e34ec240df959b650ac16,1016;252;99;1037;807
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,112;49;166;121;15
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,5487;5660;7723;1906;7311
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,576;554;215;236;624
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,645;356;320;270;219;742;836;623;703;301;143;45...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,703;204;214;924;350;215;377;51;126;297;624;436...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,749;158;352;634;487;846;926;911;726;908;575;89...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,907;853;26;691;939;176;786;22;325;977;18;724;4...


In [82]:
results["layout:xla:default:fbaa8bb6a1aed9988281085c91065c05"]

{0: 1498.820556640625,
 1: 1499.6942138671875,
 2: 1500.4295654296875,
 3: 1461.2281494140625,
 4: 1505.1829833984375,
 5: 1499.2025146484375,
 6: 1499.6873779296875,
 7: 1500.2806396484375,
 8: 1499.5762939453125,
 9: 1461.1785888671875,
 10: 1498.8541259765625,
 11: 1499.44580078125,
 12: 1460.783935546875,
 13: 1499.640380859375,
 14: 1505.178466796875,
 15: 1499.44873046875,
 16: 1499.1961669921875,
 17: 1499.583984375,
 18: 1500.3450927734375,
 19: 1460.81298828125,
 20: 1498.9427490234375,
 21: 1499.2239990234375,
 22: 1499.53271484375,
 23: 1505.08984375,
 24: 1499.4168701171875,
 25: 1505.31103515625,
 26: 1499.7974853515625,
 27: 1498.839599609375,
 28: 1499.5762939453125,
 29: 1498.8199462890625,
 30: 1499.3616943359375,
 31: 1499.021240234375,
 32: 1460.983154296875,
 33: 1499.399169921875,
 34: 1460.98876953125,
 35: 1498.794677734375,
 36: 1505.2801513671875,
 37: 1499.626953125,
 38: 1461.1456298828125,
 39: 1460.487548828125,
 40: 1499.4849853515625,
 41: 1499.5860595703