Skip to content

Commit 50e0b29

Browse files
committed
[flang] Implement DOT_PRODUCT in the runtime
API, implementation, and basic tests for the transformational reduction intrinsic function DOT_PRODUCT in the runtime support library. Differential Revision: https://reviews.llvm.org/D102351
1 parent 7c57a9b commit 50e0b29

File tree

8 files changed

+481
-21
lines changed

8 files changed

+481
-21
lines changed

flang/runtime/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ add_flang_library(FortranRuntime
3939
connection.cpp
4040
derived.cpp
4141
descriptor.cpp
42+
dot-product.cpp
4243
edit-input.cpp
4344
edit-output.cpp
4445
environment.cpp

flang/runtime/complex-reduction.c

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,34 +75,51 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
7575
*/
7676

7777
#define CPP_NAME(name) Cpp##name
78-
#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro) \
79-
struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, REDUCTION_ARGS); \
80-
cComplex RTNAME(name)(REDUCTION_ARGS) { \
78+
#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro, ARGS, ARG_NAMES) \
79+
struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, ARGS); \
80+
cComplex RTNAME(name)(ARGS) { \
8181
struct cpptype result; \
82-
RTNAME(CPP_NAME(name))(&result, REDUCTION_ARG_NAMES); \
82+
RTNAME(CPP_NAME(name))(&result, ARG_NAMES); \
8383
return cmplxMacro(result.r, result.i); \
8484
}
8585

8686
/* TODO: COMPLEX(2 & 3) */
8787

8888
/* SUM() */
89-
ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF)
90-
ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX)
89+
ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
90+
REDUCTION_ARGS, REDUCTION_ARG_NAMES)
91+
ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
92+
REDUCTION_ARGS, REDUCTION_ARG_NAMES)
9193
#if LONG_DOUBLE == 80
92-
ADAPT_REDUCTION(
93-
SumComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
94+
ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
95+
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
9496
#elif LONG_DOUBLE == 128
95-
ADAPT_REDUCTION(
96-
SumComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
97+
ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
98+
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
9799
#endif
98100

99101
/* PRODUCT() */
100-
ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF)
101-
ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX)
102+
ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
103+
REDUCTION_ARGS, REDUCTION_ARG_NAMES)
104+
ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
105+
REDUCTION_ARGS, REDUCTION_ARG_NAMES)
102106
#if LONG_DOUBLE == 80
103-
ADAPT_REDUCTION(
104-
ProductComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
107+
ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
108+
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
105109
#elif LONG_DOUBLE == 128
106-
ADAPT_REDUCTION(
107-
ProductComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
110+
ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
111+
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
112+
#endif
113+
114+
/* DOT_PRODUCT() */
115+
ADAPT_REDUCTION(DotProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
116+
DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
117+
ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
118+
DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
119+
#if LONG_DOUBLE == 80
120+
ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
121+
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
122+
#elif LONG_DOUBLE == 128
123+
ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
124+
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
108125
#endif

flang/runtime/complex-reduction.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,17 @@ double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
4949
long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
5050
long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
5151

52+
#define DOT_PRODUCT_ARGS \
53+
const struct CppDescriptor *x, const struct CppDescriptor *y, \
54+
const char *source, int line, int dim /*=0*/, \
55+
const struct CppDescriptor *mask /*=NULL*/
56+
#define DOT_PRODUCT_ARG_NAMES x, y, source, line, dim, mask
57+
58+
float_Complex_t RTNAME(DotProductComplex2)(DOT_PRODUCT_ARGS);
59+
float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
60+
float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
61+
double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
62+
long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
63+
long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
64+
5265
#endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_

flang/runtime/dot-product.cpp

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//===-- runtime/dot-product.cpp -------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "cpp-type.h"
10+
#include "descriptor.h"
11+
#include "reduction.h"
12+
#include "terminator.h"
13+
#include "tools.h"
14+
#include <cinttypes>
15+
16+
namespace Fortran::runtime {
17+
18+
template <typename ACCUMULATOR>
19+
static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
20+
Terminator &terminator) -> typename ACCUMULATOR::Result {
21+
RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
22+
SubscriptValue n{x.GetDimension(0).Extent()};
23+
if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
24+
terminator.Crash(
25+
"DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
26+
static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
27+
}
28+
SubscriptValue xAt{x.GetDimension(0).LowerBound()};
29+
SubscriptValue yAt{y.GetDimension(0).LowerBound()};
30+
ACCUMULATOR accumulator{x, y};
31+
for (SubscriptValue j{0}; j < n; ++j) {
32+
accumulator.Accumulate(xAt++, yAt++);
33+
}
34+
return accumulator.GetResult();
35+
}
36+
37+
template <TypeCategory RCAT, int RKIND,
38+
template <typename, TypeCategory, typename, typename> class ACCUM>
39+
struct DotProduct {
40+
using Result = CppTypeFor<RCAT, RKIND>;
41+
template <TypeCategory XCAT, int XKIND> struct DP1 {
42+
template <TypeCategory YCAT, int YKIND> struct DP2 {
43+
Result operator()(const Descriptor &x, const Descriptor &y,
44+
Terminator &terminator) const {
45+
if constexpr (constexpr auto resultType{
46+
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
47+
if constexpr (resultType->first == RCAT &&
48+
resultType->second <= RKIND) {
49+
using Accum = ACCUM<Result, XCAT, CppTypeFor<XCAT, XKIND>,
50+
CppTypeFor<YCAT, YKIND>>;
51+
return DoDotProduct<Accum>(x, y, terminator);
52+
}
53+
}
54+
terminator.Crash(
55+
"DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
56+
static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
57+
static_cast<int>(YCAT), YKIND);
58+
}
59+
};
60+
Result operator()(const Descriptor &x, const Descriptor &y,
61+
Terminator &terminator, TypeCategory yCat, int yKind) const {
62+
return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
63+
}
64+
};
65+
Result operator()(const Descriptor &x, const Descriptor &y,
66+
const char *source, int line) const {
67+
Terminator terminator{source, line};
68+
auto xCatKind{x.type().GetCategoryAndKind()};
69+
auto yCatKind{y.type().GetCategoryAndKind()};
70+
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
71+
return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
72+
x, y, terminator, yCatKind->first, yCatKind->second);
73+
}
74+
};
75+
76+
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
77+
class NumericAccumulator {
78+
public:
79+
using Result = RESULT;
80+
NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
81+
void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
82+
if constexpr (XCAT == TypeCategory::Complex) {
83+
sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
84+
static_cast<Result>(*y_.Element<YT>(&yAt));
85+
} else {
86+
sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
87+
static_cast<Result>(*y_.Element<YT>(&yAt));
88+
}
89+
}
90+
Result GetResult() const { return sum_; }
91+
92+
private:
93+
const Descriptor &x_, &y_;
94+
Result sum_{0};
95+
};
96+
97+
template <typename, TypeCategory, typename XT, typename YT>
98+
class LogicalAccumulator {
99+
public:
100+
using Result = bool;
101+
LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
102+
void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
103+
result_ = result_ ||
104+
(IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
105+
}
106+
bool GetResult() const { return result_; }
107+
108+
private:
109+
const Descriptor &x_, &y_;
110+
bool result_{false};
111+
};
112+
113+
extern "C" {
114+
std::int8_t RTNAME(DotProductInteger1)(
115+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
116+
return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
117+
x, y, source, line);
118+
}
119+
std::int16_t RTNAME(DotProductInteger2)(
120+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
121+
return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
122+
x, y, source, line);
123+
}
124+
std::int32_t RTNAME(DotProductInteger4)(
125+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
126+
return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
127+
x, y, source, line);
128+
}
129+
std::int64_t RTNAME(DotProductInteger8)(
130+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
131+
return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
132+
x, y, source, line);
133+
}
134+
#ifdef __SIZEOF_INT128__
135+
common::int128_t RTNAME(DotProductInteger16)(
136+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
137+
return DotProduct<TypeCategory::Integer, 16, NumericAccumulator>{}(
138+
x, y, source, line);
139+
}
140+
#endif
141+
142+
// TODO: REAL/COMPLEX(2 & 3)
143+
float RTNAME(DotProductReal4)(
144+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
145+
return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
146+
x, y, source, line);
147+
}
148+
double RTNAME(DotProductReal8)(
149+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
150+
return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
151+
x, y, source, line);
152+
}
153+
#if LONG_DOUBLE == 80
154+
long double RTNAME(DotProductReal10)(
155+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
156+
return DotProduct<TypeCategory::Real, 10, NumericAccumulator>{}(
157+
x, y, source, line);
158+
}
159+
#elif LONG_DOUBLE == 128
160+
long double RTNAME(DotProductReal16)(
161+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
162+
return DotProduct<TypeCategory::Real, 16, NumericAccumulator>{}(
163+
x, y, source, line);
164+
}
165+
#endif
166+
167+
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
168+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
169+
auto z{DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
170+
x, y, source, line)};
171+
result = std::complex<float>{
172+
static_cast<float>(z.real()), static_cast<float>(z.imag())};
173+
}
174+
void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
175+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
176+
result = DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
177+
x, y, source, line);
178+
}
179+
#if LONG_DOUBLE == 80
180+
void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
181+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
182+
result = DotProduct<TypeCategory::Complex, 10, NumericAccumulator>{}(
183+
x, y, source, line);
184+
}
185+
#elif LONG_DOUBLE == 128
186+
void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
187+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
188+
result = DotProduct<TypeCategory::Complex, 16, NumericAccumulator>{}(
189+
x, y, source, line);
190+
}
191+
#endif
192+
193+
bool RTNAME(DotProductLogical)(
194+
const Descriptor &x, const Descriptor &y, const char *source, int line) {
195+
return DotProduct<TypeCategory::Logical, 1, LogicalAccumulator>{}(
196+
x, y, source, line);
197+
}
198+
} // extern "C"
199+
} // namespace Fortran::runtime

flang/runtime/reduction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
// Implements ALL, ANY, COUNT, IPARITY, & PARITY for all required operand
1010
// types and shapes.
1111
//
12-
// FINDLOC, SUM, and PRODUCT are in their own eponymous source files;
13-
// NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp.
12+
// DOT_PRODUCT, FINDLOC, SUM, and PRODUCT are in their own eponymous source
13+
// files; NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp.
1414

1515
#include "reduction.h"
1616
#include "reduction-templates.h"

flang/runtime/reduction.h

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
// Defines the API for the reduction transformational intrinsic functions.
10-
// (Except the complex-valued total reduction forms of SUM and PRODUCT;
11-
// the API for those is in complex-reduction.h so that C's _Complex can
12-
// be used for their return types.)
10+
// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction
11+
// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that
12+
// C's _Complex can be used for their return types.)
1313

1414
#ifndef FORTRAN_RUNTIME_REDUCTION_H_
1515
#define FORTRAN_RUNTIME_REDUCTION_H_
@@ -275,6 +275,48 @@ bool RTNAME(Parity)(
275275
void RTNAME(ParityDim)(Descriptor &result, const Descriptor &, int dim,
276276
const char *source, int line);
277277

278+
// DOT_PRODUCT
279+
std::int8_t RTNAME(DotProductInteger1)(const Descriptor &, const Descriptor &,
280+
const char *source = nullptr, int line = 0);
281+
std::int16_t RTNAME(DotProductInteger2)(const Descriptor &, const Descriptor &,
282+
const char *source = nullptr, int line = 0);
283+
std::int32_t RTNAME(DotProductInteger4)(const Descriptor &, const Descriptor &,
284+
const char *source = nullptr, int line = 0);
285+
std::int64_t RTNAME(DotProductInteger8)(const Descriptor &, const Descriptor &,
286+
const char *source = nullptr, int line = 0);
287+
#ifdef __SIZEOF_INT128__
288+
common::int128_t RTNAME(DotProductInteger16)(const Descriptor &,
289+
const Descriptor &, const char *source = nullptr, int line = 0);
290+
#endif
291+
float RTNAME(DotProductReal2)(const Descriptor &, const Descriptor &,
292+
const char *source = nullptr, int line = 0);
293+
float RTNAME(DotProductReal3)(const Descriptor &, const Descriptor &,
294+
const char *source = nullptr, int line = 0);
295+
float RTNAME(DotProductReal4)(const Descriptor &, const Descriptor &,
296+
const char *source = nullptr, int line = 0);
297+
double RTNAME(DotProductReal8)(const Descriptor &, const Descriptor &,
298+
const char *source = nullptr, int line = 0);
299+
long double RTNAME(DotProductReal10)(const Descriptor &, const Descriptor &,
300+
const char *source = nullptr, int line = 0);
301+
long double RTNAME(DotProductReal16)(const Descriptor &, const Descriptor &,
302+
const char *source = nullptr, int line = 0);
303+
void RTNAME(CppDotProductComplex2)(std::complex<float> &, const Descriptor &,
304+
const Descriptor &, const char *source = nullptr, int line = 0);
305+
void RTNAME(CppDotProductComplex3)(std::complex<float> &, const Descriptor &,
306+
const Descriptor &, const char *source = nullptr, int line = 0);
307+
void RTNAME(CppDotProductComplex4)(std::complex<float> &, const Descriptor &,
308+
const Descriptor &, const char *source = nullptr, int line = 0);
309+
void RTNAME(CppDotProductComplex8)(std::complex<double> &, const Descriptor &,
310+
const Descriptor &, const char *source = nullptr, int line = 0);
311+
void RTNAME(CppDotProductComplex10)(std::complex<long double> &,
312+
const Descriptor &, const Descriptor &, const char *source = nullptr,
313+
int line = 0);
314+
void RTNAME(CppDotProductComplex16)(std::complex<long double> &,
315+
const Descriptor &, const Descriptor &, const char *source = nullptr,
316+
int line = 0);
317+
bool RTNAME(DotProductLogical)(const Descriptor &, const Descriptor &,
318+
const char *source = nullptr, int line = 0);
319+
278320
} // extern "C"
279321
} // namespace Fortran::runtime
280322
#endif // FORTRAN_RUNTIME_REDUCTION_H_

0 commit comments

Comments
 (0)