Skip to content

Commit

Permalink
Slightly improve error message for Values library
Browse files Browse the repository at this point in the history
  • Loading branch information
jewelltaylor committed Apr 20, 2024
1 parent 9334088 commit b9d5ce1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 22 deletions.
4 changes: 2 additions & 2 deletions lib/values/types.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Bigarray

exception TypeException
exception SizeException
exception InvalidArgumentException of string
exception SizeException of string

type standard_array = float array array

Expand Down
57 changes: 39 additions & 18 deletions lib/values/valuesOps.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,34 @@ let reciprocal a = unary_op vvrecf a
let abs a = unary_op vvfabsf a

let vforce_elementwise_binary_op f a b =
if (Array2.dim1 a, Array2.dim2 a) <> (Array2.dim1 b, Array2.dim2 b) then raise SizeException;
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let y = Array2.create Float32 c_layout dim1 dim2 in
f (bigarray_start array2 y) (bigarray_start array2 a) (bigarray_start array2 b) (allocate int (dim1 * dim2));
y
if (Array2.dim1 a, Array2.dim2 a) <> (Array2.dim1 b, Array2.dim2 b) then begin
let adim1, adim2, bdim1, bdim2 = (Array2.dim1 a), (Array2.dim2 a), (Array2.dim1 b), (Array2.dim2 b) in
let error_msg = Printf.sprintf "a dim (%d, %d) <> b dim (%d, %d) \n" adim1 adim2 bdim1 bdim2 in
raise (SizeException error_msg);
end
else begin
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let y = Array2.create Float32 c_layout dim1 dim2 in
f (bigarray_start array2 y) (bigarray_start array2 a) (bigarray_start array2 b) (allocate int (dim1 * dim2));
y
end

let pow2 a =
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let x = Array2.init Float32 c_layout dim1 dim2 (fun _ _ -> 2.0) in
vforce_elementwise_binary_op vvpowf x a

let dot a b =
if (Array2.dim1 a, Array2.dim2 a) <> (Array2.dim1 b, Array2.dim2 b) then raise SizeException;
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let r = cblas_sdot (dim1 * dim2) (bigarray_start array2 a) 1 (bigarray_start array2 b) 1 in
Array2.init Float32 c_layout 1 1 (fun _ _ -> r)
if (Array2.dim1 a, Array2.dim2 a) <> (Array2.dim1 b, Array2.dim2 b) then begin
let adim1, adim2, bdim1, bdim2 = (Array2.dim1 a), (Array2.dim2 a), (Array2.dim1 b), (Array2.dim2 b) in
let error_msg = Printf.sprintf "a dim (%d, %d) <> b dim (%d, %d) \n" adim1 adim2 bdim1 bdim2 in
raise (SizeException error_msg);
end
else begin
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let r = cblas_sdot (dim1 * dim2) (bigarray_start array2 a) 1 (bigarray_start array2 b) 1 in
Array2.init Float32 c_layout 1 1 (fun _ _ -> r)
end

let sum a =
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
Expand All @@ -48,39 +60,48 @@ let neg a =
mul a x

let add a b =
if (Array2.dim1 a, Array2.dim2 a) <> (Array2.dim1 b, Array2.dim2 b) then raise SizeException;
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let y = Array2.create Float32 c_layout dim1 dim2 in Array2.blit b y;
cblas_saxpy (dim1 * dim2) 1.0 (bigarray_start array2 a) 1 (bigarray_start array2 y) 1;
y
if (Array2.dim1 a, Array2.dim2 a) <> (Array2.dim1 b, Array2.dim2 b) then begin
let adim1, adim2, bdim1, bdim2 = (Array2.dim1 a), (Array2.dim2 a), (Array2.dim1 b), (Array2.dim2 b) in
let error_msg = Printf.sprintf "a dim (%d, %d) <> b dim (%d, %d) \n" adim1 adim2 bdim1 bdim2 in
raise (SizeException error_msg);
end
else begin
let (dim1, dim2) = (Array2.dim1 a, Array2.dim2 a) in
let y = Array2.create Float32 c_layout dim1 dim2 in Array2.blit b y;
cblas_saxpy (dim1 * dim2) 1.0 (bigarray_start array2 a) 1 (bigarray_start array2 y) 1;
y
end

let matmul ?(trans_a = 111) ?(trans_b = 111) a b =
match (trans_a, trans_b) with
| (111, 111) -> begin
if (Array2.dim2 a <> Array2.dim1 b) then raise SizeException;
let error_msg = Printf.sprintf "a dim2 %d <> b dim1 %d" (Array2.dim2 a) (Array2.dim1 b) in
if (Array2.dim2 a <> Array2.dim1 b) then raise (SizeException error_msg);
let adim1, adim2, bdim2 = Array2.dim1 a, Array2.dim2 a, Array2.dim2 b in
let y = Array2.init Float32 c_layout adim1 bdim2 (fun _ _ -> 0.0) in
cblas_sgemm 101 trans_a trans_b adim1 bdim2 adim2 1.0 (bigarray_start array2 a)
adim2 (bigarray_start array2 b) bdim2 0.0 (bigarray_start array2 y) bdim2;
y
end
| (111, 112) -> begin
if (Array2.dim2 a <> Array2.dim2 b) then raise SizeException;
let error_msg = Printf.sprintf "a dim2 %d <> b dim2 %d" (Array2.dim2 a) (Array2.dim2 b) in
if (Array2.dim2 a <> Array2.dim2 b) then raise (SizeException error_msg);
let adim1, adim2, bdim1, bdim2 = Array2.dim1 a, Array2.dim2 a, Array2.dim1 b, Array2.dim2 b in
let y = Array2.init Float32 c_layout adim1 bdim1 (fun _ _ -> 0.0) in
cblas_sgemm 101 trans_a trans_b adim1 bdim1 adim2 1.0 (bigarray_start array2 a)
adim2 (bigarray_start array2 b) bdim2 0.0 (bigarray_start array2 y) bdim1;
y
end
| (112, 111) -> begin
if (Array2.dim1 a <> Array2.dim1 b) then raise SizeException;
let error_msg = Printf.sprintf "a dim1 %d <> b dim1 %d" (Array2.dim1 a) (Array2.dim1 b) in
if (Array2.dim1 a <> Array2.dim1 b) then raise (SizeException error_msg);
let adim1, adim2, _, bdim2 = Array2.dim1 a, Array2.dim2 a, Array2.dim1 b, Array2.dim2 b in
let y = Array2.init Float32 c_layout adim2 bdim2 (fun _ _ -> 0.0) in
cblas_sgemm 101 trans_a trans_b adim2 bdim2 adim1 1.0 (bigarray_start array2 a)
adim2 (bigarray_start array2 b) bdim2 0.0 (bigarray_start array2 y) bdim2;
y
end
| _ -> raise TypeException
| _ -> raise (InvalidArgumentException "Valid trans_a * trans_b : (111, 111), (111, 112), (112, 111)")


let sub a b = add a (neg b)
Expand Down
4 changes: 2 additions & 2 deletions lib/values/valuesUtils.ml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
open Bigarray
exception InvalidArgumentException
open Types

let range start stop =
if start >= stop then raise InvalidArgumentException;
if start >= stop then raise (InvalidArgumentException "start must be less or equal to stop");

let rec rangeHelper start stop =
if start == stop then []
Expand Down

0 comments on commit b9d5ce1

Please sign in to comment.