Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array compound generation first pass #57

Open
wants to merge 11 commits into
base: rewrite
Choose a base branch
from
4 changes: 3 additions & 1 deletion src/checker.c
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ Operand check_expr_compound(Checker *self, Expr *expr, Ty *wanted) {
case TYPE_UNION:
case TYPE_STRUCT: {
u32 index = 0;
// TODO: Ensure each index is only set once
for (int i = 0; i < arrlen(expr->ecompound.fields); i++) {
CompoundField field = expr->ecompound.fields[i];
if (field.kind == FIELD_INDEX) {
Expand Down Expand Up @@ -489,6 +490,7 @@ Operand check_expr_compound(Checker *self, Expr *expr, Ty *wanted) {
Ty *expected_type = type ? type->tarray.eltype : NULL;
u32 index = 0;
u32 max_index = 0;
// TODO: Ensure each index is only set once
for (u32 i = 0; i < arrlen(expr->ecompound.fields); i++) {
CompoundField field = expr->ecompound.fields[i];
if (field.kind == FIELD_NAME) {
Expand All @@ -507,7 +509,7 @@ Operand check_expr_compound(Checker *self, Expr *expr, Ty *wanted) {
index = (u32) op.val.i;
}
if (type && type->tarray.length && index >= type->tarray.length) {
error(self, field.key->range,
error(self, (field.key ?: field.val)->range,
"Array index %lu is beyond the max index of %lu for type %s",
index, type->tarray.length, tyname(type));
}
Expand Down
142 changes: 83 additions & 59 deletions src/llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ extern "C" {
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
Expand Down Expand Up @@ -347,7 +348,8 @@ Type *llvm_type(IRContext *c, Ty *type, bool do_not_hide_behind_pointer = false)
case TYPE_ARRAY: {
Type *base = llvm_type(c, type->tarray.eltype);
Type *ty = ArrayType::get(base, type->tarray.length);
return ty;
if (do_not_hide_behind_pointer) return ty;
return PointerType::get(ty, 0);
}
case TYPE_SLICE: {
if (type == type_string) return c->ty.rawptr; // FIXME: Temp hack while we don't have slices
Expand All @@ -369,7 +371,7 @@ Type *llvm_type(IRContext *c, Ty *type, bool do_not_hide_behind_pointer = false)
if (type->sym) {
const char *name = type->sym->external_name ?: type->sym->name;
StructType *ty = StructType::create(c->context, elements, name);
if (type->sym) type->sym->userdata = ty;
type->sym->userdata = ty;
return ty;
}
return StructType::get(c->context, elements);
Expand Down Expand Up @@ -462,7 +464,7 @@ DIType *llvm_debug_type(IRContext *c, Ty *type, bool do_not_hide_behind_pointer
subscripts.push_back(c->dbg.builder->getOrCreateSubrange(0, eltype->tarray.length));
eltype = eltype->tarray.eltype;
}
Type *irtype = llvm_type(c, type);
Type *irtype = llvm_type(c, type, true);
u64 size = c->data_layout.getTypeSizeInBits(irtype);
u32 align = c->data_layout.getPrefTypeAlignment(irtype) * 8;
DIType *deltype = llvm_debug_type(c, type->tarray.eltype);
Expand Down Expand Up @@ -750,6 +752,7 @@ Value *create_coerce(IRContext *self, Value *val, Expr *expr, bool is_lvalue = f
goto start;
}
if (isa<FunctionType>(val_ty)) return val;

// if (val_ty->getPointerElementType() == dst_ty) return create_load(self, val); // FIXME: Do not do this....
// FIXME: IF isa<ArrayType> What do?
return self->builder.CreatePointerBitCastOrAddrSpaceCast(val, dst_ty);
Expand Down Expand Up @@ -968,29 +971,21 @@ IRValue emit_expr_compound(IRContext *self, Expr *expr) {
}
case TYPE_ARRAY: {
Type *eltype = llvm_type(self, operand.type->tarray.eltype);
ArrayType *type = (ArrayType *) llvm_type(self, operand.type);
ArrayType *type = (ArrayType *) llvm_type(self, operand.type, true);

Value *agg = Constant::getNullValue(type);
u32 index = 0;
for (i64 i = 0; i < arrlen(expr->ecompound.fields); i++) {
Value *val = emit_expr(self, expr->ecompound.fields[i].val).val;
if (expr->ecompound.fields[i].key) {
Operand op = hmget(self->package->operands, expr->ecompound.fields[i].key);
index = (u32) op.val.u;
}
agg = self->builder.CreateInsertValue(agg, val, index);
index++;
}
return irval(agg);

if (!expr->ecompound.fields)
return irval(agg);

Value *value;
if (!expr->ecompound.fields) return irval(value);
bool is_all_members_constant = true;
bool does_have_gaps = operand.type->tarray.length != arrlen(expr->ecompound.fields);

std::vector<Value *> values;
std::vector<Constant *> constants;
for (i64 i = 0; i < arrlen(expr->ecompound.fields); i++) {
Value *el = emit_expr(self, expr->ecompound.fields[i].val).val;
is_all_members_constant |= isa<Constant>(el);
is_all_members_constant &= isa<Constant>(el);
if (Constant *constant = dyn_cast<Constant>(el)) {
constants.push_back(constant);
} else {
Expand All @@ -999,45 +994,60 @@ IRValue emit_expr_compound(IRContext *self, Expr *expr) {
}
values.push_back(el);
}
u32 alignment = self->data_layout.getPrefTypeAlignment(type);
// FIXME: can't do this if we aren't in a function.
Constant *constant = ConstantArray::get(type, constants);
return irval(constant);

if (constant->isZeroValue()) {
GlobalVariable *global = new GlobalVariable(
*self->module, type, true, GlobalValue::PrivateLinkage, constant,
"compound.lit");
return irval(global);

if (!self->fn) {
return irval(llvm::ConstantArray::get(type, constants));
}

u32 alignment = self->data_layout.getPrefTypeAlignment(type);
u32 size = (u32)self->data_layout.getTypeStoreSize(type);

AllocaInst *alloca = emit_entry_alloca(self, type, "compound.lit", alignment);

if (is_all_members_constant) {
u64 target_index = 0;
for (i64 i = 0; i < arrlen(expr->ecompound.fields); i++) {
CompoundField field = expr->ecompound.fields[i];
u64 target_index = hmget(self->package->operands, field.key).val.u;
value = self->builder.CreateInsertValue(value, values[i], {(u32) target_index});
}
if (isa<Constant>(value)) {
GlobalVariable *global = new GlobalVariable(
*self->module, type, true, GlobalValue::PrivateLinkage, (Constant *)value,
"compound.lit");
self->builder.CreateMemCpy(
alloca, alignment, global, global->getAlignment(),
type->getPrimitiveSizeInBits() / 8);
value = alloca;
} else {
// create_store(self, value, alloca);
if (field.kind == FIELD_INDEX) {
target_index = (u32) hmget(self->package->operands, field.key).val.u;
}

agg = self->builder.CreateInsertValue(agg, constants[i], {(u32)target_index});
target_index++;
}

GlobalVariable *global = new GlobalVariable(
*self->module, type, true, GlobalValue::PrivateLinkage, (Constant *)agg,
"compound.lit");

self->builder.CreateMemCpy(
alloca, alignment, global, global->getAlignment(), size);
} else {
// NOTE: memset requires `zero` to be a char
llvm::Value *zero = llvm::ConstantInt::get(self->ty.i8, 0);
llvm::Value *element = self->builder.CreateInBoundsGEP(alloca, {zero, zero});

if (does_have_gaps) {
self->builder.CreateMemSet(alloca, zero, size, alignment);
}

u64 target_index = 0;

for (i64 i = 0; i < arrlen(expr->ecompound.fields); i++) {
CompoundField field = expr->ecompound.fields[i];
u64 target_index = hmget(self->package->operands, field.key).val.u;
value = self->builder.CreateInsertValue(
alloca, values[i], {0, (u32) target_index});
if (field.kind == FIELD_INDEX) {
target_index = (u32) hmget(self->package->operands, field.key).val.u;
}

llvm::Value *offset = llvm::ConstantInt::get(self->ty.i64, target_index);
self->builder.CreateInBoundsGEP(element, offset);
create_store(self, values[i], element);

target_index++;
}
}
return irval(alloca);

return irval(alloca, true);
}
case TYPE_SLICE:
default:
Expand Down Expand Up @@ -1211,7 +1221,7 @@ IRValue emit_expr_call(IRContext *self, Expr *expr) {
if (is_cvargs && last_arg) { // C ABI rules (TODO: Apply to all parameters for c calls)
if (is_integer(arg_operand.type) && arg_operand.type->size < 4) {
val = is_signed(arg_operand.type) ?
self->builder.CreateSExt(val, self->ty.i32) : self->builder.CreateZExt(val, self->ty.u32);
self->builder.CreateSExt(val, self->ty.i32) : self->builder.CreateZExt(val, self->ty.u32);
} else if (is_float(arg_operand.type) && arg_operand.type->size < 8) {
self->builder.CreateFPExt(val, self->ty.f64);
}
Expand Down Expand Up @@ -1251,6 +1261,12 @@ IRValue emit_expr_index(IRContext *self, Expr *expr) {

IRValue emit_expr_slice(IRContext *self, Expr *expr) { return {}; }

IRValue emit_expr_struct(IRContext *self, Expr *expr) {
Operand operand = hmget(self->package->operands, expr);
llvm::StructType *type = (llvm::StructType *) llvm_type(self, operand.type);
return irval((Value *)type);
}

IRValue emit_expr_func(IRContext *self, Expr *expr) {
TRACE(EMITTING);
Operand operand = hmget(self->package->operands, expr);
Expand Down Expand Up @@ -1376,7 +1392,6 @@ IRValue emit_expr_functype(IRContext *self, Expr *expr) { fatal("Unimplemented %
IRValue emit_expr_slicetype(IRContext *self, Expr *expr) { fatal("Unimplemented %s", __FUNCTION__); }
IRValue emit_expr_array(IRContext *self, Expr *expr) { fatal("Unimplemented %s", __FUNCTION__); }
IRValue emit_expr_pointer(IRContext *self, Expr *expr) { fatal("Unimplemented %s", __FUNCTION__); }
IRValue emit_expr_struct(IRContext *self, Expr *expr) { fatal("Unimplemented %s", __FUNCTION__); }
IRValue emit_expr_union(IRContext *self, Expr *expr) { fatal("Unimplemented %s", __FUNCTION__); }
IRValue emit_expr_enum(IRContext *self, Expr *expr) { fatal("Unimplemented %s", __FUNCTION__); }

Expand Down Expand Up @@ -1776,7 +1791,7 @@ void emit_decl_var_global(IRContext *self, Decl *decl) {
} else {
init = Constant::getNullValue(llvm_type(self, sym->type));
}
Type *type = llvm_type(self, sym->type);
Type *type = llvm_type(self, sym->type, true);
GlobalVariable *global = new GlobalVariable(
*self->module, type, /* IsConstant */ false, GlobalValue::ExternalLinkage,
init, sym->external_name ?: sym->name);
Expand Down Expand Up @@ -1813,7 +1828,8 @@ void emit_decl_var(IRContext *self, Decl *decl) {
Sym *sym = hmget(self->package->symbols, name);
if (sym->userdata) return; // Already emitted

Value *rhs = NULL;
Type *type = llvm_type(self, sym->type, true);
Value *rhs = Constant::getNullValue(type);
bool rhs_is_alloca = false;
if (decl->dvar.vals) {
Expr *expr = decl->dvar.vals[index];
Expand All @@ -1838,19 +1854,26 @@ void emit_decl_var(IRContext *self, Decl *decl) {
return;
}
}
Type *type = llvm_type(self, sym->type);

Value *alloca;
if (rhs_is_alloca) {
alloca = rhs;
sym->userdata = rhs;
} else {
// FIXME: Alloca can't happen at global scope instead use a global variable and add
// check that we only initialize with global variables.
alloca = emit_entry_alloca(self, type, sym->name, sym->type->align);
set_debug_pos(self, decl->range);
if (rhs) create_coerced_store(self, rhs, alloca);
if (self->fn) {
AllocaInst *alloca = emit_entry_alloca(self, type, sym->name, sym->type->align);
create_coerced_store(self, rhs, alloca);
sym->userdata = alloca;
} else {
ASSERT(isa<Constant>(rhs));
GlobalVariable *global = new GlobalVariable(
*self->module, type, false, GlobalValue::ExternalLinkage, (Constant *) rhs,
sym->external_name ?: sym->name);
global->setAlignment(sym->type->align);
global->setExternallyInitialized(false);
sym->userdata = global;
}
}
sym->userdata = alloca;
set_debug_pos(self, name->range);

if (compiler.flags.debug) declare_auto_variable(self, sym);
}
}
Expand All @@ -1865,7 +1888,7 @@ void emit_decl_val(IRContext *self, Decl *decl) {
arrpush(self->symbols, sym);
Value *value = emit_expr(self, decl->dval.val).val;
arrpop(self->symbols);
if (isa<Function>(value)) {
if (isa<Function>(value) || isa<StructType>((Type *)value)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is bad

sym->userdata = value;
return;
}
Expand Down Expand Up @@ -2176,3 +2199,4 @@ void print(Module *module) {
puts(buf.c_str());
}
#endif

10 changes: 10 additions & 0 deletions test/array.kai
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
testArray := [3]i32 { 1, 2, 3 }

test :: fn(b: i32) -> void {
a := [..]i32 { 0, 1, 2, 3}
}

main :: fn() -> void {
a := 11
b := a + 11
}