Skip to content

Commit 0f0e78f

Browse files
authored
Bump to MLX v0.26.1 (#76)
* bump to MLX v0.26.1 * variant generator: check overload/name sizes before printing * mlxvariants: handle additional overload for normal() * regenerate files
1 parent 9ebe155 commit 0f0e78f

File tree

10 files changed

+151
-4
lines changed

10 files changed

+151
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ else()
3535
FetchContent_Declare(
3636
mlx
3737
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
38-
GIT_TAG v0.25.1)
38+
GIT_TAG v0.26.3)
3939
FetchContent_MakeAvailable(mlx)
4040
endif()
4141

mlx/c/fft.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ extern "C" int mlx_fft_fftn(
6868
}
6969
return 0;
7070
}
71+
extern "C" int mlx_fft_fftshift(
72+
mlx_array* res,
73+
const mlx_array a,
74+
const int* axes,
75+
size_t axes_num,
76+
const mlx_stream s) {
77+
try {
78+
mlx_array_set_(
79+
*res,
80+
mlx::core::fft::fftshift(
81+
mlx_array_get_(a),
82+
std::vector<int>(axes, axes + axes_num),
83+
mlx_stream_get_(s)));
84+
} catch (std::exception& e) {
85+
mlx_error(e.what());
86+
return 1;
87+
}
88+
return 0;
89+
}
7190
extern "C" int mlx_fft_ifft(
7291
mlx_array* res,
7392
const mlx_array a,
@@ -128,6 +147,25 @@ extern "C" int mlx_fft_ifftn(
128147
}
129148
return 0;
130149
}
150+
extern "C" int mlx_fft_ifftshift(
151+
mlx_array* res,
152+
const mlx_array a,
153+
const int* axes,
154+
size_t axes_num,
155+
const mlx_stream s) {
156+
try {
157+
mlx_array_set_(
158+
*res,
159+
mlx::core::fft::ifftshift(
160+
mlx_array_get_(a),
161+
std::vector<int>(axes, axes + axes_num),
162+
mlx_stream_get_(s)));
163+
} catch (std::exception& e) {
164+
mlx_error(e.what());
165+
return 1;
166+
}
167+
return 0;
168+
}
131169
extern "C" int mlx_fft_irfft(
132170
mlx_array* res,
133171
const mlx_array a,

mlx/c/fft.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ int mlx_fft_fftn(
4949
const int* axes,
5050
size_t axes_num,
5151
const mlx_stream s);
52+
int mlx_fft_fftshift(
53+
mlx_array* res,
54+
const mlx_array a,
55+
const int* axes,
56+
size_t axes_num,
57+
const mlx_stream s);
5258
int mlx_fft_ifft(
5359
mlx_array* res,
5460
const mlx_array a,
@@ -71,6 +77,12 @@ int mlx_fft_ifftn(
7177
const int* axes,
7278
size_t axes_num,
7379
const mlx_stream s);
80+
int mlx_fft_ifftshift(
81+
mlx_array* res,
82+
const mlx_array a,
83+
const int* axes,
84+
size_t axes_num,
85+
const mlx_stream s);
7486
int mlx_fft_irfft(
7587
mlx_array* res,
7688
const mlx_array a,

mlx/c/linalg.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ extern "C" int mlx_linalg_cross(
5757
}
5858
return 0;
5959
}
60+
extern "C" int mlx_linalg_eig(
61+
mlx_array* res_0,
62+
mlx_array* res_1,
63+
const mlx_array a,
64+
const mlx_stream s) {
65+
try {
66+
{
67+
auto [tpl_0, tpl_1] =
68+
mlx::core::linalg::eig(mlx_array_get_(a), mlx_stream_get_(s));
69+
mlx_array_set_(*res_0, tpl_0);
70+
mlx_array_set_(*res_1, tpl_1);
71+
};
72+
} catch (std::exception& e) {
73+
mlx_error(e.what());
74+
return 1;
75+
}
76+
return 0;
77+
}
6078
extern "C" int mlx_linalg_eigh(
6179
mlx_array* res_0,
6280
mlx_array* res_1,
@@ -76,6 +94,18 @@ extern "C" int mlx_linalg_eigh(
7694
}
7795
return 0;
7896
}
97+
extern "C" int
98+
mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) {
99+
try {
100+
mlx_array_set_(
101+
*res,
102+
mlx::core::linalg::eigvals(mlx_array_get_(a), mlx_stream_get_(s)));
103+
} catch (std::exception& e) {
104+
mlx_error(e.what());
105+
return 1;
106+
}
107+
return 0;
108+
}
79109
extern "C" int mlx_linalg_eigvalsh(
80110
mlx_array* res,
81111
const mlx_array a,

mlx/c/linalg.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,18 @@ int mlx_linalg_cross(
4343
const mlx_array b,
4444
int axis,
4545
const mlx_stream s);
46+
int mlx_linalg_eig(
47+
mlx_array* res_0,
48+
mlx_array* res_1,
49+
const mlx_array a,
50+
const mlx_stream s);
4651
int mlx_linalg_eigh(
4752
mlx_array* res_0,
4853
mlx_array* res_1,
4954
const mlx_array a,
5055
const char* UPLO,
5156
const mlx_stream s);
57+
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
5258
int mlx_linalg_eigvalsh(
5359
mlx_array* res,
5460
const mlx_array a,

mlx/c/ops.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,6 +2752,26 @@ extern "C" int mlx_scatter_prod(
27522752
}
27532753
return 0;
27542754
}
2755+
extern "C" int mlx_segmented_mm(
2756+
mlx_array* res,
2757+
const mlx_array a,
2758+
const mlx_array b,
2759+
const mlx_array segments,
2760+
const mlx_stream s) {
2761+
try {
2762+
mlx_array_set_(
2763+
*res,
2764+
mlx::core::segmented_mm(
2765+
mlx_array_get_(a),
2766+
mlx_array_get_(b),
2767+
mlx_array_get_(segments),
2768+
mlx_stream_get_(s)));
2769+
} catch (std::exception& e) {
2770+
mlx_error(e.what());
2771+
return 1;
2772+
}
2773+
return 0;
2774+
}
27552775
extern "C" int
27562776
mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) {
27572777
try {

mlx/c/ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,12 @@ int mlx_scatter_prod(
868868
const int* axes,
869869
size_t axes_num,
870870
const mlx_stream s);
871+
int mlx_segmented_mm(
872+
mlx_array* res,
873+
const mlx_array a,
874+
const mlx_array b,
875+
const mlx_array segments,
876+
const mlx_stream s);
871877
int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s);
872878
int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s);
873879
int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s);

mlx/c/random.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,32 @@ extern "C" int mlx_random_multivariate_normal(
195195
}
196196
return 0;
197197
}
198+
extern "C" int mlx_random_normal_broadcast(
199+
mlx_array* res,
200+
const int* shape,
201+
size_t shape_num,
202+
mlx_dtype dtype,
203+
const mlx_array loc /* may be null */,
204+
const mlx_array scale /* may be null */,
205+
const mlx_array key /* may be null */,
206+
const mlx_stream s) {
207+
try {
208+
mlx_array_set_(
209+
*res,
210+
mlx::core::random::normal(
211+
std::vector<int>(shape, shape + shape_num),
212+
mlx_dtype_to_cpp(dtype),
213+
(loc.ctx ? std::make_optional(mlx_array_get_(loc)) : std::nullopt),
214+
(scale.ctx ? std::make_optional(mlx_array_get_(scale))
215+
: std::nullopt),
216+
(key.ctx ? std::make_optional(mlx_array_get_(key)) : std::nullopt),
217+
mlx_stream_get_(s)));
218+
} catch (std::exception& e) {
219+
mlx_error(e.what());
220+
return 1;
221+
}
222+
return 0;
223+
}
198224
extern "C" int mlx_random_normal(
199225
mlx_array* res,
200226
const int* shape,

mlx/c/random.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ int mlx_random_multivariate_normal(
8888
mlx_dtype dtype,
8989
const mlx_array key /* may be null */,
9090
const mlx_stream s);
91+
int mlx_random_normal_broadcast(
92+
mlx_array* res,
93+
const int* shape,
94+
size_t shape_num,
95+
mlx_dtype dtype,
96+
const mlx_array loc /* may be null */,
97+
const mlx_array scale /* may be null */,
98+
const mlx_array key /* may be null */,
99+
const mlx_stream s);
91100
int mlx_random_normal(
92101
mlx_array* res,
93102
const int* shape,

python/mlxvariants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ def _make_variant_suffixes(name, defs, variants):
1919
print("OVL", file=sys.stderr)
2020
if name in variants:
2121
variants = variants[name]
22-
for i, d in enumerate(defs):
23-
print("OVL", i, _pretty_string_def(d), " -> ", variants[i], file=sys.stderr)
2422
if len(variants) != len(defs):
2523
print("function overloads length:", len(defs), file=sys.stderr)
2624
for i, d in enumerate(defs):
@@ -29,6 +27,8 @@ def _make_variant_suffixes(name, defs, variants):
2927
for i, v in enumerate(variants):
3028
print(i, v, file=sys.stderr)
3129
raise RuntimeError("function overloads and namings do not match")
30+
for i, d in enumerate(defs):
31+
print("OVL", i, _pretty_string_def(d), " -> ", variants[i], file=sys.stderr)
3232
newdefs = []
3333
for i, d in enumerate(defs):
3434
v = variants[i]
@@ -112,7 +112,7 @@ def mlx_core_random(name, defs):
112112
"permutation": ["", "arange"],
113113
"split": ["num", ""],
114114
"uniform": ["", None, None, None],
115-
"normal": ["", None, None, None],
115+
"normal": ["broadcast", "", None, None, None],
116116
}
117117
return _make_variant_suffixes(name, defs, variants)
118118

0 commit comments

Comments
 (0)