Permalink
Browse files

First working iteration of user types.

-It is now possible to:
  -Define new types
  -Initialize user types
  • Loading branch information...
1 parent b12bedc commit 4ca7bac907adf583a67d123ceec1c9bdbefa1c5b @fimad committed Jun 9, 2012
Showing with 205 additions and 25 deletions.
  1. +6 −2 ast.sml
  2. +5 −2 exec.sml
  3. +4 −0 lexer.lex
  4. +59 −14 llvm-translate.sml
  5. +76 −6 llvm.sml
  6. +32 −1 parser.grm
  7. +13 −0 user_types.eval
  8. +10 −0 user_types_def.eval
View
@@ -5,7 +5,8 @@ sig
(* Definition of the AST
* ast is a datatype encapsulating the abstract syntax of the language e *)
datatype ast
- = Var of string
+ = Program of (LLVM.UserType list)*ast
+ | Var of string
| Dim of int*string (*statically gets the int level dimension of an array*) (*0 for non arrays*)
| Block of ast list
| Print of ast
@@ -31,6 +32,7 @@ sig
| More of ast*ast
| MoreEq of ast*ast
| Apply of ast*(ast list)
+ | Case of ast*(string*string list*ast) list
| If of ast*ast*ast
| For of ast*ast*ast*ast
| Assign of string*ast (*like a let, but assumes the variable is already defined*)
@@ -64,7 +66,8 @@ end
structure Ast :> AST =
struct
datatype ast
- = Var of string
+ = Program of (LLVM.UserType list)*ast
+ | Var of string
| Int of int
| Float of real
| Bool of int
@@ -88,6 +91,7 @@ struct
| More of ast*ast
| MoreEq of ast*ast
| Apply of ast*(ast list)
+ | Case of ast*(string*string list*ast) list
| If of ast*ast*ast
| For of ast*ast*ast*ast
| Assign of string*ast (*like a let, but assumes the variable is already defined*)
View
@@ -57,8 +57,11 @@ val optimizeLevel =
val shouldDot = shouldArg "-dot"
val _ = LLVM_Translate.compile result handle (LLVM_Translate.TranslationError what) => let
- val _ = (print (concat ["Translation Error: ",what,"\n"]), OS.Process.exit OS.Process.failure)
- in () end
+ val _ = (print (concat ["Translation Error: ",what,"\n"]), OS.Process.exit OS.Process.failure)
+ in () end
+ handle (LLVM.LLVMError what) => let
+ val _ = (print (concat ["Translation Error: ",what,"\n"]), OS.Process.exit OS.Process.failure)
+ in () end
val program = LLVM_Translate.getProgram ()
fun optimizeMethod (name,ty,args,code) = let
val bbGraph = SSA.completeSSA (BB.createBBGraph code)
View
@@ -26,6 +26,10 @@ ws = [\ \t];
%%
\n => (pos := (!pos) + 1; lex());
{ws}+ => (lex());
+"type" => (Tokens.TYPE_WORD(!pos,!pos));
+"of" => (Tokens.OF(!pos,!pos));
+"|" => (Tokens.BAR(!pos,!pos));
+"case" => (Tokens.CASE(!pos,!pos));
"int" => (Tokens.TYPE_INT(!pos,!pos));
"float" => (Tokens.TYPE_FLOAT(!pos,!pos));
"bool" => (Tokens.TYPE_BOOL(!pos,!pos));
View
@@ -408,19 +408,63 @@ struct
in
(l,ty,code1@code2@alias_code@cast_code@[LLVM.Div (l,ty,LLVM.Variable var1,LLVM.Variable var2)])
end
- | translate (Ast.Apply ((Ast.Var v),exps)) scope fscope = let
- val argsAndCodes = map (evalArg scope fscope) exps
- val code = (foldr (op @) [] (map (#1) argsAndCodes))
- val args = (map (fn (_,r,t) => (r,t)) argsAndCodes)
- (*change arrays so that they are passed as pointers*)
- val args = map (fn (x,ty) => case ty of
- LLVM.array _ => (x,LLVM.ptr ty)
- | _ => (x,ty)) args
- val l = makenextvar ()
- val ty = typeForFunc v fscope
- in
- (l,ty,code@[LLVM.Call (l,ty,v,args)])
- end
+ | translate (Ast.Apply ((Ast.Var v),exps)) scope fscope =
+ if LLVM.isUserTypeForm v then
+ let
+ val form = v
+
+ val var_malloc = makenextvar ()
+ val var_form_type = makenextvar ()
+ val var_T_type = makenextvar ()
+ val var_form_ptr = makenextvar ()
+
+ val (user_type as (LLVM.usertype name)) = LLVM.getTypeForForm form
+ val form_type = LLVM.usertype_form (name,form)
+ val types = LLVM.getFormTypes form
+ (* error checking, wut?*)
+ val _ = if length types <> length exps then raise (TranslationError (concat["Incorrect number of expressions for '",form,"'"])) else ()
+
+ fun initialize i [] = []
+ | initialize i ((t,exp)::xs) = let
+ val (code,arg,ety) = evalArg scope fscope exp
+ val (alias_code,[var]) = ensureVars [arg]
+ val (cast_code,ty) = resolveType' [] t [(var,ety)]
+ val var_val_ptr = makenextvar ()
+ in
+ code @alias_code @cast_code
+ (* store the value in the struct *)
+ @[ LLVM.GetElementPtr (var_val_ptr,LLVM.ptr form_type,LLVM.Variable var_form_type,LLVM.Int i)
+ , LLVM.Store (t,LLVM.Variable var,LLVM.Variable var_val_ptr)
+ ]@(initialize (i+1) xs)
+ end
+
+ val code = [
+ LLVM.Call (var_malloc, LLVM.ptr LLVM.i8, "malloc", [(LLVM.Int (LLVM.sizeOfType user_type),LLVM.i32)])
+ , LLVM.Bitcast (var_T_type, LLVM.ptr LLVM.i8, LLVM.Variable var_malloc, LLVM.ptr LLVM.usertype_parent)
+ (* set the form tracker *)
+ , LLVM.GetElementPtr (var_form_ptr,LLVM.ptr LLVM.usertype_parent,LLVM.Variable var_T_type,LLVM.Int 0)
+ , LLVM.Store (LLVM.i32,LLVM.Int (LLVM.getFormIndex form),LLVM.Variable var_form_ptr)
+ (* grab a pointer to the data portion *)
+ , LLVM.GetElementPtr (var_form_type,LLVM.ptr LLVM.usertype_parent,LLVM.Variable var_T_type,LLVM.Int 1)
+ , LLVM.Bitcast (var_form_type, LLVM.ptr LLVM.i8, LLVM.Variable var_malloc, LLVM.ptr form_type)
+ ]@(initialize 0 (ListPair.zip (types,exps)))
+ in
+ (var_malloc,user_type,code)
+ end
+ else
+ let
+ val argsAndCodes = map (evalArg scope fscope) exps
+ val code = (foldr (op @) [] (map (#1) argsAndCodes))
+ val args = (map (fn (_,r,t) => (r,t)) argsAndCodes)
+ (*change arrays so that they are passed as pointers*)
+ val args = map (fn (x,ty) => case ty of
+ LLVM.array _ => (x,LLVM.ptr ty)
+ | _ => (x,ty)) args
+ val l = makenextvar ()
+ val ty = typeForFunc v fscope
+ in
+ (l,ty,code@[LLVM.Call (l,ty,v,args)])
+ end
| translate (Ast.Apply _) scope fscope = raise (TranslationError "Can only apply on variables")
| translate (Ast.For (init_exp,cond_exp,step_exp,doexp)) scope fscope = let
val cnd_label = makenextlabel ()
@@ -556,7 +600,8 @@ struct
| getFunScope (Ast.Let (_,_,exp)) = getFunScope exp
| getFunScope _ = []
- fun compile ast = let
+ fun compile (Ast.Program (types,ast)) = let
+ val _ = map LLVM.addUserType types (*add the user types to the scope*)
val funScope = getFunScope ast
val (mainBody,vres,vty) = evalArg [] funScope ast
val res = case vres of
View
@@ -2,7 +2,18 @@
structure LLVM =
struct
- datatype Type = notype | i8 | i32 | i1 | float | array of int*Type | ptr of Type
+ datatype Type
+ = notype
+ | usertype of string
+ | i1
+ | i8
+ | i32
+ | float
+ | array of int*Type
+ | ptr of Type
+ | usertype_form of string*string (*used in intermediate operations when dealing with forms*)
+ | usertype_parent
+ type UserType = string*((string*(Type list)) list)(*type name, forms*) (*forms = (form name, types in form)*)
type Result = string
datatype Arg =
Int of int
@@ -47,21 +58,77 @@ struct
(* An entire program is just a collection of Methods *)
type Program = Method list
- fun sizeOfType i32 = 32
- | sizeOfType i8 = 8
- | sizeOfType (ptr _) = 32
- | sizeOfType float = 64
+ fun isPrimitive i32 = true
+ | isPrimitive i8 = true
+ | isPrimitive i1 = true
+ | isPrimitive float = true
+ | isPrimitive (ptr _) = true (*pointers are primitives?*)
+ | isPrimitive _ = false
+
+ (*hacks for user types*)
+ exception LLVMError of string;
+ val userTypeScope = ref []
+ fun addUserType ut = userTypeScope := ut::(!userTypeScope)
+ fun isUserType name = List.exists (fn (name',_) => name = name') (!userTypeScope)
+ fun isUserTypeForm form = List.exists (fn (_,forms) => List.exists (fn (form',_) => form = form') forms) (!userTypeScope)
+ fun getUserType name = (case (List.filter (fn (name',_) => name = name') (!userTypeScope)) of
+ [] => raise (LLVMError (concat ["User type '",name,"' does not exist!"]))
+ | (x::_) => x)
+ fun getTypeForForm form = (case (List.filter (fn (_,forms) => List.exists (fn (f,_) => f=form) forms) (!userTypeScope)) of
+ [] => raise (LLVMError (concat ["No user type has the form '",form,"'!"]))
+ | ((name,_)::_) => usertype name)
+ fun getUserTypeForForm form = let
+ val (usertype name) = getTypeForForm form
+ in getUserType name end
+ fun getFormTypes form = (case (List.filter (fn (f,_) => f=form) (#2 ((getUserTypeForForm) form))) of
+ [] => raise (LLVMError (concat ["No user type has the form '",form,"'!"]))
+ | ((_,types)::_) => types)
+ fun getFormIndex form = let
+ fun getIndex i ((f,_)::fs) = if form = f then i else getIndex (i+1) fs (*should throw error before it reaches end of list*)
+ val (_,forms) = (getUserTypeForForm) form
+ in getIndex 0 forms end
+ fun getIndexForForm name i = let
+ val (name,forms) = getUserType name
+ val _ = if length forms < i then raise (LLVMError "Index is TOO HIGH! should never happen though...") else ()
+ in
+ List.nth (forms,i)
+ end
+
+ fun sizeOfType i32 = 4
+ | sizeOfType i8 = 1
+ | sizeOfType (ptr _) = 4
+ | sizeOfType float = 8
| sizeOfType (array (size,ty)) = size*(sizeOfType ty)
+ | sizeOfType (usertype name) = let
+ val (ut,forms) = getUserType name
+ val sizes = (map (fn (_,ts) => foldr (op +) 0 (map sizeOfType (map (fn t => if isPrimitive t then t else ptr t) ts))) forms)
+ val maxSize = foldr Int.max 0 sizes
+ in
+ maxSize + 4 (* the 4 is for the form int *)
+ end
fun printType i32 = "i32"
| printType i8 = "i8"
| printType i1 = "i1"
| printType float = "double"
+ | printType usertype_parent = "%T"
| printType (ptr ty) = concat [printType ty,"*"]
(*| printType (array (size,ty)) = printType (ptr ty)*)
| printType (array (size,ty)) = concat ["[",(Int.toString size)," x ",(printType ty),"]"]
+ | printType (usertype _) = printType (ptr i8)
+ | printType (usertype_form (name,form)) = concat ["%T.",name,".",form]
| printType notype = ""
+ fun printUserType (name,forms) = let
+ fun printTypes [] = []
+ | printTypes [t] = [printType t]
+ | printTypes (t::ts) = (printType t)::","::(printTypes ts)
+ fun printForm name (form,types) =
+ concat (["%T.",name,".",form," = type {"]@(printTypes types)@["}\n"])
+ in
+ concat (map (printForm name) forms)
+ end
+
fun arrayType (array (size,array sub)) = arrayType (array sub)
| arrayType (array (size,ty)) = SOME ty
| arrayType _ = NONE
@@ -225,7 +292,10 @@ struct
end
fun printProgram program = concat [
- "@.print_int_str = private constant [4 x i8] c\"%d\\0A\\00\", align 1\n"
+ "%T = type { i32, i8 }\n"
+ , concat (map printUserType (!userTypeScope))
+ , "\n"
+ , "@.print_int_str = private constant [4 x i8] c\"%d\\0A\\00\", align 1\n"
, "@.print_float_str = private constant [4 x i8] c\"%f\\0A\\00\", align 1\n\n"
, (foldl (fn (a,b) => concat [a,"\n",b]) "" (map printMethod program))
, "declare i32 @printf(i8*, ...)\n"
View
@@ -31,6 +31,7 @@ structure Ast = Ast
| EMPTY_ARRAY | START_ARRAY | END_ARRAY | POUND
| PRINT
| DIM
+ | OF | BAR | TYPE_WORD | CASE | NEW
| TYPE_BOOL | TYPE_FLOAT | TYPE_INT | COLON
| EOF
%nonterm LETDEF of Ast.ast
@@ -45,10 +46,32 @@ structure Ast = Ast
| DO_BLOCK of Ast.ast
| PRIMITIVE_TYPE of LLVM.Type
| TYPE of LLVM.Type
+ | TYPES of LLVM.UserType list
+ | TYPE_DEF of LLVM.UserType
+ | TYPE_FORMS of (string*(LLVM.Type list)) list
+ | TYPE_FORM of string*(LLVM.Type list)
+ | TYPE_LIST of LLVM.Type list
+ | ID_LIST of string list
+ | CASE_THENS of (string*string list*Ast.ast) list
+ | CASE_THEN of string*string list*Ast.ast
%%
-START : STATEMENT EOF (STATEMENT)
+START : TYPES STATEMENT EOF (Ast.Program (TYPES,STATEMENT))
+ | STATEMENT EOF (Ast.Program ([],STATEMENT))
+
+TYPES : TYPE_DEF ([TYPE_DEF])
+ | TYPE_DEF TYPES (TYPE_DEF::TYPES)
+
+TYPE_DEF : TYPE_WORD ID EQ TYPE_FORMS ((ID,TYPE_FORMS))
+
+TYPE_FORMS : TYPE_FORM ([TYPE_FORM])
+ | TYPE_FORM BAR TYPE_FORMS (TYPE_FORM::TYPE_FORMS)
+
+TYPE_FORM : ID OF TYPE_LIST ((ID,TYPE_LIST))
+
+TYPE_LIST : TYPE ([TYPE])
+ | TYPE MULT TYPE_LIST (TYPE::TYPE_LIST)
STATEMENT : EXP (EXP)
| DO_BLOCK (DO_BLOCK)
@@ -78,6 +101,11 @@ EXP : TERM (TERM)
| ID ARRAY_INDEX ASSIGN EXP (Ast.AssignArray(ID,ARRAY_INDEX,EXP))
| IF EXP THEN STATEMENT ELSE STATEMENT (Ast.If(EXP,STATEMENT1,STATEMENT2))
| FOR LPAREN EXP SEMICOLON EXP SEMICOLON EXP RPAREN STATEMENT (Ast.For(EXP1,EXP2,EXP3,STATEMENT1))
+ | CASE EXP OF CASE_THENS (Ast.Case(EXP,CASE_THENS))
+
+CASE_THENS : CASE_THEN ([CASE_THEN])
+ | CASE_THEN CASE_THENS (CASE_THEN::CASE_THENS)
+CASE_THEN : ID LPAREN ID_LIST RPAREN THEN EXP (ID,ID_LIST,EXP)
TERM : ID (Ast.Var(ID))
| ID ARRAY_INDEX (Ast.ArrayIndex(ID,ARRAY_INDEX))
@@ -101,6 +129,8 @@ EXPLIST : EXP ([EXP])
| EXP COMMA EXPLIST (EXP::EXPLIST)
| EXP SEMICOLON EXPLIST (EXP::EXPLIST)
+ID_LIST : ID ([ID])
+ | ID COMMA ID_LIST (ID::ID_LIST)
LETDEF : LET ID EQ STATEMENT IN STATEMENT (Ast.Let(ID,STATEMENT1,STATEMENT2))
| FUN ID LPAREN PARAMS RPAREN COLON TYPE EQ STATEMENT IN STATEMENT (Ast.Fun(ID,PARAMS,TYPE,STATEMENT1,STATEMENT2))
@@ -110,6 +140,7 @@ PARAMS : ID COLON TYPE ([(ID,TYPE)])
TYPE : PRIMITIVE_TYPE (PRIMITIVE_TYPE)
| START_ARRAY PRIMITIVE_TYPE END_ARRAY ARRAY_DIM (foldr (fn (d,t) => LLVM.array (d,t)) PRIMITIVE_TYPE ARRAY_DIM)
+ | ID (LLVM.usertype ID)
PRIMITIVE_TYPE : TYPE_INT (LLVM.i32)
| TYPE_FLOAT (LLVM.float)
View
@@ -0,0 +1,13 @@
+type myType
+ = IntPair of int*int
+ | FloatPair of float*float
+type weird
+ = weird of myType*myType
+
+let t = IntPair (1,4) in
+let f = FloatPair (5,6) in
+case t of
+ IntPair (a,b) then
+ print a
+ FloatPair (a,b) then
+ print b
View
@@ -0,0 +1,10 @@
+type myType
+ = IntPair of int*int
+ | FloatPair of float*float
+type weird
+ = weird of myType*myType
+
+do
+ weird (IntPair(1,2),FloatPair(0.47,0.42)) ;
+ pass
+end

0 comments on commit 4ca7bac

Please sign in to comment.