# Phase 1: Data Factory — Generate (PTX, AST, CUDA) Pairs

**Important:** `nvcc` compilation is CPU-only (GPU does NOT help). Use a **CPU runtime** for this notebook to avoid wasting GPU credits. Switch to GPU runtime only for training (notebook 02).

Strategy: batch-compile hundreds of `.cu` files at once with parallel `nvcc -O0`, reading all results back in bulk. This is 10-20x faster than sequential compilation.

In [None]:
# --- Run this cell first on Google Colab to clone the repo ---
import os
if not os.path.exists("/content/DeepPTX"):
    !git clone https://github.com/ns-1456/DeepPTX.git /content/DeepPTX
%cd /content/DeepPTX

## Setup: Install deps and mount Drive (optional)

In Colab: Runtime → Change runtime type → CPU is enough. Install `pyarrow` for Parquet. Optionally mount Google Drive to save the dataset.

In [1]:
!pip install -q pyarrow tqdm

# Optional: mount Google Drive to save the dataset persistently
# from google.colab import drive
# drive.mount("/content/drive")
# OUTPUT_DIR = "/content/drive/MyDrive/NeuralPTX"

OUTPUT_DIR = "."  # saves to repo root; uncomment above for Drive

# ============================================================
# TARGET_PAIRS: start with 10k (~15 min). Scale to 100k later.
# GPU does NOT help here -- nvcc is CPU-only.
# Use a CPU runtime to avoid wasting GPU credits for this step.
# ============================================================
TARGET_PAIRS = 10_000
COMPILE_BATCH = 200  # how many .cu files to write+compile per round

## Add project root and import

In [2]:
import sys, os
# Add repo root to path (works both in Colab after clone, and locally from notebooks/)
REPO_ROOT = "/content/DeepPTX" if os.path.exists("/content/DeepPTX") else os.path.abspath("..")
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

import random
import pandas as pd
from pathlib import Path

from ptx_decompiler.data import (
    get_tier_generator,
    TIER_WEIGHTS,
    Tier1SimpleBinary,
    Tier2NestedArithmetic,
    Tier3UnaryMath,
    Tier4Ternary,
    Tier5TypeDiversity,
    Tier6MultiStatement,
    Tier7SharedMemory,
    ast_to_cuda,
    compile_cuda_to_ptx_silent,
    normalize_ptx,
)
from ptx_decompiler.data.grammar import TIER_CLASSES, sample_tier
print("Imports OK")

## Generation loop

In [3]:
import tempfile, subprocess, time, glob
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm

# =================================================================
# BATCH compilation strategy:
#   1. Generate N random ASTs + CUDA sources (instant, pure Python)
#   2. Write all N .cu files to disk at once
#   3. Compile ALL of them in one shot with GNU parallel / xargs
#      (one nvcc per file, but OS schedules them across all cores)
#   4. Read back .ptx files and normalize
# This avoids Python subprocess overhead per file.
# =================================================================

WORK_DIR = tempfile.mkdtemp(prefix="ptx_batch_")
NUM_PARALLEL = max(os.cpu_count() or 2, 2) * 2  # parallel nvcc jobs
print(f"CPU cores: {os.cpu_count()} | Batch dir: {WORK_DIR}")
print(f"Target: {TARGET_PAIRS} pairs | Batch size: {COMPILE_BATCH} | Parallel jobs: {NUM_PARALLEL}")

# Check if GNU parallel is available (faster than xargs)
has_parallel = subprocess.run(["which", "parallel"], capture_output=True).returncode == 0
print(f"GNU parallel available: {has_parallel}")

def generate_batch_sources(n):
    """Generate n (ast_sexp, cuda_source, tier_id, score) tuples. Pure Python, instant."""
    batch = []
    for _ in range(n):
        tier_id, gen = sample_tier()
        ast = gen.generate()
        ast_sexp = ast.to_sexp()
        cuda_source = ast_to_cuda(ast_sexp)
        batch.append((ast_sexp, cuda_source, tier_id, gen.complexity_score))
    return batch

def write_cu_files(batch, work_dir):
    """Write .cu files to disk. Returns list of (index, ast_sexp, cuda_source, tier, score)."""
    for i, (ast_sexp, cuda_source, tier_id, score) in enumerate(batch):
        cu_path = Path(work_dir) / f"{i}.cu"
        cu_path.write_text(cuda_source, encoding="utf-8")
    return batch

def compile_all_cu(work_dir, n, num_parallel):
    """Compile all .cu -> .ptx in work_dir using xargs/parallel. Single OS call."""
    cu_files = [str(Path(work_dir) / f"{i}.cu") for i in range(n)]
    cu_list = "\n".join(cu_files)

    if has_parallel:
        cmd = f'echo "{cu_list}" | parallel -j {num_parallel} "nvcc -ptx -O0 -arch=sm_75 {{}} -o {{.}}.ptx 2>/dev/null"'
    else:
        # xargs fallback
        cmd = f'echo "{cu_list}" | xargs -P {num_parallel} -I {{}} sh -c \'nvcc -ptx -O0 -arch=sm_75 "$1" -o "${{1%.cu}}.ptx" 2>/dev/null\' _ {{}}'

    subprocess.run(cmd, shell=True, capture_output=True, timeout=300)

def read_ptx_results(batch, work_dir):
    """Read .ptx files back, normalize, return completed rows."""
    results = []
    for i, (ast_sexp, cuda_source, tier_id, score) in enumerate(batch):
        ptx_path = Path(work_dir) / f"{i}.ptx"
        if ptx_path.exists():
            ptx_raw = ptx_path.read_text(encoding="utf-8")
            ptx_normalized = normalize_ptx(ptx_raw)
            if ptx_normalized.strip():
                results.append({
                    "ptx_normalized": ptx_normalized,
                    "ast_sexp": ast_sexp,
                    "cuda_source": cuda_source,
                    "tier": tier_id,
                    "complexity_score": score,
                })
    return results

def cleanup_batch(work_dir, n):
    """Remove .cu and .ptx files from the batch."""
    for i in range(n):
        for ext in (".cu", ".ptx"):
            p = Path(work_dir) / f"{i}{ext}"
            try:
                p.unlink(missing_ok=True)
            except OSError:
                pass

# ======================== Main generation loop ========================
random.seed(42)
data = []
failures = 0
round_num = 0

pbar = tqdm(total=TARGET_PAIRS, desc="Generating pairs", unit="pair",
            bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] fail={postfix}")
pbar.set_postfix_str("0")

while len(data) < TARGET_PAIRS:
    round_num += 1
    need = min(COMPILE_BATCH, int((TARGET_PAIRS - len(data)) * 1.2) + 10)

    # Step 1: generate CUDA sources (instant)
    batch = generate_batch_sources(need)

    # Step 2: write .cu files
    write_cu_files(batch, WORK_DIR)

    # Step 3: compile all at once (one OS call, parallel nvcc)
    compile_all_cu(WORK_DIR, len(batch), NUM_PARALLEL)

    # Step 4: read results
    results = read_ptx_results(batch, WORK_DIR)
    batch_fails = len(batch) - len(results)
    failures += batch_fails

    for row in results:
        if len(data) >= TARGET_PAIRS:
            break
        data.append(row)
        pbar.update(1)

    pbar.set_postfix_str(str(failures))

    # Step 5: cleanup
    cleanup_batch(WORK_DIR, len(batch))

pbar.close()
total = len(data) + failures
print(f"\nDone! {len(data)} pairs in {round_num} rounds")
print(f"Compile failures: {failures} ({failures/max(total,1)*100:.1f}%)")
print(f"Strategy: batch {COMPILE_BATCH} files x {NUM_PARALLEL} parallel nvcc (-O0)")

KeyboardInterrupt: 

## Save to Parquet and validate

In [None]:
df = pd.DataFrame(data)
out_path = Path(OUTPUT_DIR) / "dataset_100k.parquet"
df.to_parquet(out_path, index=False)
print(f"Saved to {out_path}")
print(df["tier"].value_counts().sort_index())
df.head(2)

In [None]:
# Quick validation: round-trip one row
from ptx_decompiler.data import parse_sexp
from ptx_decompiler.data.renderer import CUDARenderer

r = df.iloc[0]
tree = parse_sexp(r["ast_sexp"])
rendered = CUDARenderer().kernel_source(tree)
assert r["cuda_source"].strip() == rendered.strip(), "Round-trip mismatch"
print("Round-trip OK.")