Skip to content

Commit

Permalink
[DSLX] Fix constexpr eval for struct parametrics.
Browse files Browse the repository at this point in the history
Note that by consolidating code for dimension evaluation this also gives
warning-as-error for type-annotated literals given as explicit parametrics, so
there are corresponding `.x` file changes for that.

Fixes #727

PiperOrigin-RevId: 534580874
  • Loading branch information
cdleary authored and Copybara-Service committed May 23, 2023
1 parent 7526816 commit 75930f7
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 89 deletions.
2 changes: 1 addition & 1 deletion xls/dslx/bytecode/bytecode_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ pub enum ImportedEnum : u32 {
import imported
type MyEnum = imported::ImportedEnum;
type MyStruct = imported::ImportedStruct<u32:16>;
type MyStruct = imported::ImportedStruct<16>;
#[test]
fn main() -> u32 {
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/stdlib/bfloat16.x
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// bfloat16 routines.
import apfloat

pub type BF16 = apfloat::APFloat<u32:8, u32:7>;
pub type BF16 = apfloat::APFloat<8, 7>;
pub type FloatTag = apfloat::APFloatTag;

pub fn qnan() -> BF16 { apfloat::qnan<u32:8, u32:7>() }
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/stdlib/float32.x
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import apfloat
// TODO(rspringer): Make u32:8 and u32:23 symbolic constants. Currently, such
// constants don't propagate correctly and fail to resolve when in parametric
// specifications.
pub type F32 = apfloat::APFloat<u32:8, u32:23>;
pub type F32 = apfloat::APFloat<8, 23>;
pub type FloatTag = apfloat::APFloatTag;

pub type TaggedF32 = (FloatTag, F32);
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/stdlib/float64.x
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import apfloat
// TODO(rspringer): Make u32:11 and u32:52 symbolic constants. Currently, such
// constants don't propagate correctly and fail to resolve when in parametric
// specifications.
pub type F64 = apfloat::APFloat<u32:11, u32:52>;
pub type F64 = apfloat::APFloat<11, 52>;
pub type FloatTag = apfloat::APFloatTag;

pub type TaggedF64 = (FloatTag, F64);
Expand Down
5 changes: 5 additions & 0 deletions xls/dslx/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ dslx_lang_test(
test_ir_equivalence = False,
)

xls_dslx_test(
name = "parametric_issue_727_test",
srcs = ["parametric_issue_727.x"],
)

dslx_lang_test(
name = "parametric_value_as_nested_loop_bound",
)
Expand Down
6 changes: 3 additions & 3 deletions xls/dslx/tests/explicit_parametric.x
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ struct Generic<X:u32, Y:u32> {
b: bits[Y]
}

pub fn foo(a: bits[4]) -> Generic<u32:4, u32:8> {
pub fn foo(a: bits[4]) -> Generic<4, 8> {
Generic<u32:4, u32:8>{ a: a, b: bits[8]:0 }
}

pub fn indirect_foo<X: u32, Y: u32 = {(X * X) as u32}>(a: bits[4]) -> Generic<{X as u32}, u32:8> {
pub fn indirect_foo<X: u32, Y: u32 = {(X * X) as u32}>(a: bits[4]) -> Generic<X, 8> {
Generic<{X as u32}, u32:8>{ a: a as bits[X], b: bits[8]:32 }
}

pub fn instantiates_indirect_foo(a: bits[16]) -> Generic<u32:16, u32:8> {
pub fn instantiates_indirect_foo(a: bits[16]) -> Generic<16, 8> {
indirect_foo<u32:16>(a as bits[4])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ struct MyParametric<A: u32, B: u32 = {double(A)}> {

// TODO(leary): 2020-12-19 This doesn't work, we have to annotate B as well.
// We should be able to infer it.
// fn f() -> MyParametric<u32:8> {
fn main() -> MyParametric<u32:8, u32:16> {
// fn f() -> MyParametric<8> {
fn main() -> MyParametric<8, 16> {
MyParametric { x: u8:1, y: u16:2 }
}

Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/tests/parametric_importer.x
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import xls.dslx.tests.parametric_import

type LocalType = parametric_import::Type<u32:1, u32:2>;
type LocalType = parametric_import::Type<1, 2>;

#[test]
fn parametric_importer() {
Expand Down
47 changes: 47 additions & 0 deletions xls/dslx/tests/parametric_issue_727.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2023 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

struct MyStruct<WIDTH: u32> {
myfield: bits[WIDTH]
}

fn myfunc<FIELD_WIDTH: u32>(arg: MyStruct<FIELD_WIDTH>) -> u32 {
(arg.myfield as u32)
}

const WIDTH_15 = u32:15;

fn myfunc_spec1(arg: MyStruct<15>) -> u32 {
(myfunc<u32:15>(arg))
}

fn myfunc_spec2(arg: MyStruct<15>) -> u32 {
(myfunc<WIDTH_15>(arg))
}

fn myfunc_spec3(arg: MyStruct<15>) -> u32 {
(myfunc(arg))
}

fn myfunc_spec4(arg: MyStruct<WIDTH_15>) -> u32 {
(myfunc<u32:15>(arg))
}

fn myfunc_spec5(arg: MyStruct<WIDTH_15>) -> u32 {
(myfunc<WIDTH_15>(arg))
}

fn myfunc_spec6(arg: MyStruct<WIDTH_15>) -> u32 {
(myfunc(arg))
}
1 change: 0 additions & 1 deletion xls/dslx/tests/parametric_struct.x
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ struct ParametricPoint<A: u32, B: u32 = {double(A)}> {
y: bits[B]
}


struct WrapperStruct<P: u32, Q: u32> {
pp: ParametricPoint<P, Q>
}
Expand Down
7 changes: 6 additions & 1 deletion xls/dslx/type_system/concrete_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class ConcreteTypeDim {
std::variant<InterpValue, OwnedParametric> value_;
};

inline std::ostream& operator<<(std::ostream& os, const ConcreteTypeDim& ctd) {
os << ctd.ToString();
return os;
}

class EnumType;
class BitsType;
class FunctionType;
Expand Down Expand Up @@ -260,7 +265,7 @@ class MetaType : public ConcreteType {
std::vector<ConcreteTypeDim> GetAllDims() const override {
return wrapped_->GetAllDims();
}
absl::StatusOr<ConcreteTypeDim> GetTotalBitCount() const;
absl::StatusOr<ConcreteTypeDim> GetTotalBitCount() const override;
absl::StatusOr<std::unique_ptr<ConcreteType>> MapSize(
const MapFn& f) const override {
XLS_ASSIGN_OR_RETURN(auto wrapped, wrapped_->MapSize(f));
Expand Down
90 changes: 40 additions & 50 deletions xls/dslx/type_system/deduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2168,7 +2168,7 @@ static absl::StatusOr<ConcreteTypeDim> DimToConcrete(const Expr* dim_expr,

// Now we try to constexpr evaluate it.
const ParametricEnv parametric_env = GetCurrentParametricEnv(ctx);
XLS_VLOG(3) << "Attempting to evaluate dimension expression: `"
XLS_VLOG(5) << "Attempting to evaluate dimension expression: `"
<< dim_expr->ToString()
<< "` via parametric env: " << parametric_env;
XLS_RETURN_IF_ERROR(
Expand Down Expand Up @@ -2283,7 +2283,11 @@ absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceArrayTypeAnnotation(
// definition (before parametrics are applied).
static absl::StatusOr<std::unique_ptr<ConcreteType>> ConcretizeStructAnnotation(
const TypeRefTypeAnnotation* type_annotation, const StructDef* struct_def,
const ConcreteType& base_type) {
const ConcreteType& base_type, DeduceCtx* ctx) {
XLS_VLOG(5) << "ConcreteStructAnnotation; type_annotation: "
<< type_annotation->ToString()
<< " struct_def: " << struct_def->ToString();

// Note: if there are too *few* annotated parametrics, some of them may be
// derived.
if (type_annotation->parametrics().size() >
Expand All @@ -2297,40 +2301,20 @@ static absl::StatusOr<std::unique_ptr<ConcreteType>> ConcretizeStructAnnotation(
type_annotation->parametrics().size()));
}

absl::flat_hash_map<
std::string, std::variant<int64_t, std::unique_ptr<ParametricExpression>>>
defined_to_annotated;
absl::flat_hash_map<std::string, ConcreteTypeDim> parametric_env;

for (int64_t i = 0; i < type_annotation->parametrics().size(); ++i) {
ParametricBinding* defined_parametric =
struct_def->parametric_bindings()[i];
ExprOrType eot = type_annotation->parametrics()[i];
XLS_RET_CHECK(std::holds_alternative<Expr*>(eot));
Expr* annotated_parametric = std::get<Expr*>(eot);
// TODO(leary): 2020-12-13 This is kind of an ad hoc
// constexpr-evaluate-to-int implementation, unify and consolidate it.
if (auto* cast = dynamic_cast<Cast*>(annotated_parametric)) {
Expr* expr = cast->expr();
if (auto* number = dynamic_cast<Number*>(expr)) {
XLS_ASSIGN_OR_RETURN(int64_t value, number->GetAsUint64());
defined_to_annotated[defined_parametric->identifier()] = value;
} else {
auto* name_ref = dynamic_cast<NameRef*>(expr);
XLS_RET_CHECK(name_ref != nullptr);
defined_to_annotated[defined_parametric->identifier()] =
std::make_unique<ParametricSymbol>(name_ref->identifier(),
name_ref->span());
}
} else if (auto* number = dynamic_cast<Number*>(annotated_parametric)) {
XLS_ASSIGN_OR_RETURN(int value, number->GetAsUint64());
defined_to_annotated[defined_parametric->identifier()] = value;
} else {
auto* name_ref = dynamic_cast<NameRef*>(annotated_parametric);
XLS_RET_CHECK(name_ref != nullptr);
defined_to_annotated[defined_parametric->identifier()] =
std::make_unique<ParametricSymbol>(name_ref->identifier(),
name_ref->span());
}
XLS_VLOG(5) << "annotated_parametric: `" << annotated_parametric->ToString()
<< "`";

XLS_ASSIGN_OR_RETURN(ConcreteTypeDim ctd,
DimToConcrete(annotated_parametric, ctx));
parametric_env.emplace(defined_parametric->identifier(), std::move(ctd));
}

// For the remainder of the formal parameterics (i.e. after the explicitly
Expand All @@ -2350,21 +2334,17 @@ static absl::StatusOr<std::unique_ptr<ConcreteType>> ConcretizeStructAnnotation(
}
}

// Convert the defined_to_annotated map to use borrowed pointers for the
// ParametricExpressions, as required by `ParametricExpression::Env` (so we
// can `ParametricExpression::Evaluate()`).
ParametricExpression::Env env;
for (auto& item : defined_to_annotated) {
if (std::holds_alternative<int64_t>(item.second)) {
env[item.first] = InterpValue::MakeU32(std::get<int64_t>(item.second));
for (const auto& [k, ctd] : parametric_env) {
if (std::holds_alternative<InterpValue>(ctd.value())) {
env[k] = std::get<InterpValue>(ctd.value());
} else {
env[item.first] =
std::get<std::unique_ptr<ParametricExpression>>(item.second).get();
env[k] = &ctd.parametric();
}
}

// Now evaluate all the dimensions according to the values we've got.
return base_type.MapSize([&env](ConcreteTypeDim dim)
return base_type.MapSize([&](const ConcreteTypeDim& dim)
-> absl::StatusOr<ConcreteTypeDim> {
if (std::holds_alternative<ConcreteTypeDim::OwnedParametric>(dim.value())) {
auto& parametric =
Expand All @@ -2386,8 +2366,8 @@ absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceTypeRefTypeAnnotation(
if (struct_def_or.ok()) {
auto* struct_def = struct_def_or.value();
if (struct_def->IsParametric() && !node->parametrics().empty()) {
XLS_ASSIGN_OR_RETURN(
base_type, ConcretizeStructAnnotation(node, struct_def, *base_type));
XLS_ASSIGN_OR_RETURN(base_type, ConcretizeStructAnnotation(
node, struct_def, *base_type, ctx));
}
}
XLS_RET_CHECK(base_type->IsMeta());
Expand Down Expand Up @@ -2687,12 +2667,14 @@ absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceJoin(const Join* node,

// Deduces the concrete types of the arguments to a parametric function or
// proc and returns them to the caller.
absl::Status InstantiateParametricArgs(
static absl::Status InstantiateParametricArgs(
const Instantiation* inst, const Expr* callee, absl::Span<Expr* const> args,
DeduceCtx* ctx, std::vector<InstantiateArg>* instantiate_args) {
for (Expr* arg : args) {
XLS_ASSIGN_OR_RETURN(std::unique_ptr<ConcreteType> type,
DeduceAndResolve(arg, ctx));
XLS_VLOG(5) << "InstantiateParametricArgs; arg: `" << arg->ToString()
<< "` deduced: `" << type->ToString() << "` @ " << arg->span();
XLS_RET_CHECK(!type->IsMeta()) << "parametric arg: " << arg->ToString()
<< " type: " << type->ToString();
instantiate_args->push_back(InstantiateArg{std::move(type), arg->span()});
Expand Down Expand Up @@ -3143,26 +3125,34 @@ absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceInternal(
return std::move(visitor.result());
}

} // namespace

absl::StatusOr<std::unique_ptr<ConcreteType>> Resolve(const ConcreteType& type,
DeduceCtx* ctx) {
XLS_RET_CHECK(!ctx->fn_stack().empty());
const FnStackEntry& entry = ctx->fn_stack().back();
const ParametricEnv& fn_parametric_env = entry.parametric_env();
absl::StatusOr<std::unique_ptr<ConcreteType>> ResolveViaEnv(
const ConcreteType& type, const ParametricEnv& parametric_env) {
ParametricExpression::Env env;
for (const auto& [k, v] : parametric_env.bindings()) {
env[k] = v;
}

return type.MapSize([&fn_parametric_env](ConcreteTypeDim dim)
return type.MapSize([&](const ConcreteTypeDim& dim)
-> absl::StatusOr<ConcreteTypeDim> {
if (std::holds_alternative<ConcreteTypeDim::OwnedParametric>(dim.value())) {
const auto& parametric =
std::get<ConcreteTypeDim::OwnedParametric>(dim.value());
ParametricExpression::Env env = ToParametricEnv(fn_parametric_env);
return ConcreteTypeDim(parametric->Evaluate(env));
}
return dim;
});
}

} // namespace

absl::StatusOr<std::unique_ptr<ConcreteType>> Resolve(const ConcreteType& type,
DeduceCtx* ctx) {
XLS_RET_CHECK(!ctx->fn_stack().empty());
const FnStackEntry& entry = ctx->fn_stack().back();
const ParametricEnv& fn_parametric_env = entry.parametric_env();
return ResolveViaEnv(type, fn_parametric_env);
}

absl::StatusOr<std::unique_ptr<ConcreteType>> Deduce(const AstNode* node,
DeduceCtx* ctx) {
XLS_RET_CHECK(node != nullptr);
Expand Down
6 changes: 5 additions & 1 deletion xls/dslx/type_system/parametric_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ absl::Status ParametricBindConcreteTypeDim(const ConcreteType& param_type,
const ConcreteType& arg_type,
const ConcreteTypeDim& arg_dim,
ParametricBindContext& ctx) {
XLS_RET_CHECK(!arg_dim.IsParametric());
XLS_VLOG(5) << "ParametricBindConcreteTypeDim;"
<< " param_type: " << param_type << " param_dim: " << param_dim
<< " arg_type: " << arg_type << " arg_dim: " << arg_dim;

XLS_RET_CHECK(!arg_dim.IsParametric()) << arg_dim.ToString();

// See if there's a parametric symbol in the formal argument we need to bind
// vs the actual argument.
Expand Down
31 changes: 30 additions & 1 deletion xls/dslx/type_system/typecheck_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ fn f() -> u32 { p<u32:28>() }
XLS_EXPECT_OK(Typecheck(program));
}

TEST(TypecheckTest, ParametricStructInstantiatedByGlobal) {
std::string program = R"(
struct MyStruct<WIDTH: u32> {
f: bits[WIDTH]
}
fn p<FIELD_WIDTH: u32>(s: MyStruct<FIELD_WIDTH>) -> u15 {
s.f
}
const GLOBAL = u32:15;
fn f(s: MyStruct<GLOBAL>) -> u15 { p(s) }
)";
XLS_EXPECT_OK(Typecheck(program));
}

TEST(TypecheckErrorTest, ParametricInvocationConflictingArgs) {
std::string program = R"(
fn id<N: u32>(x: bits[N], y: bits[N]) -> bits[N] { x }
Expand Down Expand Up @@ -1318,6 +1332,21 @@ fn f(p: Point<3>) -> uN[6] {
XLS_EXPECT_OK(Typecheck(kProgram));
}

// TODO(https://github.com/google/xls/issues/978) Enable types other than u32 to
// be used in struct parametric instantiation.
TEST(TypecheckParametricStructInstanceTest, DISABLED_NonU32Parametric) {
const std::string_view kProgram = R"(
struct Point<N: u5, N_U32: u32 = {N as u32}> {
x: uN[N_U32],
}
fn f(p: Point<u5:3>) -> uN[3] {
p.y
}
)";
XLS_EXPECT_OK(Typecheck(kProgram));
}

// Helper for parametric struct instance based tests.
static absl::Status TypecheckParametricStructInstance(std::string program) {
program = R"(
Expand Down Expand Up @@ -1447,7 +1476,7 @@ struct S<X: u32, Y: u32> {
x: bits[X],
y: bits[Y],
}
type MyS = S<u32:3, u32:4>;
type MyS = S<3, 4>;
fn f() -> MyS { MyS{x: bits[3]:3, y: bits[4]:4 } }
)"));
}
Expand Down
Loading

0 comments on commit 75930f7

Please sign in to comment.