In [None]:
from multiprocessing import Pool, cpu_count

import pandas as pd
from datasets import load_dataset
from tqdm.auto import tqdm

from deep_mca.utils import disassemble_hex

In [None]:
dataset = load_dataset("stevenhe04/x86-bb-24m")
df = dataset["train"].to_pandas()

df["hex"] = df["hex"].fillna("").astype(str).str.strip()
df = df[df["hex"] != ""].reset_index(drop=True)

In [None]:
with Pool(cpu_count()) as pool:
    results = list(
        tqdm(
            pool.imap(disassemble_hex, df["hex"], chunksize=512),
            total=len(df),
        )
    )

pretrain_x86 = pd.DataFrame(results, columns=["instructions", "valid"])
pretrain_x86.to_parquet("../data/x86-pretrain-raw.parquet", index=False)

print(f"Total rows: {len(pretrain_x86)}")
print(f"Valid rows: {pretrain_x86['valid'].sum()}")
print(f"Invalid rows: {(~pretrain_x86['valid']).sum()}")

Remove invalid rows:

In [None]:
dataset = load_dataset("henryc13/x86-pretrain-raw", split="train")
df = dataset.to_pandas()

print(f"Total rows: {len(df)}")
print(f"Valid rows: {df['valid'].sum()}")
print(f"Invalid rows: {(~df['valid']).sum()}")

df_filtered = df[df["valid"]].drop(columns=["valid"]).reset_index(drop=True)
print(f"Filtered rows: {len(df_filtered)}")

df_filtered.to_parquet("../data/x86-pretrain.parquet", index=False)

Total rows: 24442085
Valid rows: 23048454
Invalid rows: 1393631
Filtered rows: 23048454
