In [1]:
import os
from datetime import datetime
from typing import Literal

import torch
import wandb
import numpy as np
from torch import nn
from torch import optim
import torch_geometric.nn as geonn
from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from gfos.loss import MultiElementRankLoss
from gfos.utils.misc import seed_everything


SEED = 42
DEBUG = False

seed_everything(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

os.environ["WANDB_NOTEBOOK_NAME"] = "tile_gnn_batch.ipynb"


## Configs

In [2]:
TILE_DIR = "../../data/npz_all/npz/tile/xla/"

configs = dict(
    conv_layer="GATConv",
    num_epochs=500,
    batch_size=128,
    graph_dim=128,
    config_dim=128,
    learning_rate=1e-3,
    weight_decay=1e-6,
    min_lr=1e-7,
    max_configs=-1,
    num_configs=512,
    loss_margin=0.5,
    loss_num_permutations=10,
)

WANDB_PROJECT = "gfos"
WANDB_DIR = "../../logs/"
WANDB_RUN_NAME = "tile_xla_mle_batch"
TAGS = ["train", "tile", "xla"]

NUM_VAL_EPOCHS = 10
INFERENCE_CONFIGS_BATCH_SIZE = 1000


## Model

In [3]:
class LayoutModel(torch.nn.Module):
    def __init__(
        self,
        conv_layer: Literal["GATConv", "GCNConv", "SAGEConv"],
        graph_dim: int = 64,
        config_dim: int = 64,
        num_node_feats: int = 140,
        num_config_feats: int = 24,
    ):
        super().__init__()

        conv_layer = getattr(geonn, conv_layer)

        op_embedding_dim = 32
        merged_node_dim = graph_dim + config_dim

        self.embedding = torch.nn.Embedding(
            120,
            op_embedding_dim,
        )
        in_channels = op_embedding_dim + num_node_feats

        self.convs = geonn.Sequential('x, edge_index', [
            # (nn.Dropout(p=0.1), 'x -> x'),
            (conv_layer(in_channels, graph_dim), 'x, edge_index -> x1'),
            nn.LeakyReLU(inplace=True),
            (conv_layer(graph_dim, graph_dim), 'x1, edge_index -> x2'),
            # (lambda x1, x2: x1 + x2, 'x1, x2 -> x3'),
            nn.LeakyReLU(inplace=True),
            (conv_layer(graph_dim, graph_dim), 'x2, edge_index -> x4'),
            nn.LeakyReLU(inplace=True),
            (conv_layer(graph_dim, graph_dim), 'x4, edge_index -> x5'),
            # (lambda x4, x5: x4 + x5, 'x4, x5 -> x6'),
            nn.LeakyReLU(inplace=True),
            # (conv_layer(graph_dim, graph_dim), 'x5, edge_index -> x6'),
            # nn.LeakyReLU(inplace=True),
            # (conv_layer(graph_dim, graph_dim), 'x6, edge_index -> x7'),
            # nn.LeakyReLU(inplace=True),
            # (lambda x4, x5: x4 + x5, 'x5, x7 -> x8'),
        ])

        self.config_prj = nn.Sequential(
            nn.Linear(num_config_feats, config_dim),
            nn.LeakyReLU(),
        )

        # Define a sequential dense neural network
        self.dense = torch.nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(merged_node_dim, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, 64, bias=False),
            nn.LeakyReLU(),
            nn.Linear(64, 1, bias=False),
        )

    def forward(
        self,
        node_feat: torch.Tensor,
        node_opcode: torch.Tensor,
        edge_index: torch.Tensor,
        config_feat: torch.Tensor,
        batch: torch.Tensor,
        config_batch: torch.Tensor,
    ) -> torch.Tensor:
        c = config_feat.size(0)

        x = torch.cat([node_feat, self.embedding(node_opcode)], dim=1)

        x = self.convs(x, edge_index)
        # x = x.mean(dim=0).repeat((c, 1))
        x = geonn.global_mean_pool(x, batch)
        x = x[config_batch]
                
        config_feat = self.config_prj(config_feat)

        x = torch.cat([x, config_feat], dim=-1)
        x = nn.functional.normalize(x, dim=-1)

        x = self.dense(x).flatten()
        
        # x = torch.sigmoid(x)

        return x

## Dataset

In [14]:
from pathlib import Path
from torch.utils.data import Dataset
from dataclasses import dataclass

@dataclass
class Normalizer:
    node_feat_mask: torch.Tensor
    node_feat_min: torch.Tensor
    node_feat_max: torch.Tensor
    node_config_feat_mask: torch.Tensor
    node_config_feat_min: torch.Tensor
    node_config_feat_max: torch.Tensor

    def normalize_node_feat(self, node_feat: torch.Tensor) -> torch.Tensor:
        assert node_feat.ndim == 2, "node_feat must be 2D"
        node_feat = node_feat[:, self.node_feat_mask]

        return (node_feat - self.node_feat_min) / (
            self.node_feat_max - self.node_feat_min
        )

    def normalize_config_feat(
        self, config_feat: torch.Tensor
    ) -> torch.Tensor:
        assert config_feat.ndim == 2, "node_config_feat must be 2D"
        config_feat = config_feat[:, self.node_config_feat_mask]
        return (config_feat - self.node_config_feat_min) / (
            self.node_config_feat_max - self.node_config_feat_min
        )

    @classmethod
    def from_dict(
        cls,
        configs: dict,
    ) -> "Normalizer":
    
        node_feat_mask = torch.tensor(
            configs["node_feat_mask"], dtype=torch.bool
        )
        node_feat_min = torch.tensor(
            configs["node_feat_min"], dtype=torch.float
        )[node_feat_mask]
        node_feat_max = torch.tensor(
            configs["node_feat_max"], dtype=torch.float
        )[node_feat_mask]
        config_feat_mask = torch.tensor(
            configs["config_feat_mask"], dtype=torch.bool
        )
        config_feat_min = torch.tensor(
            configs["config_feat_min"], dtype=torch.float
        )[config_feat_mask]
        config_feat_max = torch.tensor(
            configs["config_feat_max"], dtype=torch.float
        )[config_feat_mask]

        return Normalizer(
            node_feat_mask=node_feat_mask,
            node_feat_min=node_feat_min,
            node_feat_max=node_feat_max,
            node_config_feat_mask=config_feat_mask,
            node_config_feat_min=config_feat_min,
            node_config_feat_max=config_feat_max,
        )

    @classmethod
    def from_json(cls, path):
        import json

        json_data = json.load(open(path))
        return Normalizer.from_dict(json_data)


class TileDataset(Dataset):
    """Load all data in advance."""

    def __init__(
        self,
        files: list[str],
        max_configs: int = -1,
        num_configs: int = -1,
        normalizer: Normalizer = None,
        test: bool = False,  # TODO: remove
    ):
        self.max_configs = max_configs
        self.num_configs = num_configs
        self.files = files
        self.normalizer = normalizer

        self.data = []
        
        # self.files = [f for f in self.files if np.load(f)['config_runtime'].shape[0] > 10]

        for file in tqdm(self.files, desc="Loading data"):
            record = dict(np.load(file))
            model_id = Path(file).stem
            record["model_id"] = model_id
            runtime = record["config_runtime"]
            record["original_config_runtime"] = runtime.copy()

            config_runtime_normalizers = record["config_runtime_normalizers"]
            runtime = runtime / (config_runtime_normalizers + 1e-10)
            runtime = (runtime - runtime.mean()) / runtime.std()

            # TODO: invest when sampling results in better kendall scores on validation set
            # and scores after sampling are close LB scores
            # Currently, use sampling for both training and validation
            if not test:
                runtime_sampled, config_indices = sample_configs(
                    runtime, max_configs
                )
            else:
                runtime_sampled = runtime
                config_indices = torch.arange(len(runtime))

            record["config_runtime"] = runtime_sampled
            record["config_feat"] = record["config_feat"][
                config_indices
            ]
            record["original_config_runtime"] = record[
                "original_config_runtime"
            ][config_indices]

            self.data.append(record)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        record = self.data[idx]
        config_runtime = torch.tensor(
            record["config_runtime"], dtype=torch.float
        )

        if self.num_configs > 0:
            num_configs = self.num_configs
        elif self.max_configs > 0:
            num_configs = self.max_configs
        else:
            num_configs = config_runtime.size(0)

        # Shuffle
        if self.max_configs > 0 or self.num_configs > 0:
            config_indices = torch.randperm(config_runtime.size(0))[
                :num_configs
            ]
        else:
            config_indices = torch.arange(num_configs)
        config_runtime = config_runtime[config_indices]
        original_config_runtime = torch.tensor(
            record["original_config_runtime"], dtype=torch.float
        )[config_indices]

        model_id = record["model_id"]
        node_feat = torch.tensor(record["node_feat"], dtype=torch.float)
        node_opcode = torch.tensor(record["node_opcode"], dtype=torch.long)
        edge_index = torch.tensor(
            np.swapaxes(record["edge_index"], 0, 1), dtype=torch.long
        )

        config_feat = torch.tensor(
            record["config_feat"], dtype=torch.float
        )
        config_feat = config_feat[config_indices]

        if self.normalizer is not None:
            node_feat = self.normalizer.normalize_node_feat(node_feat)
            config_feat = self.normalizer.normalize_config_feat(
                config_feat
            )

        # sample = dict(
        #     model_id=model_id,
        #     node_feat=node_feat,
        #     node_opcode=node_opcode,
        #     edge_index=edge_index,
        #     config_feat=config_feat,
        #     config_runtime=config_runtime,
        #     original_config_runtime=original_config_runtime,
        # )
        
        data = Data(
            x=node_feat,
            edge_index=edge_index,
            y=config_runtime,
            config_feat=config_feat,
            original_config_runtime=original_config_runtime,
            model_id=model_id,
            node_opcode=node_opcode,
        )

        return data
    
    
def sample_configs(
    config_runtime: np.array, max_configs: int
) -> (np.array, np.array):
    """Sample 1/3 max_configs of best configs and 1/3 of worst configs,
    and the rest randomly. Return the sampled configs and indices.
    """
    c = len(config_runtime)
    max_configs = min(max_configs, c) if max_configs > 0 else c
    third = max_configs // 3

    sorted_indices = np.argsort(config_runtime)

    keep_indices = np.concatenate(
        [
            sorted_indices[:third],  # Good configs.
            sorted_indices[-third:],  # Bad configs.
            np.random.choice(
                sorted_indices[third:-third],
                max_configs - 2 * third,
            ),
        ]
    )

    return config_runtime[keep_indices], keep_indices


## Prepare for training

### Expand configs

In [15]:
conv_layer = configs["conv_layer"]
num_epochs = configs["num_epochs"]
batch_size = configs["batch_size"]
graph_dim = configs["graph_dim"]
config_dim = configs["config_dim"]
learning_rate = configs["learning_rate"]
weight_decay = configs["weight_decay"]
min_lr = configs["min_lr"]
max_configs = configs["max_configs"]
num_configs = configs["num_configs"]
margin = configs["loss_margin"]
number_permutations = configs["loss_num_permutations"]


### Model, loss, optimizer, scheduler

In [16]:
# Read data
tile_dir = Path(TILE_DIR)

train_files = list(tile_dir.glob("train/*.npz"))
valid_files = list(tile_dir.glob("valid/*.npz"))
test_files = list(tile_dir.glob("test/*.npz"))

normalizer = Normalizer.from_json(
    "../../data/tile_normalizers.json",
)

# train_dataset = TileDataset(train_files, max_configs=max_configs, num_configs=num_configs, normalizer=normalizer)
# valid_dataset = TileDataset(valid_files, normalizer=normalizer)
test_dataset = TileDataset(test_files, normalizer=normalizer, test=True)

Loading data:   0%|          | 0/844 [00:00<?, ?it/s]

  runtime = (runtime - runtime.mean()) / runtime.std()
Loading data: 100%|██████████| 844/844 [00:01<00:00, 518.99it/s]


In [17]:
num_node_feats = test_dataset[0].x.size(1)
num_config_feats = test_dataset[0].config_feat.size(1)

In [18]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, follow_batch=["config_feat"])

In [20]:
model = LayoutModel(
    conv_layer=conv_layer,
    graph_dim=64,
    config_dim=64,
    num_config_feats=num_config_feats,
    num_node_feats=num_node_feats,
).to(device)

criterion = MultiElementRankLoss(
    margin=margin, number_permutations=number_permutations
)


In [23]:
model_path = "/home/eden/projects/gfos/logs/wandb/run-20231012_220520-t018u75l/files/59_0.9342.pth"
static_dict = torch.load(model_path, map_location=device)
model.load_state_dict(static_dict)
model = model.eval()

## Training

In [24]:
def tile_topk(preds: torch.Tensor, targets: torch.Tensor):
    predbest_idx = torch.topk(preds, k=5, largest=False)[1]
    predbest = targets[predbest_idx].min()
    allbest = targets.min()
    
    return 2 - predbest / allbest

In [25]:
results = {}

with torch.no_grad():
    for data in tqdm(test_loader, desc="test"):
        model_id = data.model_id[0]
        node_feat = data.x
        node_opcode = data.node_opcode
        edge_index = data.edge_index
        config_feat = data.config_feat
        config_runtime = data.y
        batch = data.batch
        config_batch = data.config_feat_batch
        original_config_runtime = data.original_config_runtime

        (
            node_feat,
            node_opcode,
            edge_index,
            config_feat,
            batch,
            config_batch,
        ) = (
            node_feat.to(device),
            node_opcode.to(device),
            edge_index.to(device),
            config_feat.to(device),
            batch.to(device),
            config_batch.to(device),
        )

        num_configs = original_config_runtime.shape[-1]
        outs = []

        for i in range(0, num_configs, INFERENCE_CONFIGS_BATCH_SIZE):
            end_i = min(i + INFERENCE_CONFIGS_BATCH_SIZE, num_configs)
            out: torch.Tensor = model(
                node_feat,
                node_opcode,
                edge_index,
                config_feat[i:end_i],
                batch,
                config_batch[i:end_i],
            )
            outs.append(out.detach().cpu())

        outs = torch.concat(outs)

        pred_idx = np.argsort(outs.numpy())
        results[model_id] = pred_idx.tolist()

test: 100%|██████████| 844/844 [00:07<00:00, 108.14it/s]


In [26]:
with open(f"../../output/tile_xla_001.csv", "w") as f:
    f.write("ID,TopConfigs\n")
    for k, v in results.items():
        model_id = f"tile:xla:"+k
        values = ";".join([str(i) for i in v[:5]])
        f.write(f"{model_id},{values}\n")