In [6]:
from pathlib import Path

import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed

In [3]:
original_layout_root = Path(r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout")
new_feat_root = Path(r"H:\data\gfos\predict-ai-model-runtime\new_node_feat")

output_root = Path(r"H:\data\gfos\predict-ai-model-runtime\npz_all\npz\layout_new")

output_root.mkdir(parents=True, exist_ok=True)

In [4]:
def merge_new_feat(
    original_layout_root: Path,
    new_feat_root: Path,
    output_root: Path,
    source: str,
    search: str,
    split: str,
    model_id: str,
):
    original_file = original_layout_root / source / search / split / f"{model_id}.npz"
    new_feat_file = new_feat_root / source / split / f"{model_id}.npy"
    save_dir = output_root / source / search / split
    save_path = save_dir / f"{model_id}.npz"
    
    save_dir.mkdir(parents=True, exist_ok=True)
    
    data = dict(np.load(original_file))
    new_feat = np.load(new_feat_file)
    
    assert len(data["node_feat"]) == len(new_feat)
    
    data["node_feat"] = new_feat
    np.savez_compressed(save_path, **data)

In [7]:
params = []

for source_dir in original_layout_root.iterdir():
    source = source_dir.name
    for search_dir in source_dir.iterdir():
        search = search_dir.name
        for split_dir in search_dir.iterdir():
            split = split_dir.name
            params.extend(
                [
                    (
                        original_layout_root,
                        new_feat_root,
                        output_root,
                        source,
                        search,
                        split,
                        file.stem,
                    )
                    for file in split_dir.glob("*.npz")
                ]
            )


In [None]:
_ = Parallel(n_jobs=8)(delayed(merge_new_feat)(*param) for param in tqdm(params))

## Check consistency

In [20]:
NUM_CHECK = 5

for source_dir in original_layout_root.iterdir():
    source = source_dir.name
    for search_dir in source_dir.iterdir():
        search = search_dir.name
        for split_dir in search_dir.iterdir():
            split = split_dir.name
            checked = 0
            for file in split_dir.glob("*.npz"):
                checked += 1
                if checked > NUM_CHECK:
                    break

                new_file = output_root / source / search / split / file.name
                
                old_data = dict(np.load(file))
                new_data = dict(np.load(new_file))
                
                for key in old_data.keys():
                    if key != "node_feat":
                        np.testing.assert_array_equal(old_data[key], new_data[key])
                    else:
                        np.testing.assert_array_equal(
                            old_data[key], 
                            np.concatenate(
                            [
                                new_data[key][:, :134],
                                new_data[key][:, -6:],
                            ],
                            axis=1,)
                        )
            print(f"{source} {search} {split} checked")

nlp default test checked
nlp default train checked
nlp default valid checked
nlp random test checked
nlp random train checked
nlp random valid checked
xla default test checked
xla default train checked
xla default valid checked
xla random test checked
xla random train checked
xla random valid checked
