Skip to content

Commit

Permalink
Add interface file for tensor library
Browse files Browse the repository at this point in the history
  • Loading branch information
jewelltaylor committed Apr 20, 2024
1 parent 7c67d9c commit dfa1328
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
69 changes: 69 additions & 0 deletions lib/tensor/tensor.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
(* Tensor Interface *)

(* Imports *)
open Values

(* Type Definitions *)

type gradient =
| GRAD of values
| NONE

type tensor = {
tid : int;
mutable vals : values;
mutable grad : gradient;
mutable acc_grad : gradient;
op : operator;
} and operator =
| ADD of tensor * tensor
| SUB of tensor * tensor
| MUL of tensor * tensor
| DIV of tensor * tensor
| MATMUL of tensor * tensor
| NEG of tensor
| EXP of tensor
| LOG of tensor
| SQRT of tensor
| POW2 of tensor
| SUM of tensor
| RELU of tensor
| SIGMOID of tensor
| CREATE

(* Creating *)
val create : dimensions -> float -> tensor
val from_array : standard_array -> tensor
val ones : dimensions -> tensor
val zeros : dimensions -> tensor
val random : dimensions -> tensor

(* Utilities *)
val printVals : tensor -> unit
val printGrad : tensor -> unit
val dim : tensor -> dimensions
val get_grad : tensor -> values
val get_acc_grad : tensor -> values

(* Unary Operations *)
val neg : tensor -> tensor
val exp : tensor -> tensor
val log : tensor -> tensor
val sqrt : tensor -> tensor
val pow2 : tensor -> tensor
val sum : tensor -> tensor
val relu : tensor -> tensor
val sigmoid : tensor -> tensor

(* Binary Operations *)
val add : tensor -> tensor -> tensor
val sub : tensor -> tensor -> tensor
val mul : tensor -> tensor -> tensor
val div : tensor -> tensor -> tensor
val matmul : tensor -> tensor -> tensor


(* Graph Operations *)
val reverse_topological_sort : tensor -> (tensor -> unit) -> unit
val visualize_computation_graph : ?file_name:string -> tensor -> unit
val backward : tensor -> unit
2 changes: 0 additions & 2 deletions lib/values/types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,5 @@ exception InvalidArgumentException of string
exception SizeException of string

type standard_array = float array array

type values = (float, float32_elt, c_layout) Array2.t

type dimensions = int * int
1 change: 1 addition & 0 deletions lib/values/values.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
(* Values Interface *)

(* Imports *)
open Bigarray

(* Type Definitions *)
Expand Down

0 comments on commit dfa1328

Please sign in to comment.