-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add interface file for tensor library
- Loading branch information
1 parent
7c67d9c
commit dfa1328
Showing
3 changed files
with
70 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
(* Tensor Interface *) | ||
|
||
(* Imports *) | ||
open Values | ||
|
||
(* Type Definitions *) | ||
|
||
type gradient = | ||
| GRAD of values | ||
| NONE | ||
|
||
type tensor = { | ||
tid : int; | ||
mutable vals : values; | ||
mutable grad : gradient; | ||
mutable acc_grad : gradient; | ||
op : operator; | ||
} and operator = | ||
| ADD of tensor * tensor | ||
| SUB of tensor * tensor | ||
| MUL of tensor * tensor | ||
| DIV of tensor * tensor | ||
| MATMUL of tensor * tensor | ||
| NEG of tensor | ||
| EXP of tensor | ||
| LOG of tensor | ||
| SQRT of tensor | ||
| POW2 of tensor | ||
| SUM of tensor | ||
| RELU of tensor | ||
| SIGMOID of tensor | ||
| CREATE | ||
|
||
(* Creating *) | ||
val create : dimensions -> float -> tensor | ||
val from_array : standard_array -> tensor | ||
val ones : dimensions -> tensor | ||
val zeros : dimensions -> tensor | ||
val random : dimensions -> tensor | ||
|
||
(* Utilities *) | ||
val printVals : tensor -> unit | ||
val printGrad : tensor -> unit | ||
val dim : tensor -> dimensions | ||
val get_grad : tensor -> values | ||
val get_acc_grad : tensor -> values | ||
|
||
(* Unary Operations *) | ||
val neg : tensor -> tensor | ||
val exp : tensor -> tensor | ||
val log : tensor -> tensor | ||
val sqrt : tensor -> tensor | ||
val pow2 : tensor -> tensor | ||
val sum : tensor -> tensor | ||
val relu : tensor -> tensor | ||
val sigmoid : tensor -> tensor | ||
|
||
(* Binary Operations *) | ||
val add : tensor -> tensor -> tensor | ||
val sub : tensor -> tensor -> tensor | ||
val mul : tensor -> tensor -> tensor | ||
val div : tensor -> tensor -> tensor | ||
val matmul : tensor -> tensor -> tensor | ||
|
||
|
||
(* Graph Operations *) | ||
val reverse_topological_sort : tensor -> (tensor -> unit) -> unit | ||
val visualize_computation_graph : ?file_name:string -> tensor -> unit | ||
val backward : tensor -> unit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
(* Values Interface *) | ||
|
||
(* Imports *) | ||
open Bigarray | ||
|
||
(* Type Definitions *) | ||
|