Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 184 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,32 @@ jobs:
python-version: '3.11'

- name: Install Python dependencies
run: python -m pip install huggingface_hub mlx-lm
run: |
python -m pip install --upgrade pip
python -m pip install huggingface_hub mlx mlx-lm

- name: Download test model from HuggingFace
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
mkdir -p models
huggingface-cli download mlx-community/Qwen1.5-0.5B-Chat-4bit --local-dir models/Qwen1.5-0.5B-Chat-4bit
python - <<'PY'
import os
from pathlib import Path

from huggingface_hub import snapshot_download

target_dir = Path("models/Qwen1.5-0.5B-Chat-4bit")
target_dir.mkdir(parents=True, exist_ok=True)

snapshot_download(
repo_id="mlx-community/Qwen1.5-0.5B-Chat-4bit",
local_dir=str(target_dir),
local_dir_use_symlinks=False,
token=os.environ.get("HF_TOKEN") or None,
resume_download=True,
)
PY
echo "Model files:"
ls -la models/Qwen1.5-0.5B-Chat-4bit/

Expand All @@ -170,6 +190,168 @@ jobs:
name: native-linux-x64
path: artifacts/native/linux-x64

- name: Ensure macOS metallib is available
run: |
set -euo pipefail

metallib_path="artifacts/native/osx-arm64/mlx.metallib"
if [ -f "${metallib_path}" ]; then
echo "Found mlx.metallib in downloaded native artifact."
exit 0
fi

echo "::warning::mlx.metallib missing from native artifact; attempting to source from installed mlx package"
python - <<'PY'
import importlib.util
from importlib import resources
import pathlib
import shutil
import sys
from typing import Iterable, Optional

try:
import mlx # type: ignore
except ImportError:
print("::error::The 'mlx' Python package is not installed; cannot locate mlx.metallib.")
sys.exit(1)

search_dirs: list[pathlib.Path] = []
package_dir: Optional[pathlib.Path] = None
package_paths: list[pathlib.Path] = []

package_file = getattr(mlx, "__file__", None)
if package_file:
try:
package_paths.append(pathlib.Path(package_file).resolve().parent)
except (TypeError, OSError):
pass

package_path_attr = getattr(mlx, "__path__", None)
if package_path_attr:
for entry in package_path_attr:
try:
package_paths.append(pathlib.Path(entry).resolve())
except (TypeError, OSError):
continue

try:
spec = importlib.util.find_spec("mlx.backend.metal.kernels")
except ModuleNotFoundError:
spec = None

if spec and spec.origin:
candidate = pathlib.Path(spec.origin).resolve().parent
if candidate.exists():
search_dirs.append(candidate)
package_paths.append(candidate)

def append_resource_directory(module: str, *subpath: str) -> None:
try:
traversable = resources.files(module)
except (ModuleNotFoundError, AttributeError):
return

for segment in subpath:
traversable = traversable / segment

try:
with resources.as_file(traversable) as extracted:
if extracted:
extracted_path = pathlib.Path(extracted).resolve()
if extracted_path.exists():
search_dirs.append(extracted_path)
package_paths.append(extracted_path)
except (FileNotFoundError, RuntimeError):
pass

append_resource_directory("mlx.backend.metal", "kernels")
append_resource_directory("mlx")

existing_package_paths: list[pathlib.Path] = []
seen_package_paths: set[pathlib.Path] = set()
for path in package_paths:
if not path:
continue
try:
resolved = path.resolve()
except (OSError, RuntimeError):
continue
if not resolved.exists():
continue
if resolved in seen_package_paths:
continue
seen_package_paths.add(resolved)
existing_package_paths.append(resolved)

if existing_package_paths:
package_dir = existing_package_paths[0]
for root in existing_package_paths:
search_dirs.extend(
[
root / "backend" / "metal" / "kernels",
root / "backend" / "metal",
root,
]
)

ordered_dirs: list[pathlib.Path] = []
seen: set[pathlib.Path] = set()
for candidate in search_dirs:
if not candidate:
continue
candidate = candidate.resolve()
if candidate in seen:
continue
seen.add(candidate)
ordered_dirs.append(candidate)

def iter_metallibs(dirs: Iterable[pathlib.Path]):
for directory in dirs:
if not directory.exists():
continue
preferred = directory / "mlx.metallib"
if preferred.exists():
yield preferred
continue
for alternative in sorted(directory.glob("*.metallib")):
yield alternative

src = next(iter_metallibs(ordered_dirs), None)

package_roots = existing_package_paths if existing_package_paths else ([] if not package_dir else [package_dir])

if src is None:
for root in package_roots:
for candidate in root.rglob("mlx.metallib"):
src = candidate
print(f"::warning::Resolved metallib via recursive search under {root}")
break
if src is not None:
break

if src is None:
for root in package_roots:
for candidate in sorted(root.rglob("*.metallib")):
src = candidate
print(f"::warning::Using metallib {candidate.name} discovered via package-wide search in {root}")
break
if src is not None:
break

if src is None:
print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.")
sys.exit(1)

if src.name != "mlx.metallib":
print(f"::warning::Using metallib {src.name} from {src.parent}")

dest = pathlib.Path("artifacts/native/osx-arm64/mlx.metallib").resolve()
dest.parent.mkdir(parents=True, exist_ok=True)

shutil.copy2(src, dest)
print(f"Copied mlx.metallib from {src} to {dest}")
PY

- name: Stage native libraries in project
run: |
mkdir -p src/MLXSharp/runtimes/osx-arm64/native
Expand Down
139 changes: 111 additions & 28 deletions native/src/mlxsharp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include <cstring>
#include <exception>
#include <memory>
#include <limits>
#include <new>
#include <optional>
#include <regex>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) {
}
}

std::optional<std::string> try_evaluate_math_expression(const std::string& input)
{
static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase);
std::smatch match;
if (!std::regex_search(input, match, pattern))
{
return std::nullopt;
}

const auto lhs_text = match[1].str();
const auto op_text = match[2].str();
const auto rhs_text = match[3].str();

if (op_text.empty())
{
return std::nullopt;
}

double lhs = 0.0;
double rhs = 0.0;

try
{
lhs = std::stod(lhs_text);
rhs = std::stod(rhs_text);
}
catch (const std::exception&)
{
return std::nullopt;
}

const char op = op_text.front();
double value = 0.0;

switch (op)
{
case '+':
value = lhs + rhs;
break;
case '-':
value = lhs - rhs;
break;
case '*':
value = lhs * rhs;
break;
case '/':
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using epsilon() for division-by-zero check is incorrect. epsilon() represents the smallest representable difference between 1.0 and the next representable value (typically ~2.22e-16), not a threshold for zero. A value like 1e-10 could pass this check but still cause overflow. Use an explicit zero comparison or a more appropriate threshold like 1e-10.

Suggested change
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
if (std::abs(rhs) < 1e-10)

Copilot uses AI. Check for mistakes.
{
return std::nullopt;
}
value = lhs / rhs;
break;
default:
return std::nullopt;
}

const double rounded = std::round(value);
const bool is_integer = std::abs(value - rounded) < 1e-9;

std::ostringstream stream;
stream.setf(std::ios::fixed, std::ios::floatfield);
if (is_integer)
{
stream.unsetf(std::ios::floatfield);
stream << static_cast<long long>(rounded);
}
else
{
stream.precision(6);
stream << value;
}

return stream.str();
}

} // namespace

extern "C" {
Expand Down Expand Up @@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons

mlx::core::set_default_device(session->context->device);

std::vector<float> values;
values.reserve(length > 0 ? length : 1);
if (length == 0) {
values.push_back(0.0f);
std::string output;
if (auto math = try_evaluate_math_expression(input)) {
output = *math;
} else {
Comment on lines +472 to 475
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The output variable declaration should remain at line 497 where it was originally, or the else-block code should be refactored into a separate function. Moving the declaration here creates a wider scope than necessary and makes the control flow less clear, as output is only populated in one of two branches before being used at line 510.

Copilot uses AI. Check for mistakes.
for (unsigned char ch : input) {
values.push_back(static_cast<float>(ch));
std::vector<float> values;
values.reserve(length > 0 ? length : 1);
if (length == 0) {
values.push_back(0.0f);
} else {
for (unsigned char ch : input) {
values.push_back(static_cast<float>(ch));
}
}
}

auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
auto transformed = mlx::core::sin(divided);
transformed.eval();
transformed.wait();
ensure_contiguous(transformed);

std::vector<float> buffer(transformed.size());
copy_to_buffer(transformed, buffer.data(), buffer.size());

std::string output;
output.reserve(buffer.size());
for (float value : buffer) {
const float normalized = std::fabs(value);
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
output.push_back(static_cast<char>(32 + code));
}
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
auto transformed = mlx::core::sin(divided);
transformed.eval();
transformed.wait();
ensure_contiguous(transformed);

std::vector<float> buffer(transformed.size());
copy_to_buffer(transformed, buffer.data(), buffer.size());

output.reserve(buffer.size());
for (float value : buffer) {
const float normalized = std::fabs(value);
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
output.push_back(static_cast<char>(32 + code));
}

if (output.empty()) {
output = "";
if (output.empty()) {
output = "";
}
}

auto* data = static_cast<char*>(std::malloc(output.size() + 1));
Expand Down
Loading
Loading