In [1]:
from pathlib import Path

from gfos.data.utils import load_layout

In [2]:
LAYOUT_DIR = r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout"


In [3]:
layout_files = load_layout(LAYOUT_DIR, compile_type="default", model_type="xla")
train_valid_files = layout_files["train"] + layout_files["valid"]

In [4]:
model_cls = [Path(file).stem for file in train_valid_files]

In [5]:
model_cls

['alexnet_train_batch_32',
 'bert_classifier.2x2.fp32',
 'bert_classifier.2x2.fp32.performance',
 'bert_pretraining.2x2.fp16',
 'bert_pretraining.8x16.fp16',
 'bert_pretraining.8x8.fp32.performance',
 'bert_squad.2x2.fp32',
 'brax_es',
 'efficientnet_b7_eval_batch_1',
 'inception_v2_batch_128_train',
 'inception_v2_batch_8_train',
 'inception_v3_batch_8_train',
 'inference_mlperf_resnet_batch_16',
 'inference_mlperf_resnet_batch_256',
 'inference_mlperf_ssd_1200_batch_1',
 'inference_mlperf_ssd_1200_batch_128',
 'inference_mlperf_ssd_1200_batch_2',
 'magenta',
 'magenta_dynamic',
 'mask_rcnn_batch_16_bf16_img1024',
 'mask_rcnn_batch_4_bf16_img1408',
 'mask_rcnn_resnet50.4x4.bf16.performance',
 'mlperf_nmt_batch_64',
 'mlperf_resnet',
 'mlperf_resnet_batch_128_1_shard',
 'mlperf_ssd_1_shard_batch_8_fast_epoch',
 'mlperf_ssd_2_shard_batch_8_fast_epoch',
 'mlperf_transformer',
 'mnasnet_a1_batch_128',
 'mnasnet_b1_batch_128',
 'ncf.2x2.fp32',
 'resnet50.2x2.fp16',
 'resnet50.2x2.fp32',
 '

In [6]:
model_dict = dict(
    cnn=("ssd", "unet", "resnet", "inception", "xception", "efficientnet", "retinanet"),
    rcnn=("shapemask", "mask_rcnn"),
    transformer=("transformer", "bert"),
)

In [7]:
model_type = {}

catched = 0
for m in model_cls:
    for k, vs in model_dict.items():
        for v in vs:
            if v in m:
                model_type[m] = k
    if m not in model_type:
        model_type[m] = "other"
                

In [8]:
model_type

{'alexnet_train_batch_32': 'other',
 'bert_classifier.2x2.fp32': 'transformer',
 'bert_classifier.2x2.fp32.performance': 'transformer',
 'bert_pretraining.2x2.fp16': 'transformer',
 'bert_pretraining.8x16.fp16': 'transformer',
 'bert_pretraining.8x8.fp32.performance': 'transformer',
 'bert_squad.2x2.fp32': 'transformer',
 'brax_es': 'other',
 'efficientnet_b7_eval_batch_1': 'cnn',
 'inception_v2_batch_128_train': 'cnn',
 'inception_v2_batch_8_train': 'cnn',
 'inception_v3_batch_8_train': 'cnn',
 'inference_mlperf_resnet_batch_16': 'cnn',
 'inference_mlperf_resnet_batch_256': 'cnn',
 'inference_mlperf_ssd_1200_batch_1': 'cnn',
 'inference_mlperf_ssd_1200_batch_128': 'cnn',
 'inference_mlperf_ssd_1200_batch_2': 'cnn',
 'magenta': 'other',
 'magenta_dynamic': 'other',
 'mask_rcnn_batch_16_bf16_img1024': 'rcnn',
 'mask_rcnn_batch_4_bf16_img1408': 'rcnn',
 'mask_rcnn_resnet50.4x4.bf16.performance': 'rcnn',
 'mlperf_nmt_batch_64': 'other',
 'mlperf_resnet': 'cnn',
 'mlperf_resnet_batch_128_1

In [9]:
mapping = {
    "cnn": 0,
    "rcnn": 1,
    "transformer": 2,
    "other": 3,
}

In [10]:
model_labels = {k: mapping[v] for k, v in model_type.items()}
model_labels

{'alexnet_train_batch_32': 3,
 'bert_classifier.2x2.fp32': 2,
 'bert_classifier.2x2.fp32.performance': 2,
 'bert_pretraining.2x2.fp16': 2,
 'bert_pretraining.8x16.fp16': 2,
 'bert_pretraining.8x8.fp32.performance': 2,
 'bert_squad.2x2.fp32': 2,
 'brax_es': 3,
 'efficientnet_b7_eval_batch_1': 0,
 'inception_v2_batch_128_train': 0,
 'inception_v2_batch_8_train': 0,
 'inception_v3_batch_8_train': 0,
 'inference_mlperf_resnet_batch_16': 0,
 'inference_mlperf_resnet_batch_256': 0,
 'inference_mlperf_ssd_1200_batch_1': 0,
 'inference_mlperf_ssd_1200_batch_128': 0,
 'inference_mlperf_ssd_1200_batch_2': 0,
 'magenta': 3,
 'magenta_dynamic': 3,
 'mask_rcnn_batch_16_bf16_img1024': 1,
 'mask_rcnn_batch_4_bf16_img1408': 1,
 'mask_rcnn_resnet50.4x4.bf16.performance': 1,
 'mlperf_nmt_batch_64': 3,
 'mlperf_resnet': 0,
 'mlperf_resnet_batch_128_1_shard': 0,
 'mlperf_ssd_1_shard_batch_8_fast_epoch': 0,
 'mlperf_ssd_2_shard_batch_8_fast_epoch': 0,
 'mlperf_transformer': 2,
 'mnasnet_a1_batch_128': 3,
 

In [11]:
import json

with open("../../data/xla_model_labels.json", "w") as f:
    json.dump(model_labels, f)