diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e4713a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# Temporary # +############# +.cache +.direnv + +# Compiled source # +################### +*.a +*.dll +*.exe +*.o +*.o.d +*.py[ocd] +*.so +*.mod + +# Logs # +######## +*.log + +# Patches # +########### +*.patch +*.diff diff --git a/.spin/cmds.py b/.spin/cmds.py new file mode 100644 index 0000000..e918884 --- /dev/null +++ b/.spin/cmds.py @@ -0,0 +1,16 @@ +import pathlib +import sys +import click + +curdir = pathlib.Path(__file__).parent +rootdir = curdir.parent +toolsdir = rootdir / "tools" +sys.path.insert(0, str(toolsdir)) + + +@click.command(help="Generate sollya c++/python based files") +@click.option("-f", "--force", is_flag=True, help="Force regenerate all files") +def sollya(*, force): + import sollya # type: ignore[import] + + sollya.main(force=force) diff --git a/npsr/trig/data/approx.h.sol b/npsr/trig/data/approx.h.sol new file mode 100644 index 0000000..8692072 --- /dev/null +++ b/npsr/trig/data/approx.h.sol @@ -0,0 +1,101 @@ +// Suppress rounding mode information messages and NaN warnings for boundary angles +suppressmessage(184, 185, 186); // suppress info no rounding, round-up, round-down +suppressmessage(419); // expected nan when angle is at specific multiples causing derivative singularities + +// Generates a 4-element lookup table for fast trigonometric function approximation. +// +// This procedure creates optimized lookup tables that store: +// - Function values split into high/low precision components +// - Derivative information with power-of-2 scaling for efficient interpolation +// +// The table format per angle is: [deriv, sigma, high, low] +// where: +// deriv = actual_derivative - sigma (reduces storage requirements) +// sigma = 2^k, where k = ceil(log2(|actual_derivative|)) +// high = main part of function value +// low = residual for extended precision +// +// Parameters: +// pT - Type descriptor with .kSize, .kDigits, .kRound +// pFunc - Function to approximate (sin or cos) +// pFuncDriv - Derivative of pFunc (cos or -sin) +procedure ApproxLut4_(pT, pFunc, pFuncDriv) { + var r, i, $; + // Table size: 512 entries for 64-bit, 256 for 32-bit + // More entries for double precision to maintain accuracy + // These sizes balance table memory usage with interpolation accuracy: + // - 256 entries = 1.4° spacing for float (sufficient for 24-bit mantissa) + // - 512 entries = 0.7° spacing for double (needed for 53-bit mantissa) + $.num_lut = match pT.kSize + with 64: (2^9) + default: (2^8); + + // Low part rounding configuration: + // - 64-bit: Round to 24 bits with round-to-zero (faster, sufficient for residual) + // - 32-bit: Use full precision with round-to-nearest + $.low_round = match pT.kSize + with 64: ([|24, RZ|]) + default: ([|pT.kDigits, RN|]); + + // Scale factor to convert table index to angle in radians + $.scale = 2.0 * pi / $.num_lut; + + r = [||]; + for i from 0 to $.num_lut - 1 do { + // Sample angle uniformly distributed from 0 to 2π + $.angle = i * $.scale; + + // Compute exact function value + $.exact = pFunc($.angle); + + // Split into high and low parts for extended precision + // High part gets the main value rounded to type precision + $.high = pT.kRound($.exact); + + // Low part stores the residual, rounded to reduced precision + // This allows accurate reconstruction: value ≈ high + low + $.low = pT.kRound(round($.exact - $.high, $.low_round[0], $.low_round[1])); + + // Compute derivative for interpolation + $.deriv_exact = pFuncDriv($.angle); + + // Find power-of-2 scale factor closest to derivative magnitude + // This allows efficient storage and reconstruction + $.k = ceil(log2(abs($.deriv_exact))); + if ($.deriv_exact < 0) then $.k = -$.k; + + // Sigma is the power-of-2 scale factor + $.sigma = 2.0^$.k; + + // Store derivative minus sigma (typically a small value) + // Actual derivative = sigma + stored_deriv + $.deriv = pT.kRound($.deriv_exact - $.sigma); + + r = r @ [|$.deriv, $.sigma, $.high, $.low|]; + }; + + // Format as C array with 4 elements per table entry + return CArrayT(pT, r, 4) @ ";"; +}; + +// Generate C++ header content with specialized lookup tables +// for both float and double precision sine and cosine +// Template declarations (empty for unsupported types) +Append( + "template inline constexpr char kSinApproxTable[] = {};", + "template <> inline constexpr float kSinApproxTable[] = ", + ApproxLut4_(Float32, sin(x), cos(x)), // sin table with cos derivative + "", + "template <> inline constexpr double kSinApproxTable[] = ", + ApproxLut4_(Float64, sin(x), cos(x)), + "", + "template inline constexpr char kCosApproxTable[] = {};", + "template <> inline constexpr float kCosApproxTable[] = ", + ApproxLut4_(Float32, cos(x), -sin(x)), // cos table with -sin derivative + "", + "template <> inline constexpr double kCosApproxTable[] = ", + ApproxLut4_(Float64, cos(x), -sin(x)), + "" +); + +WriteCPPHeader("npsr::trig::data"); diff --git a/npsr/trig/data/constants.h.sol b/npsr/trig/data/constants.h.sol new file mode 100644 index 0000000..336e19b --- /dev/null +++ b/npsr/trig/data/constants.h.sol @@ -0,0 +1,97 @@ +// Suppress rounding mode information messages that are expected during constant generation +suppressmessage(185, 186); // suppress expected info round-up, round-down + +// Helper procedure to format constants array for C++ output +// Takes a type descriptor followed by constant values and formats them +// as a C array with 4 elements per line +procedure KArray_(pArgs = ...) { + var pT; + pT = head(pArgs); + return CArrayT(pT, Constants @ tail(pArgs), 4) @ ";"; +}; + +// Generate C++ header with various π-related constants for Cody-Waite reduction +// These constants enable accurate computation of r = x - n*π with extended precision +// +// The Cody-Waite method splits π into multiple parts: +// r = x - n*π₁ - n*π₂ - n*π₃ ... +// where each πᵢ has limited precision to ensure exact multiplication +// +// Different versions are provided for: +// - FMA (Fused Multiply-Add) vs non-FMA architectures +// - float vs double precision +// - Various precision requirements +Append( + // Generic template declaration + "template inline constexpr char kPi[] = {};", + + // Float π constants for Low precision implementation + "template <> inline constexpr float kPi[] = " @ + KArray_(Float32, pi, [|RN, 24, 24, 24|]), // FMA: 3x24-bit pieces (full precision each) + + "template <> inline constexpr float kPi[] = " @ + KArray_(Float32, pi, [|RD, 11, 11, 11|], [|RN, 24|]), // no FMA: 3x11-bit + 1x24-bit + // The 11-bit pieces ensure n*πᵢ is exact (no rounding) for |n| < 2^13 + + // Double π constants for Low precision implementation + "template <> inline constexpr double kPi[] = " @ + KArray_(Float64, pi, [|RN, 53|], [|RD, 53|], [|RU, 53|]), // FMA: Different roundings for error compensation + + "template <> inline constexpr double kPi[] = " @ + KArray_(Float64, pi, [|RN, 24, 24, 24|], [|RN, 53|]), // no FMA: 3x24-bit + 1x53-bit + // The 24-bit pieces ensure n*πᵢ is exact for |n| < 2^29 + "", + + // Special 35-bit precision π for specific algorithms + "template inline constexpr double kPiPrec35[] = " @ + KArray_(Float64, pi, [|RN, 35|], [|RD, 53|]), + + "template <> inline constexpr double kPiPrec35[] = " @ + KArray_(Float64, pi, [|RN, 24, 24, 24|]), + "", + + // 2π constants for angle wrapping + "template inline constexpr char kPiMul2[] = {};", + "template <> inline constexpr float kPiMul2[] = " @ + KArray_(Float32, pi*2, [|RN, 24, 24|]), // 2x24-bit pieces + + "template <> inline constexpr double kPiMul2[] = " @ + KArray_(Float64, pi*2, [|RN, 53, 53|]), // 2x53-bit pieces + "" +); + +// Non-FMA version of π/16 for High precision implementation +// Special handling: components are reordered [0,2,3,1] for proper evaluation +// Without FMA, multiplication order matters to minimize rounding errors +vNFma = Constants(pi/16, [|RN, 27, 27|], [|RN, 29|], [|RN, 53|]); +Append( + "template inline constexpr double kPiDiv16Prec29[] = " @ + KArray_(Float64, pi/16, [|RN, 53|], [|RN, 29|], [|RN, 53|]), + + // Non-FMA version reorders components: [0], [2], [3], [1] + // This ordering ensures proper evaluation without FMA: + // r = x - n*π₁/16 - n*π₃/16 - n*π₄/16 - n*π₂/16 + "template <> inline constexpr double kPiDiv16Prec29[] = " @ + CArray([|vNFma[0], vNFma[2], vNFma[3], vNFma[1]|], 4) @ ";", + "", + + // Simple scalar constants + "template inline constexpr char kInvPi = '_';", + "template <> inline constexpr float kInvPi = " @ single(1/pi) @ "f;", + "template <> inline constexpr double kInvPi = " @ double(1/pi) @ ";", + "", + + "template inline constexpr char kHalfPi = '_';", + "template <> inline constexpr float kHalfPi = " @ single(pi/2) @ "f;", + "template <> inline constexpr double kHalfPi = " @ double(pi/2) @ ";", + "", + + "template inline constexpr char k16DivPi = '_';", + "template <> inline constexpr float k16DivPi = " @ single(16/pi) @ "f;", + "template <> inline constexpr double k16DivPi = " @ double(16/pi) @ ";", + "" +); +// Dump(); + +WriteCPPHeader("npsr::trig::data"); + diff --git a/npsr/trig/data/data.h.sol b/npsr/trig/data/data.h.sol new file mode 100644 index 0000000..a338906 --- /dev/null +++ b/npsr/trig/data/data.h.sol @@ -0,0 +1,13 @@ +{ +var header; +Append("#include \"npsr/lut-inl.h\""); +for header in [|"constants", "kpi16-inl", "approx", "reduction"|] do { + Append( + "#include \"npsr/trig/data/" @ header @ ".h\"" + ); +}; +}; + +Write(); + + diff --git a/npsr/trig/data/kpi16-inl.h.sol b/npsr/trig/data/kpi16-inl.h.sol new file mode 100644 index 0000000..445f040 --- /dev/null +++ b/npsr/trig/data/kpi16-inl.h.sol @@ -0,0 +1,86 @@ +// Generates lookup table for sin(k·π/16) and cos(k·π/16) values +// Used in the high-precision trigonometric implementation for range reduction +// +// This table supports the algorithm where input x is reduced to: +// x = n*(π/16) + r, where |r| < π/16 +// Then sin(x) and cos(x) are reconstructed using angle addition formulas +// +// Parameters: +// pT - Type descriptor (Float64 in this case) +// pFunc - Function to evaluate (sin or cos) +// pBy - Divisor for π (16 in this case, giving π/16 intervals) +procedure PiDivTable_(pT, pFunc, pBy) { + var r, i, pi_by; + pi_by = pi / pBy; + r = [||]; + + // Generate function values at k*π/16 for k = 0, 1, ..., 15 + for i from 0 to pBy - 1 do { + r = r :. pT.kRound(pFunc(i * pi_by)); + }; + + // Format as C array with 4 elements per line + return CArrayT(pT, r, 4); +}; + +// Generates packed low-precision parts of sin and cos values +// This packing scheme saves memory by storing two 32-bit values in one 64-bit word +// +// The packing works as follows for double (64-bit): +// - sin_low occupies bits [31:0] (lower 32 bits) +// - cos_low occupies bits [63:32] (upper 32 bits) +// +// This is why in the C++ code: +// - cos_lo can be used directly (it's already in the upper bits) +// - sin_lo needs to be extracted with a 32-bit left shift +procedure PiDivPackLowTable_(pT, pFunc0, pFunc1, pBy) { + var r, i, digits, $; + $.pi_by = pi / pBy; + r = [||]; + + // First, compute the low precision parts (residuals after high precision) + for i from 0 to pBy - 1 do { + $.hi0 = pT.kRound(pFunc0(i * $.pi_by)); // High precision sin + $.hi1 = pT.kRound(pFunc1(i * $.pi_by)); // High precision cos + // Low precision parts: exact value minus high precision part + $.hi0_low = pT.kRound(pFunc0(i * $.pi_by) - $.hi0); // sin_low + $.hi1_low = pT.kRound(pFunc1(i * $.pi_by) - $.hi1); // cos_low + r = r @ [|$.hi0_low, $.hi1_low|]; + }; + + // Convert to binary representation for bit manipulation + digits = ToDigits(pT, r); + $.half_size = pT.kSize / 2; // 32 for double + $.lower_bits = 2^$.half_size; // Mask for lower 32 bits + + r = [||]; + // Pack pairs of values into single 64-bit words + for i from 0 to length(digits) - 1 by 2 do { + $.hi0 = digits[i]; // sin_low bits + $.hi1 = digits[i + 1]; // cos_low bits + $.pack = mod(RightShift($.hi0, $.half_size), $.lower_bits); + $.pack = $.pack + $.hi1 - mod($.hi1, $.lower_bits); + r = r :. $.pack; + }; + + // Convert back from binary representation + r = FromDigits(pT, r); + return CArrayT(pT, r, 4); +}; + +Append( + "inline constexpr auto kKPi16Table = MakeLut(", + "// High parts of sin(k·π/16) where k = 0, 1, ..., 15", + PiDivTable_(Float64, sin(x), 16) @ ",", + "// High parts of cos(k·π/16) where k = 0, 1, ..., 15", + PiDivTable_(Float64, cos(x), 16) @ ",", + "// Lower parts of sin(k·π/16) and cos(k·π/16) packed together", + "// Format: bits [63:32] = cos_low, bits [31:0] = sin_low", + "// This packing saves 16×8 = 128 bytes of memory", + PiDivPackLowTable_(Float64, sin(x), cos(x), 16), + "", + ");" +); + +WriteHighwayHeader("npsr::HWY_NAMESPACE::trig"); + diff --git a/npsr/trig/data/reduction.h.sol b/npsr/trig/data/reduction.h.sol new file mode 100644 index 0000000..ba29b15 --- /dev/null +++ b/npsr/trig/data/reduction.h.sol @@ -0,0 +1,66 @@ +// Generates lookup tables for high-precision argument reduction in trigonometric functions +// +// For large arguments, we need to compute x mod 2π (or x mod π/2) accurately. +// This is done by multiplying x by 4/π and extracting the fractional part. +// The table stores precomputed shifted values of 4/π for different exponents. +// +// The technique is based on Payne-Hanek reduction. +// +// Parameters: +// pT - Type descriptor (Float32 or Float64) +// pOffset - Additional shift offset (70 for float, 137 for double) +// These magic constants position the bits of 4/π correctly +// for the extended precision multiplication scheme +procedure ReductionTuble_(pT, pOffset) { + var r, i, j, $; + SetDisplay(decimal); + SetPrec(pT.kDigits * 3); // Triple precision to avoid rounding errors + // Mask for extracting chunks of the specified bit size + $.mask = 2^pT.kSize; // 2^32 for float, 2^64 for double + // The constant 4/π is key to the reduction algorithm + // x mod 2π = fractional_part(x * 4/π) * π/2 + $.scalar = 4 / pi; + r = [||]; + for i from 0 to pT.kMaxExpBiased + 1 do { + // Calculate the effective shift for this exponent + // The shift positions the bits of 4/π to align with the mantissa of x + // Float: exp_shift = i - 127 + 70 = i - 57 + // Double: exp_shift = i - 1023 + 137 = i - 886 + $.exp_shift = i - pT.kBias + pOffset; + // Shift 4/π left by exp_shift bits to get the relevant portion + // This gives us the bits of 4/π that will multiply with x's mantissa + $._int = LeftShift($.scalar, $.exp_shift); + // Extract three chunks for extended precision + // Each chunk is either 32 or 64 bits depending on the type + $.chunks = [||]; + for j in [|pT.kSize * 2, pT.kSize, 0|] do { + $.rshift = RightShift($._int, j); + // Mask to get only the lower bits (32 or 64) + $.apply_mask = mod($.rshift, $.mask); + $.chunks = $.chunks @ [|$.apply_mask|]; + }; + // Format: [high_chunk, middle_chunk, low_chunk] + r = r @ $.chunks; + }; + r = CArrayTU(pT, r, 3) @ ";"; + RestorePrec(); + RestoreDisplay(); + return r; +}; + +Append( + "template inline constexpr T kLargeReductionTable[] = {};", + // The offset 70 means we extract 4/π bits starting from position (exp - 57) + // This aligns with the fractional extraction: 9 + 5 + 18 + 14 = 46 = 2×23 bits + "template <> inline constexpr uint32_t kLargeReductionTable[] = " @ + ReductionTuble_(Float32, 70) @ ";", + "", + // The offset 137 means we extract 4/π bits starting from position (exp - 886) + // This aligns with the fractional extraction: 12 + 28 + 24 + 40 = 104 = 2×52 bits + "template <> inline constexpr uint64_t kLargeReductionTable[] = " @ + ReductionTuble_(Float64, 137) @ ";", + "" +); + +WriteCPPHeader("npsr::trig::data"); + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..87aeb6a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +build-backend = "mesonpy" +requires = [ + "meson-python>=0.15.0", +] + +[project] +name = "numpy_sr" +version = "1.0.0.dev0" +# TODO: add `license-files` once PEP 639 is accepted (see meson-python#88) +license = {file = "LICENSE.txt"} + +description = "NumPy SIMD Routines" +authors = [{name = "NumPy Developers."}] +maintainers = [ + {name = "NumPy Developers", email="numpy-discussion@python.org"} +] +requires-python = ">=3.12" +readme = "README.md" +classifiers = [ + 'Intended Audience :: Science/Research', +] + +[project.urls] +homepage = "https://numpy.org" +documentation = "https://sr.numpy.org/doc/" +source = "https://github.com/numpy/numpy-simd-routines" +download = "https://pypi.org/project/numpy_sr/#files" +tracker = "https://github.com/numpy/numpy-simd-routines/issues" +"release notes" = "https://sr.numpy.org/doc/stable/release" + +[tool.spin] +package = 'numpy_sr' + +[tool.spin.commands] +"Build" = [ + ".spin/cmds.py:sollya", +] diff --git a/tools/sollya/__init__.py b/tools/sollya/__init__.py new file mode 100644 index 0000000..85781cd --- /dev/null +++ b/tools/sollya/__init__.py @@ -0,0 +1,3 @@ +from .__main__ import main + +__all__ = ["main"] diff --git a/tools/sollya/__main__.py b/tools/sollya/__main__.py new file mode 100644 index 0000000..3e99c7a --- /dev/null +++ b/tools/sollya/__main__.py @@ -0,0 +1,316 @@ +"""Generate C++ headers/Python templates from Sollya scripts.""" + +import argparse +import os +import subprocess +import sys +import tempfile +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Final +from dataclasses import dataclass + + +# ANSI color codes for terminal output +if sys.stdout.isatty(): + + class Colors: + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + BLUE = "\033[94m" + MAGENTA = "\033[95m" + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + +else: + + class Colors: + GREEN = "" + YELLOW = "" + RED = "" + BLUE = "" + MAGENTA = "" + RESET = "" + BOLD = "" + DIM = "" + + +def print_colored(text: str, color: str = "", icon: str = "", indent: int = 0) -> None: + """Print text with color, icon, and indentation.""" + prefix = " " * indent + print(f"{prefix}{color}{icon}{text}{Colors.RESET}") + + +def print_divider() -> None: + """Print a visual divider.""" + print(f"{Colors.DIM} {'─' * 82}{Colors.RESET}") + + +@dataclass +class ProcessResult: + """Result of processing a single Sollya file.""" + + sollya_file: Path + output_file: Path + success: bool = False + duration: float = 0.0 + error: str | None = None + + +def format_duration(seconds: float) -> str: + """Format duration in human-readable format.""" + match seconds: + case s if s < 1: + return f"{s * 1000:.0f}ms" + case s if s < 60: + return f"{s:.1f}s" + case s: + return f"{int(s // 60)}m {int(s % 60)}s" + + +def check_sollya_available() -> bool: + """Check if Sollya is available in PATH.""" + try: + return ( + subprocess.run( + ["sollya", "--version"], capture_output=True, check=False + ).returncode + == 0 + ) + except FileNotFoundError: + return False + + +def sollya(sollya_file: Path, output_file: Path) -> ProcessResult: + """Process a Sollya file and generate output.""" + print_colored(f"▶ Processing {sollya_file}", Colors.BLUE) + + # Setup paths + current_dir: Final = Path(__file__).parent + root_dir: Final = current_dir.parent.parent + + res = ProcessResult(sollya_file, output_file) + + try: + relative_output = output_file.resolve().relative_to(root_dir) + relative_sollya = sollya_file.resolve().relative_to(root_dir) + except ValueError as e: + res.error = f"Path resolution error: {e}" + return res + + guard_name = str(relative_output).upper().translate(str.maketrans("/.\\-", "____")) + + # Create temp files and process + with ( + tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_py, + tempfile.NamedTemporaryFile(mode="w", suffix=".sol", delete=False) as temp_sol, + ): + res.duration = time.time() + try: + # Write Sollya script + temp_sol.write( + f"""SOURCE_GUARD_NAME = "{guard_name}"; +SOURCE_FILE_PATH = "{relative_sollya}"; +OUTPUT_FILE_PATH = "{output_file}"; +PYTEMP_FILE_PATH = "{temp_py.name}"; +execute("{current_dir / "core.sol"}"); +{sollya_file.read_text().strip()} +quit; +""" + ) + temp_sol.flush() + + process = subprocess.Popen( + ["sollya", temp_sol.name], + cwd=sollya_file.parent, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + process.wait() + print_divider() + + has_error = False + for line in process.stdout: + fold = line.casefold() + # treats warnings as errors to detect NaNs, + # also some syntax errors are reported as warnings + got_error = "warning" in fold or "error" in fold + got_info = "information" in fold and not got_error + has_error = has_error or got_error + + color = ( + Colors.RED if got_error else Colors.BLUE if got_info else Colors.DIM + ) + icon = "│✗" if got_error else "│ℹ" if got_info else "│" + print_colored(f"{icon} {line.rstrip()}", color, indent=1) + + print_divider() + res.duration = time.time() - res.duration + + if process.returncode != 0 or has_error: + res.error = f"Sollya failed with code {process.returncode}" + return res + + res.success = True + return res + + finally: + # Cleanup + for path in [temp_sol.name, temp_py.name]: + Path(path).unlink(missing_ok=True) + + +def find_sollya_files(root_dir: Path) -> list[tuple[Path, Path]]: + """Find all Sollya files and their corresponding output files.""" + search_path: Final = root_dir / "npsr" + patterns: Final = [ + search_path.glob(f"**/data/{ext}") + for ext in ["*.h.sol", "*.py.sol", "*.csv.sol"] + ] + + return sorted( + [ + (sollya_file, sollya_file.with_suffix("")) + for pattern in patterns + for sollya_file in pattern + ] + ) + + +def process_files(files: list[tuple[Path, Path]]) -> list[ProcessResult]: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + futures = {executor.submit(sollya, sf, of): (sf, of) for sf, of in files} + results = [] + + for future in as_completed(futures): + result = future.result() + results.append(result) + + # Print result + color = Colors.GREEN if result.success else Colors.RED + icon = "✓ " if result.success else "✗ " + status = "Completed" if result.success else "Failed" + file_name = result.output_file if result.success else result.sollya_file + if not result.success: + result.output_file.unlink(missing_ok=True) + duration = format_duration(result.duration) + + print_colored(f"{icon}{status}: {file_name} ({duration})", color) + + if not result.success and result.error: + print_colored(result.error, Colors.RED, indent=1) + + return results + + +def main(*, force: bool = False) -> None: + """Generate all files from Sollya sources.""" + print_colored("🔧 Sollya Code Generator", Colors.BOLD) + print_divider() + + # Check Sollya availability + if not check_sollya_available(): + print_colored("✗ Error: Sollya not found in PATH", Colors.RED) + print_colored("Please install Sollya: https://www.sollya.org/", indent=1) + sys.exit(1) + + # Find files + root_dir: Final = Path(__file__).parent.parent.parent + all_files = find_sollya_files(root_dir) + + if not all_files: + print_colored("⚠ No Sollya files found", Colors.YELLOW) + return + + print(f"Found {Colors.BOLD}{len(all_files)}{Colors.RESET} Sollya files") + + # Partition files + to_process = [] + skipped = [] + + for sollya_file, output_file in all_files: + (skipped if not force and output_file.exists() else to_process).append( + (sollya_file, output_file) + ) + + # Show skipped files + if skipped: + print_colored( + f"Skipping {len(skipped)} existing files (use -f to regenerate)", Colors.DIM + ) + for _, output_file in skipped: + print_colored(f"○ {output_file}", Colors.DIM, indent=1) + + if not to_process: + print_colored("✓ All files up to date", Colors.GREEN) + return + + print(f"Processing {Colors.BOLD}{len(to_process)}{Colors.RESET} files...") + + # Process files and measure time + start_time = time.time() + results = process_files(to_process) + total_duration = time.time() - start_time + + # Summary statistics + successful = sum(r.success for r in results) + failed = len(results) - successful + + print_divider() + print_colored("Summary:", Colors.BOLD) + print_colored(f"✓ Success: {successful}", Colors.GREEN, indent=1) + if failed > 0: + print_colored(f"✗ Failed: {failed}", Colors.RED, indent=1) + if skipped: + print_colored(f"○ Skipped: {len(skipped)}", Colors.DIM, indent=1) + print_colored(f"⏱ Time: {format_duration(total_duration)}", indent=1) + + if len(to_process) > 1: + total_sequential_time = sum(r.duration for r in results) + speedup = total_sequential_time / total_duration + avg_time = format_duration(total_sequential_time / len(results)) + print_colored(f"⚡ Speedup: {speedup:.1f}x (avg {avg_time}/file)", indent=1) + + # Show errors + if errors := [r for r in results if not r.success]: + print_colored("Errors:", Colors.RED) + for result in errors: + print_colored(f"• {result.sollya_file}", indent=1) + if result.error: + print_colored(result.error, Colors.DIM, indent=2) + + +def cli() -> None: + """Command line interface.""" + parser = argparse.ArgumentParser( + description="Generate C++ headers/Python templates from Sollya scripts.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s # Process new/modified files with Sollya output + %(prog)s -f # Force regenerate all files + """, + ) + parser.add_argument( + "-f", + "--force", + action="store_true", + help="Force regenerate all files, even if they exist", + ) + + args = parser.parse_args() + + try: + main(force=args.force) + except KeyboardInterrupt: + print_colored("⚠ Interrupted by user", Colors.YELLOW) + sys.exit(1) + + +if __name__ == "__main__": + cli() diff --git a/tools/sollya/core.sol b/tools/sollya/core.sol new file mode 100644 index 0000000..770fb83 --- /dev/null +++ b/tools/sollya/core.sol @@ -0,0 +1,410 @@ +// Sollya utility functions for generating C++ header files with mathematical constants and lookup tables + +// Initial Sollya configuration +prec = 512; // High precision for accurate constant computation +display = hexadecimal; // Hex display for exact bit representation +verbosity = 4; // Verbose output to detect NaN values +showmessagenumbers = on; + +// Global state management +THE_OUTPUT_LINES = [||]; // Accumulates generated C++ code lines +THE_DISPLAY_STACK = [||]; // Stack for display mode changes +THE_PREC_STACK = [||]; // Stack for precision changes + +// Type descriptors for IEEE 754 floating-point formats +// These objects encapsulate all properties needed for type-specific code generation +Float32 = { + .kName = "float32", + .kSize = 32, // Bits in representation + .kExpBits = 8, // Exponent bits + .kMantBits = 23, // Mantissa bits (excluding implicit bit) + .kDigits = 24, // Precision digits (including implicit bit) + .kDigits10 = 6, // Decimal digits of precision + .kMaxDigits10 = 9, // Max decimal digits for round-trip + .kMinExp = -126, // Minimum exponent (normalized) + .kMinExp10 = -37, // Minimum decimal exponent + .kBias = 127, // Exponent bias + .kMaxExp10 = 38, // Maximum decimal exponent + .kMinExpDenorm = -149, // Minimum exponent (denormalized) + .kMaxExpBiased = 254, // Maximum biased exponent + .kMin = 0x1p-126, // Smallest normalized positive value + .kLowest = -0x1.fffffep127, // Most negative finite value + .kMax = 0x1.fffffep127, // Largest finite value + .kEps = 0x1p-23, // Machine epsilon + .kDenormMin = 0x1p-149, // Smallest denormalized positive value + .kPyName = "float32_t", // Python type name + .kCSFX = "f", // C suffix for literals + .kCName = "float", // C type name + .kCUint = "uint32_t", // Corresponding unsigned integer type + .kCUintSFX = "u", // Suffix for unsigned literals + .kRound = single(x), // Rounding function + .kRoundStr = "single", // String representation of rounding + .kPrintDigits = "printsingle" // Print function for exact hex output +}; + +Float64 = { + .kName = "float64", + .kSize = 64, + .kExpBits = 11, + .kMantBits = 52, + .kDigits = 53, + .kDigits10 = 15, + .kMaxDigits10 = 17, + .kMinExp = -1022, + .kMinExp10 = -307, + .kBias = 1023, + .kMaxExp10 = 308, + .kMinExpDenorm = -1074, + .kMaxExpBiased = 2046, + .kMin = 0x1p-1022, + .kLowest = -0x1.fffffffffffffp1023, + .kMax = 0x1.fffffffffffffp1023, + .kEps = 0x1p-52, + .kDenormMin = 0x1p-1074, + .kPyName = "float64_t", + .kCSFX = "", + .kCName = "double", + .kCUint = "uint64_t", + .kCUintSFX = "ull", + .kRound = double(x), + .kRoundStr = "double", + .kPrintDigits = "printdouble" +}; + +// Bit manipulation procedures +// These emulate C-style bit operations that Sollya doesn't natively support + +// Right shift operation: equivalent to C's >> operator +procedure RightShift(pN, pK) { + return floor(pN / 2^pK); +}; + +// Left shift operation: equivalent to C's << operator +procedure LeftShift(pN, pK) { + return pN * 2^pK; +}; + +// String manipulation procedures + +// Join list elements with separator (like Python's join) +procedure Join(pList, pSep) { + var r, i, v; + r = ""; + for i in pList do { + v = i @ pSep; + r = r @ v; + }; + return r; +}; + +// Join with automatic line breaks for readability +procedure PrettyJoin(pList, pSfx, pSep, pLineEvery) { + var r, i, v, l; + r = ""; + l = 0; + for i in pList do { + v = i @ pSfx; + if (v == "0f") then { + v = "0.0f"; // Avoid printing "0f" for float zero + }; + v = v @ pSep; + r = r @ v; + if (pLineEvery > 0) then { + l = l + 1; + if (l >= pLineEvery) then { + r = r @ "\n"; + l = 0; + }; + }; + }; + return r; +}; + +// Ensure zeros are represented as 0.0 for C++ template deduction +procedure FixZero(pList) { + var r, i; + r = [||]; + for i in pList do { + if (i == 0) then { + r = r :. "0.0"; // Ensure zero is represented as 0.0 + } else { + r = r :. i; + }; + }; + return r; +}; + +// C array formatting procedures +// Generate C array initializer +procedure CArray(pList, pLineEvery) { + return "{\n" @ PrettyJoin(pList, "", ", ", pLineEvery) @ "}"; +}; + +// Generate C array with type-specific suffix (e.g., "f" for float) +procedure CArrayT(pT, pList, pLineEvery) { + return "{\n" @ PrettyJoin(FixZero(pList), pT.kCSFX, ", ", pLineEvery) @ "}"; +}; + +// Generate C array with unsigned integer suffix +procedure CArrayTU(pT, pList, pLineEvery) { + return "{\n" @ PrettyJoin(pList, pT.kCUintSFX, ", ", pLineEvery) @ "}"; +}; + +// Python array formatting +procedure PyArray(pList, pLineEvery) { + return "[\n" @ PrettyJoin(pList, "", ", ", pLineEvery) @ "]"; +}; + +// External tool integration +// These procedures allow Sollya to leverage Python for complex operations + +// Execute Python code and return output +// Uses temporary file to pass code to Python interpreter +procedure PyEval(pCode = ...) { + var code; + write(Join(pCode, "\n")) > PYTEMP_FILE_PATH; + code = bashevaluate("python3 " @ PYTEMP_FILE_PATH); + return code; +}; + +// Execute Sollya code in a subprocess +// Useful for operations that need isolated evaluation +procedure SolEval(pCode = ...) { + var code; + write(Join(pCode, "\n") @ "quit;") > PYTEMP_FILE_PATH; + code = bashevaluate("sollya " @ PYTEMP_FILE_PATH); + return code; +}; + +// Floating-point to integer bit representation conversion + +// Convert array of floats to their integer bit representations +// Uses Sollya's print functions to get exact hex, then Python to convert to decimal +procedure ToDigits(pT, pA) { + var i, code, prfunc, $; + code = ""; + prfunc = pT.kPrintDigits @ "("; + + // Generate Sollya code to print each value in hex + for i in pA do { + code = code @ (prfunc @ i @ ");"); + }; + + // Execute Sollya to get hex representations + $.hex = SolEval(code); + + // Use Python to convert hex to decimal integers + $.ints = PyEval( + "x = '''", + $.hex, + "'''", + "x = [str(int(l, base=0x10)) for l in x.splitlines() if l.strip()]", + "print('[|', ', '.join(x), '|];')" + ); + return parse($.ints); +}; + +// Convert array of integer bit representations back to floats +procedure FromDigits(pT, pA) { + var i, code, $; + SetDisplay(decimal); + // Use Python to format integers as hex strings with Sollya rounding function + $.hex = PyEval( + "x = (", + PyArray(pA, 8), + ")", + "rstr = '" @ pT.kRoundStr @ "'", + "pad = '0" @ pT.kSize / 4 @ "x'", // Pad to full hex width + "x = [f'{rstr}(0x{format(l, pad)})' for l in x]", + "print('[|', ', '.join(x), '|];')" + ); + RestoreDisplay(); + return parse($.hex); +}; + +// Multi-precision constant generation +// Splits a constant into multiple floating-point pieces for extended precision +// See usage in the trigonometric constant generation scripts +procedure Constants(pArgs = ...) { + var r, i, j, $; + r = [||]; + $.exact = head(pArgs); + $.remainder = 0; + for i in tail(pArgs) do { + $.r_mod = head(i); // Rounding mode + for j in tail(i) do { // Precision bits + $.val = round($.exact - $.remainder, j, $.r_mod); + $.remainder = $.remainder + $.val; + r = r :. $.val; + }; + }; + return r; +}; + +// Output accumulation procedures + +// Append lines to the output buffer +procedure Append(pLines = ...) { + suppressmessage(56); // Suppress assignment warnings + THE_OUTPUT_LINES = THE_OUTPUT_LINES @ pLines; + unsuppressmessage(56); +}; + +// Prepend lines to the output buffer +procedure Prepend(pLines = ...) { + suppressmessage(56); + THE_OUTPUT_LINES = pLines @ THE_OUTPUT_LINES; + unsuppressmessage(56); +}; + +// Display mode management with stack + +// Push current display mode and set new one +procedure SetDisplay(pMod) { + suppressmessage(56); + THE_DISPLAY_STACK = display .: THE_DISPLAY_STACK; + unsuppressmessage(56); + display = pMod; +}; + +// Pop and restore previous display mode +procedure RestoreDisplay() { + Assert( + length(THE_DISPLAY_STACK) > 0, + "Display stack is empty, cannot restore display." + ); + display = head(THE_DISPLAY_STACK); + suppressmessage(56); + if (length(THE_DISPLAY_STACK) == 1) then { + THE_DISPLAY_STACK = [||]; + } else { + THE_DISPLAY_STACK = tail(THE_DISPLAY_STACK); + }; + unsuppressmessage(56); +}; + +// Precision management with stack + +// Push current precision and set new one +procedure SetPrec(pPrec) { + suppressmessage(56); + THE_PREC_STACK = prec .: THE_PREC_STACK; + unsuppressmessage(56); + prec = pPrec; +}; + +// Pop and restore previous precision +procedure RestorePrec() { + Assert( + length(THE_PREC_STACK) > 0, + "Prec stack is empty, cannot restore prec." + ); + prec = head(THE_PREC_STACK); + suppressmessage(56); + if (length(THE_PREC_STACK) == 1) then { + THE_PREC_STACK = [||]; + } else { + THE_PREC_STACK = tail(THE_PREC_STACK); + }; + unsuppressmessage(56); +}; + +// Error handling + +// Assert condition with error message +// On failure, deletes output file and kills parent process +procedure Assert(pCondition, pMessage) { + if (!pCondition) then { + "Assertion failed: " @ pMessage; + PyEval( + "import os, signal; from pathlib import Path;", + "Path('" @OUTPUT_FILE_PATH@ "').unlink(missing_ok=True)", + "os.kill(os.getppid(), signal.SIGKILL)" + ); + }; +}; + +// Debug helper - prints accumulated output and exits +procedure Dump() { + var i; + for i in THE_OUTPUT_LINES do { + i; + }; + Assert(false, "Dump"); +}; + +// File writing procedures + +// Write accumulated output to file and clear buffer +procedure Write() { + write(Join(THE_OUTPUT_LINES, "\n")) > OUTPUT_FILE_PATH; + suppressmessage(56); + THE_OUTPUT_LINES = [||]; + unsuppressmessage(56); +}; + +// Generate standard C++ header with namespace and include guards +// Example: WriteCPPHeader("npsr", "trig", "data") creates nested namespaces +procedure WriteCPPHeader(pNamespace = ...) { + var i, $; + $.pre = [| + "// Auto-generated by " @ SOURCE_FILE_PATH, + "// Use `spin sollya -f` to force regeneration", + "#ifndef " @ SOURCE_GUARD_NAME, + "#define " @ SOURCE_GUARD_NAME, + "" + |]; + $.post = [||]; + + // Create nested namespace declarations + for i in pNamespace do { + vNamespace = "namespace " @ i; + $.pre = $.pre :. (vNamespace @ " {"); + $.post = $.post :. ("} // " @ vNamespace); + }; + + $.post = $.post @ [| + "", + "#endif // " @ SOURCE_GUARD_NAME + |]; + + Prepend @ $.pre; + Append @ $.post; + Write(); +}; + +// Generate Highway SIMD library header with special include guard pattern +// Highway uses a toggle pattern for target-specific includes +procedure WriteHighwayHeader(pNamespace = ...) { + var i, $; + $.pre = [| + "// Auto-generated by " @ SOURCE_FILE_PATH, + "// Use `spin sollya -f` to force regeneration", + "#if defined("@ SOURCE_GUARD_NAME @") == defined(HWY_TARGET_TOGGLE) // NOLINT", + "#ifdef " @ SOURCE_GUARD_NAME, + "#undef " @ SOURCE_GUARD_NAME, + "#else", + "#define " @ SOURCE_GUARD_NAME, + "#endif", + "", + "HWY_BEFORE_NAMESPACE();" + |]; + $.post = [||]; + + // Create nested namespace declarations + for i in pNamespace do { + vNamespace = "namespace " @ i; + $.pre = $.pre :. (vNamespace @ " {"); + $.post = $.post :. ("} // " @ vNamespace); + }; + + $.post = $.post @ [| + "HWY_AFTER_NAMESPACE();", + "#endif // " @ SOURCE_GUARD_NAME + |]; + + Prepend @ $.pre; + Append @ [|"inline HWY_ATTR void _dummy_supress_unused_target(){}"|]; // to suppress unused attribute 'target' in '#pragma clang attribute push' + Append @ $.post; + Write(); +};