diff --git a/spec/compiler/codegen/const_spec.cr b/spec/compiler/codegen/const_spec.cr index 22c5ef26dfdc..e94e7af3aa41 100644 --- a/spec/compiler/codegen/const_spec.cr +++ b/spec/compiler/codegen/const_spec.cr @@ -380,4 +380,63 @@ describe "Codegen: const" do Foo.new.z )).to_i.should eq(42) end + + it "inlines simple const" do + mod = codegen(%( + CONST = 1 + CONST + )) + + mod.to_s.should_not contain("CONST") + end + + it "inlines enum value" do + mod = codegen(%( + enum Foo + CONST + end + + Foo::CONST + )) + + mod.to_s.should_not contain("CONST") + end + + it "inlines const with math" do + mod = codegen(%( + CONST = (1 + 2) * 3 + )) + + mod.to_s.should_not contain("CONST") + end + + it "inlines const referencing another const" do + mod = codegen(%( + OTHER = 1 + + CONST = OTHER + CONST + )) + + mod.to_s.should_not contain("CONST") + mod.to_s.should_not contain("OTHER") + end + + it "inlines bool const" do + mod = codegen(%( + CONST = true + CONST + )) + + mod.to_s.should_not contain("CONST") + end + + it "inlines char const" do + mod = codegen(%( + CONST = 'a' + CONST + )) + + mod.to_s.should_not contain("CONST") + end end diff --git a/src/compiler/crystal/codegen/codegen.cr b/src/compiler/crystal/codegen/codegen.cr index 943eb0f990d9..6f13d02ae2b6 100644 --- a/src/compiler/crystal/codegen/codegen.cr +++ b/src/compiler/crystal/codegen/codegen.cr @@ -235,6 +235,8 @@ module Crystal @program.class_var_and_const_initializers.each do |initializer| case initializer when Const + # Simple constants are never initialized: they are always inlined + next if initializer.compile_time_value next unless initializer.simple? initialize_simple_const(initializer) @@ -873,7 +875,7 @@ module Crystal def codegen_assign(target : Path, value, node) const = target.target_const.not_nil! - if const.used? && !const.simple? + if const.used? && !const.simple? && !const.compile_time_value initialize_const(const) end @last = llvm_nil diff --git a/src/compiler/crystal/codegen/const.cr b/src/compiler/crystal/codegen/const.cr index 7b4cde59631a..78ab68effa7a 100644 --- a/src/compiler/crystal/codegen/const.cr +++ b/src/compiler/crystal/codegen/const.cr @@ -154,12 +154,27 @@ class Crystal::CodeGenVisitor end def read_const(const) - @last = read_const_pointer(const) - @last = to_lhs @last, const.value.type + # We inline constants. Otherwise we use an LLVM const global. + @last = + case value = const.compile_time_value + when Bool then int1(value ? 1 : 0) + when Char then int32(value.ord) + when Int8 then int8(value) + when Int16 then int16(value) + when Int32 then int32(value) + when Int64 then int64(value) + when UInt8 then int8(value) + when UInt16 then int16(value) + when UInt32 then int32(value) + when UInt64 then int64(value) + else + last = read_const_pointer(const) + to_lhs last, const.value.type + end end def read_const_pointer(const) - if const == @program.argc || const == @program.argv + if const == @program.argc || const == @program.argv || const.initializer global_name = const.llvm_name global = declare_const(const) diff --git a/src/compiler/crystal/codegen/types.cr b/src/compiler/crystal/codegen/types.cr index fea9856e0141..95f3a6d7e384 100644 --- a/src/compiler/crystal/codegen/types.cr +++ b/src/compiler/crystal/codegen/types.cr @@ -171,8 +171,36 @@ module Crystal "#{llvm_name}:init" end + # Returns `true` if this constant's value is a simple literal, like + # `nil`, a number, char, string or symbol literal. def simple? value.simple_literal? end + + @compile_time_value : (Int16 | Int32 | Int64 | Int8 | UInt16 | UInt32 | UInt64 | UInt8 | Bool | Char | Nil) + @computed_compile_time_value = false + + # Returns a value if this constant's value can be evaluated at + # compile time (things like `1 + 2` and such). Returns nil otherwise. + def compile_time_value + unless @computed_compile_time_value + @computed_compile_time_value = true + + case value = self.value + when BoolLiteral + @compile_time_value = value.value + when CharLiteral + @compile_time_value = value.value + else + case type = value.type? + when IntegerType, EnumType + interpreter = MathInterpreter.new(namespace, visitor) + @compile_time_value = interpreter.interpret(value) rescue nil + end + end + end + + @compile_time_value + end end end