-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
solver_kernels_ffi.cc
487 lines (431 loc) · 21.8 KB
/
solver_kernels_ffi.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/solver_kernels_ffi.h"
#include <algorithm>
#include <cstdint>
#include <string_view>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/ffi_helpers.h"
#include "jaxlib/gpu/blas_handle_pool.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/make_batch_pointers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace ffi = ::xla::ffi;
namespace {
template <typename T>
inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
int64_t size,
std::string_view name) {
auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
if (!maybe_workspace.has_value()) {
return absl::Status(
absl::StatusCode::kResourceExhausted,
absl::StrFormat("Unable to allocate workspace for %s", name));
}
return static_cast<T*>(maybe_workspace.value());
}
} // namespace
#define SOLVER_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \
return impl<float>(__VA_ARGS__); \
} else if (dataType == ffi::F64) { \
return impl<double>(__VA_ARGS__); \
} else if (dataType == ffi::C64) { \
return impl<gpuComplex>(__VA_ARGS__); \
} else if (dataType == ffi::C128) { \
return impl<gpuDoubleComplex>(__VA_ARGS__); \
}
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \
return impl<float>(__VA_ARGS__); \
} else if (dataType == ffi::F64) { \
return impl<double>(__VA_ARGS__); \
} else if (dataType == ffi::C64) { \
return impl<gpublasComplex>(__VA_ARGS__); \
} else if (dataType == ffi::C128) { \
return impl<gpublasDoubleComplex>(__VA_ARGS__); \
}
// LU decomposition: getrf
namespace {
#define GETRF_KERNEL_IMPL(type, name) \
template <> \
struct GetrfKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
type* workspace, int lwork, int* ipiv, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
} \
}
template <typename T>
struct GetrfKernel;
GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf);
GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf);
GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf);
GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf);
#undef GETRF_KERNEL_IMPL
template <typename T>
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
GetrfKernel<T>::BufferSize(handle.get(), m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "getrf"));
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}
int ipiv_step = std::min(m, n);
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(GetrfKernel<T>::Run(
handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data));
out_data += m * n;
ipiv_data += ipiv_step;
++info_data;
}
return ffi::Error::Success();
}
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct GetrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
int* ipiv, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
} \
}
template <typename T>
struct GetrfBatchedKernel;
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
#undef GETRF_BATCHED_KERNEL_IMPL
template <typename T>
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(auto batch_ptrs,
AllocateWorkspace<T*>(scratch, batch, "batched getrf"));
auto a_data = a.untyped_data();
auto out_data = out->untyped_data();
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}
MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
sizeof(T) * n * n);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
return ffi::Error::Success();
}
ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type();
if (dataType != out->element_type()) {
return ffi::Error::InvalidArgument(
"The input and output to getrf must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "getrf"));
FFI_RETURN_IF_ERROR(CheckShape(
ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf"));
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf"));
if (batch > 1 && rows == cols && rows / batch <= 128) {
SOLVER_BLAS_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a,
out, ipiv, info);
} else {
SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out,
ipiv, info);
}
return ffi::Error::InvalidArgument("Unsupported element type for getrf");
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::Buffer<ffi::S32>>() // ipiv
.Ret<ffi::Buffer<ffi::S32>>() // info
);
// QR decomposition: geqrf
namespace {
#define GEQRF_KERNEL_IMPL(type, name) \
template <> \
struct GeqrfKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
type* tau, type* workspace, int lwork, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, a, m, tau, workspace, lwork, info)); \
} \
}
template <typename T>
struct GeqrfKernel;
GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf);
GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf);
GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf);
GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf);
#undef GEQRF_KERNEL_IMPL
template <typename T>
ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> tau) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
GeqrfKernel<T>::BufferSize(handle.get(), m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "geqrf"));
// Note: We ignore the returned value of info because it is only used for
// shape checking (which we already do ourselves), but it is expected to be
// in device memory, so we need to allocate it.
FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace<int>(scratch, 1, "geqrf"));
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto tau_data = static_cast<T*>(tau->untyped_data());
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}
int out_step = m * n;
int tau_step = std::min(m, n);
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel<T>::Run(
handle.get(), m, n, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += tau_step;
}
return ffi::Error::Success();
}
#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct GeqrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \
type** tau, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
} \
}
template <typename T>
struct GeqrfBatchedKernel;
GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched);
GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched);
GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched);
GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched);
#undef GEQRF_BATCHED_KERNEL_IMPL
template <typename T>
ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> tau) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(auto out_batch_ptrs,
AllocateWorkspace<T*>(scratch, batch, "batched geqrf"));
FFI_ASSIGN_OR_RETURN(auto tau_batch_ptrs,
AllocateWorkspace<T*>(scratch, batch, "batched geqrf"));
auto a_data = a.untyped_data();
auto out_data = out->untyped_data();
auto tau_data = tau->untyped_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}
MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch,
sizeof(T) * m * n);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch,
sizeof(T) * std::min(m, n));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
// We ignore the output value of `info` because it is only used for shape
// checking.
int info;
FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel<T>::Run(
handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch));
return ffi::Error::Success();
}
ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> tau) {
auto dataType = a.element_type();
if (dataType != out->element_type() || dataType != tau->element_type()) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to geqrf must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "geqrf"));
FFI_RETURN_IF_ERROR(CheckShape(
tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf"));
if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) {
SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream,
scratch, a, out, tau);
} else {
SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out,
tau);
}
return ffi::Error::InvalidArgument("Unsupported element type for geqrf");
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::AnyBuffer>() // tau
);
// Householder transformations: orgqr
namespace {
#define ORGQR_KERNEL_IMPL(type, name) \
template <> \
struct OrgqrKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n, int k) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \
/*tau=*/nullptr, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \
type* a, type* tau, type* workspace, int lwork, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
} \
}
template <typename T>
struct OrgqrKernel;
ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr);
ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr);
ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr);
ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr);
#undef ORGQR_KERNEL_IMPL
template <typename T>
ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::AnyBuffer tau,
ffi::Result<ffi::AnyBuffer> out) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto k, MaybeCastNoOverflow<int>(size));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
OrgqrKernel<T>::BufferSize(handle.get(), m, n, k));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "orgqr"));
// Note: We ignore the returned value of info because it is only used for
// shape checking (which we already do ourselves), but it is expected to be
// in device memory, so we need to allocate it.
FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace<int>(scratch, 1, "orgqr"));
auto a_data = static_cast<T*>(a.untyped_data());
auto tau_data = static_cast<T*>(tau.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}
int out_step = m * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel<T>::Run(
handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += k;
}
return ffi::Error::Success();
}
ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::AnyBuffer a, ffi::AnyBuffer tau,
ffi::Result<ffi::AnyBuffer> out) {
auto dataType = a.element_type();
if (dataType != tau.element_type() || dataType != out->element_type()) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to orgqr must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [tau_batch, size]),
SplitBatch1D(tau.dimensions()));
if (tau_batch != batch) {
return ffi::Error::InvalidArgument(
"The batch dimensions of the inputs to orgqr must match");
}
if (size > cols) {
return ffi::Error::InvalidArgument(
"The trailing dimension of the tau input to orgqr must be less than or "
"equal to the number of columns of the input matrix");
}
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr"));
SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a,
tau, out);
return ffi::Error::InvalidArgument("Unsupported element type for orgqr");
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Arg<ffi::AnyBuffer>() // tau
.Ret<ffi::AnyBuffer>() // out
);
#undef SOLVER_DISPATCH_IMPL
} // namespace JAX_GPU_NAMESPACE
} // namespace jax