Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
branch: master
Fetching contributors…

Cannot retrieve contributors at this time

file 193 lines (177 sloc) 7.841 kb
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
(* make sure we can compile native code *)
let _ = Llvm_executionengine.initialize_native_target()

module StringMap = Map.Make(String)

(* maps variable names in DSL to LLVM registers *)
type name_env = Llvm.llvalue StringMap.t

(* for simple compilers just use one context *)
let context : Llvm.llcontext = Llvm.global_context()

(* annoying to recreate types repeatedly, so just make globals *)
let float_t : Llvm.lltype = Llvm.float_type context

(* pre-create these common numbers *)
let zero : Llvm.llvalue = Llvm.const_float float_t 0.0
let one : Llvm.llvalue = Llvm.const_float float_t 1.0

(* State which needs to get passed into LLVM functions *)
type llvm_state = {
  builder : Llvm.llbuilder;
  llvm_module : Llvm.llmodule;
  llvm_fn : Llvm.llvalue;
}

let rec compile_exp state names = function
  | Dsl.Num f -> Llvm.const_float float_t f
  | Dsl.Var x ->
    if StringMap.mem x names then StringMap.find x names
    else failwith ("Undefined variable " ^ x)
  | Dsl.Add (x,y) ->
    let x' = compile_exp state names x in
    let y' = compile_exp state names y in
    Llvm.build_fadd x' y' "add_result" state.builder
  | Dsl.Sub (x,y) ->
    let x' = compile_exp state names x in
    let y' = compile_exp state names y in
    Llvm.build_fsub x' y' "sub_result" state.builder
  | Dsl.Div (x,y) ->
    let x' : Llvm.llvalue = compile_exp state names x in
    let y' : Llvm.llvalue = compile_exp state names y in
    Llvm.build_fdiv x' y' "div_result" state.builder
  | Dsl.Mult (x,y) ->
    let x' : Llvm.llvalue = compile_exp state names x in
    let y' : Llvm.llvalue = compile_exp state names y in
    Llvm.build_fmul x' y' "mult_result" state.builder
  | Dsl.Sum(loop_var_name, start, stop, body) ->
    let start : Llvm.llvalue = compile_exp state names start in
    let stop : Llvm.llvalue = compile_exp state names stop in
    (* what block are we currently inserting into *)
    let old_block : Llvm.llbasicblock = Llvm.insertion_block state.builder in
    (* create a loop header where we test whether the loop should continue *)
    let loop_header : Llvm.llbasicblock =
      Llvm.append_block context "loop_header" state.llvm_fn
    in
    (* make the original code jump into the loop header we've built *)
    let _ = Llvm.build_br loop_header state.builder in
    (* move to the loop header we just created *)
    Llvm.position_at_end loop_header state.builder;
    (* initially the phi node only knows about what happened before the loop, *)
    (* we'll add another incoming edge later *)
    (* To start, we initialize the result to zero *)
    let result : Llvm.llvalue =
      Llvm.build_phi [zero, old_block] "result" state.builder
    in
    (* also initialize the loop variable to whatever the program specifies *)
    let loop_var : Llvm.llvalue =
      Llvm.build_phi [start, old_block] loop_var_name state.builder
    in
    (* The module Llvm.Fcmp contains two sorts of comparisons. *)
    (* Ordered comparisons, such as Ogt (ordered greater than) will be false*)
    (* if an argument is NaN. Unordered don't care. *)
    let cond =
      Llvm.build_fcmp Llvm.Fcmp.Ogt loop_var stop "loop_cond" state.builder
    in
    (* create another block for the loop body itself *)
    let loop_body : Llvm.llbasicblock =
      Llvm.append_block context "loop_body" state.llvm_fn
    in
    let after_loop : Llvm.llbasicblock =
      Llvm.append_block context "after_loop" state.llvm_fn
    in
    (* check whether the loop_var exceeds the end_val, if so jump back *)
    let _ = Llvm.build_cond_br cond after_loop loop_body state.builder in
    (* move the builder to the loop body *)
    Llvm.position_at_end loop_body state.builder;
    (* add the loop variable to the name environment *)
    let names' : name_env = StringMap.add loop_var_name loop_var names in
    let curr_val : Llvm.llvalue = compile_exp state names' body in
    let next_result : Llvm.llvalue =
      Llvm.build_fadd result curr_val "next_result" state.builder
    in
    (* builder might have moved when compiling body! *)
    let curr_block = Llvm.insertion_block state.builder in
    (* add an edge to the phi node so that result and next_result are merged *)
    Llvm.add_incoming (next_result, curr_block) result;
    let next_loop_var =
      Llvm.build_fadd loop_var one "next_loop_var" state.builder
    in
    (* update the phi node for the loop var also *)
    Llvm.add_incoming (next_loop_var, curr_block) loop_var;
    let _ = Llvm.build_br loop_header state.builder in
    (* move builder back to end of the block we were building*)
    Llvm.position_at_end after_loop state.builder;
    result

module LLE = Llvm_executionengine.ExecutionEngine

let optimize llvm_fn llvm_module execution_engine =
  let pm = Llvm.PassManager.create_function llvm_module in
  (* Set up the optimizer pipeline. Start with registering info about how the
* target lays out data structures. *)
  Llvm_target.TargetData.add (LLE.target_data execution_engine) pm;

  (* THROW EVERY OPTIMIZATION UNDER THE SUN AT THE CODE *)
  List.iter (fun f -> f pm) Llvm_scalar_opts.([
    add_memory_to_register_promotion ;
    add_sccp ;
    add_aggressive_dce ;
    add_instruction_combination ;
    add_cfg_simplification ;
    add_ind_var_simplification ;
    add_dead_store_elimination ;
    add_gvn ;
    add_licm ;
  ]);
  
  ignore (Llvm.PassManager.run_function llvm_fn pm);
  ignore (Llvm.PassManager.finalize pm);
  Llvm.PassManager.dispose pm

let init (f : Dsl.fn) : llvm_state =
  (* for now modules aren't really used but still need to exist *)
  let m : Llvm.llmodule = Llvm.create_module context "M" in
  let input_types : Llvm.lltype list =
    List.map (fun _ -> float_t) f.Dsl.inputs
  in
  let return_type = float_t in
  let fn_type : Llvm.lltype =
    Llvm.function_type return_type (Array.of_list input_types)
  in
  (* make a fresh function which takes some float64's and returns a float64 *)
  let llvm_fn : Llvm.llvalue = Llvm.declare_function f.Dsl.name fn_type m in
  let builder : Llvm.llbuilder = Llvm.builder context in
  (* create an entry block to the function and move our builder to this block *)
  let entry : Llvm.llbasicblock = Llvm.append_block context "entry" llvm_fn in
  Llvm.position_at_end entry builder;
  { builder = builder; llvm_module = m; llvm_fn = llvm_fn }

type compiled_fn = {
  fn_val : Llvm.llvalue;
  execution_engine : LLE.t;
}

let compile (f:Dsl.fn) : compiled_fn =
  (* initialize an empty function *)
  let state : llvm_state = init f in
  (* grabs the registers which store inputs *)
  let llvm_inputs : Llvm.llvalue array = Llvm.params state.llvm_fn in
  (* ...and combine them with the names of variables to make an env *)
  let names =
    List.fold_left2
      (fun env name llvm_var -> StringMap.add name llvm_var env)
      StringMap.empty
      f.Dsl.inputs
      (Array.to_list llvm_inputs)
  in
  let result = compile_exp state names f.Dsl.body in
  (* return the last value *)
  let _ = Llvm.build_ret result state.builder in
  print_endline "Compiled LLVM Code:";
  Llvm.dump_value state.llvm_fn;
  print_endline "Validating function:";
  Llvm_analysis.assert_valid_function state.llvm_fn;
  (* create an optimizing JIT (opt-level = 3) *)
  let execution_engine : LLE.t =
    Llvm_executionengine.ExecutionEngine.create_jit state.llvm_module 3
  in
  optimize state.llvm_fn state.llvm_module execution_engine;
  (* return the function associated with the execution engine *)
  {
    fn_val = state.llvm_fn;
    execution_engine = execution_engine
  }

module GV = Llvm_executionengine.GenericValue

let run (f:compiled_fn) (inputs:float list) : float =
  let llvm_inputs : GV.t list = List.map (GV.of_float float_t) inputs in
  let result : GV.t =
    LLE.run_function f.fn_val (Array.of_list llvm_inputs) f.execution_engine
  in
  GV.as_float float_t result
Something went wrong with that request. Please try again.