# Phase 1: Data Factory — Generate (PTX, AST, CUDA) Pairs

**Important:** `nvcc` compilation is CPU-only (GPU does NOT help). Use a **CPU runtime** for this notebook to avoid wasting GPU credits. Switch to GPU runtime only for training (notebook 02).

Strategy: batch-compile hundreds of `.cu` files at once with parallel `nvcc -O0`, reading all results back in bulk. This is 10-20x faster than sequential compilation.

In [None]:
# --- Run this cell first on Google Colab to clone the repo ---
import os
if not os.path.exists("/content/DeepPTX"):
    !git clone https://github.com/ns-1456/DeepPTX.git /content/DeepPTX
%cd /content/DeepPTX

## Setup: Install deps and mount Drive (optional)

In Colab: Runtime → Change runtime type → CPU is enough. Install `pyarrow` for Parquet. Optionally mount Google Drive to save the dataset.

In [1]:
!pip install -q pyarrow tqdm

# Optional: mount Google Drive to save the dataset persistently
# from google.colab import drive
# drive.mount("/content/drive")
# OUTPUT_DIR = "/content/drive/MyDrive/NeuralPTX"

OUTPUT_DIR = "."  # saves to repo root; uncomment above for Drive

# ============================================================
# TARGET_PAIRS: start with 10k (~15 min). Scale to 100k later.
# GPU does NOT help here -- nvcc is CPU-only.
# Use a CPU runtime to avoid wasting GPU credits for this step.
# ============================================================
TARGET_PAIRS = 10_000
COMPILE_BATCH = 200  # how many .cu files to write+compile per round

## Add project root and import

In [2]:
import sys, os
# Add repo root to path (works both in Colab after clone, and locally from notebooks/)
REPO_ROOT = "/content/DeepPTX" if os.path.exists("/content/DeepPTX") else os.path.abspath("..")
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

import random
import pandas as pd
from pathlib import Path

from ptx_decompiler.data import (
    get_tier_generator,
    TIER_WEIGHTS,
    Tier1SimpleBinary,
    Tier2NestedArithmetic,
    Tier3UnaryMath,
    Tier4Ternary,
    Tier5TypeDiversity,
    Tier6MultiStatement,
    Tier7SharedMemory,
    ast_to_cuda,
    compile_cuda_to_ptx_silent,
    normalize_ptx,
)
from ptx_decompiler.data.grammar import TIER_CLASSES, sample_tier
print("Imports OK")

## Generation loop

In [3]:
import tempfile, subprocess, time
from tqdm.auto import tqdm

# =================================================================
# Strategy: fire N nvcc processes at once using subprocess.Popen
# (no shell piping, no xargs, no GNU parallel -- just direct calls)
# Then wait for all to finish and read results.
# -O0 = skip optimization (3x faster per file).
# =================================================================

WORK_DIR = tempfile.mkdtemp(prefix="ptx_batch_")
NUM_PARALLEL = max(os.cpu_count() or 2, 2) * 3  # e.g. 2 cores -> 6 simultaneous nvcc
print(f"CPU cores: {os.cpu_count()} | Parallel nvcc: {NUM_PARALLEL}")
print(f"Target: {TARGET_PAIRS} pairs | Batch size: {COMPILE_BATCH}")

# ---- Quick sanity check: can nvcc compile anything at all? ----
test_cu = Path(WORK_DIR) / "test.cu"
test_ptx = Path(WORK_DIR) / "test.ptx"
test_cu.write_text('extern "C" __global__ void k(float* a) { a[0] = 1.0f; }')
r = subprocess.run(["nvcc", "-ptx", "-O0", str(test_cu), "-o", str(test_ptx)],
                    capture_output=True, text=True, timeout=30)
if r.returncode != 0:
    print(f"ERROR: nvcc sanity check failed!\nstderr: {r.stderr}\nstdout: {r.stdout}")
    # Try without -arch flag
    r2 = subprocess.run(["nvcc", "-ptx", "-O0", str(test_cu), "-o", str(test_ptx)],
                        capture_output=True, text=True, timeout=30)
    print(f"Without -arch: returncode={r2.returncode}")
else:
    print(f"nvcc sanity check OK ({test_ptx.stat().st_size} bytes)")
test_cu.unlink(missing_ok=True)
test_ptx.unlink(missing_ok=True)

# Detect correct arch: try sm_75 (T4), fallback to no arch flag
NVCC_ARCH = []
test_cu.write_text('extern "C" __global__ void k(float* a) { a[0] = 1.0f; }')
r = subprocess.run(["nvcc", "-ptx", "-O0", "-arch=sm_75", str(test_cu), "-o", str(test_ptx)],
                    capture_output=True, text=True, timeout=30)
if r.returncode == 0:
    NVCC_ARCH = ["-arch=sm_75"]
    print("Using -arch=sm_75 (T4)")
else:
    print("Using default arch (no -arch flag)")
test_cu.unlink(missing_ok=True)
test_ptx.unlink(missing_ok=True)

def generate_batch_sources(n):
    """Generate n (ast_sexp, cuda_source, tier_id, score) tuples. Pure Python, instant."""
    batch = []
    for _ in range(n):
        tier_id, gen = sample_tier()
        ast = gen.generate()
        ast_sexp = ast.to_sexp()
        cuda_source = ast_to_cuda(ast_sexp)
        batch.append((ast_sexp, cuda_source, tier_id, gen.complexity_score))
    return batch

def compile_batch_parallel(batch, work_dir, max_concurrent):
    """
    Write .cu files, fire max_concurrent nvcc processes at once,
    wait for all, read .ptx results. No shell piping.
    """
    n = len(batch)
    cu_paths = []
    ptx_paths = []

    # Write all .cu files
    for i, (_, cuda_src, _, _) in enumerate(batch):
        cu = Path(work_dir) / f"{i}.cu"
        ptx = Path(work_dir) / f"{i}.ptx"
        cu.write_text(cuda_src, encoding="utf-8")
        cu_paths.append(cu)
        ptx_paths.append(ptx)

    # Fire nvcc processes in waves of max_concurrent
    processes = [None] * n
    for start in range(0, n, max_concurrent):
        end = min(start + max_concurrent, n)
        procs = []
        for i in range(start, end):
            cmd = ["nvcc", "-ptx", "-O0"] + NVCC_ARCH + [str(cu_paths[i]), "-o", str(ptx_paths[i])]
            p = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            procs.append((i, p))
        # Wait for this wave
        for i, p in procs:
            try:
                p.wait(timeout=30)
            except subprocess.TimeoutExpired:
                p.kill()

    # Read results
    results = []
    for i, (ast_sexp, cuda_src, tier_id, score) in enumerate(batch):
        if ptx_paths[i].exists() and ptx_paths[i].stat().st_size > 0:
            ptx_raw = ptx_paths[i].read_text(encoding="utf-8")
            ptx_norm = normalize_ptx(ptx_raw)
            if ptx_norm.strip():
                results.append({
                    "ptx_normalized": ptx_norm,
                    "ast_sexp": ast_sexp,
                    "cuda_source": cuda_src,
                    "tier": tier_id,
                    "complexity_score": score,
                })
        # Cleanup
        cu_paths[i].unlink(missing_ok=True)
        ptx_paths[i].unlink(missing_ok=True)

    return results

# ======================== Main loop ========================
random.seed(42)
data = []
failures = 0
round_num = 0

pbar = tqdm(total=TARGET_PAIRS, desc="Generating pairs", unit="pair",
            bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] fail={postfix}")
pbar.set_postfix_str("0")

while len(data) < TARGET_PAIRS:
    round_num += 1
    need = min(COMPILE_BATCH, int((TARGET_PAIRS - len(data)) * 1.15) + 10)

    batch = generate_batch_sources(need)
    results = compile_batch_parallel(batch, WORK_DIR, NUM_PARALLEL)

    batch_fails = len(batch) - len(results)
    failures += batch_fails

    for row in results:
        if len(data) >= TARGET_PAIRS:
            break
        data.append(row)
    pbar.n = len(data)
    pbar.set_postfix_str(str(failures))
    pbar.refresh()

pbar.close()
total = len(data) + failures
rate = len(data) / max(1, pbar.format_dict.get("elapsed", 1))
print(f"\nDone! {len(data)} pairs in {round_num} rounds")
print(f"Compile failures: {failures} ({failures/max(total,1)*100:.1f}%)")
print(f"Effective rate: {rate:.1f} pairs/sec ({NUM_PARALLEL} parallel nvcc, -O0)")

KeyboardInterrupt: 

## Save to Parquet and validate

In [None]:
df = pd.DataFrame(data)
out_path = Path(OUTPUT_DIR) / "dataset_100k.parquet"
df.to_parquet(out_path, index=False)
print(f"Saved to {out_path}")
print(df["tier"].value_counts().sort_index())
df.head(2)

In [None]:
# Quick validation: round-trip one row
from ptx_decompiler.data import parse_sexp
from ptx_decompiler.data.renderer import CUDARenderer

r = df.iloc[0]
tree = parse_sexp(r["ast_sexp"])
rendered = CUDARenderer().kernel_source(tree)
assert r["cuda_source"].strip() == rendered.strip(), "Round-trip mismatch"
print("Round-trip OK.")