In [1]:
(* Required definitions *)
let asrt = function
 | (true,_) -> ()
 | (false, str) -> failwith ("Assertion failure: "^str)

val asrt : bool * string -> unit = <fun>


In [2]:
module Util = struct
  type 'a t = float list list
  let shape (l: 'a t) = (List.length l, List.length (List.nth l 0))
  
  let float_1d size = List.init size (fun _ -> (Random.float 1.0) -. 0.5)
  
  let float_2d rows cols = List.init rows (fun _ -> List.init cols (fun _ -> (Random.float 1.0) -. 0.5))
  
  let float_2d_zeros rows cols = List.init rows (fun _ -> List.init cols (fun _ -> 0.0))
  
  let transpose (lst : 'a list list) : 'a list list =
    let rec transpose_helper acc lst = match lst with
    | [] -> acc
    | [] :: _ -> acc
    | _ ->
        let column = List.map List.hd lst in
        let remainder = List.map List.tl lst in
        transpose_helper (column :: acc) remainder
    in
    List.rev (transpose_helper [] lst)
  
  let dot_product (lst1 : float list list) (lst2 : float list list) : float list list =
    let dot_helper (lst1 : float list) (lst2 : float list) : float =
      List.fold_left2 (fun acc x y -> acc +. x *. y) 0. lst1 lst2 in
    let transposed = transpose lst2 in
      List.map (fun row -> List.map (dot_helper row) transposed) lst1
      
  let add_scalar (s : float) (lst : float list list) = List.map (fun row -> List.map (fun x -> x +. s) row) lst
  
  let sub_scalar (lst: float list list) (s: float) = List.map (fun row -> List.map (fun x -> x -. s) row) lst
  
  let mult_scalar (s : float) (lst : float list list) = List.map (fun row -> List.map (fun x -> x *. s) row) lst
  
  let add_mat (m1 : float list list) (m2 : float list list) : float list list =
    let n_rows1, n_cols1 = shape m1 in
    let n_rows2, n_cols2 = shape m2 in
    if n_rows1 <> n_rows2 || n_cols1 <> n_cols2 then
      let rec repeat_column col n = match n with
      | 0 -> []
      | _ -> col :: (repeat_column col (n-1))
      in
      let m1_resized = List.map (fun row -> repeat_column (List.hd row) n_cols2) m1 in
      List.map2 (List.map2 (+.)) m1_resized m2
    else
      List.map2 (fun row1 row2 -> List.map2 (fun x y -> x +. y) row1 row2) m1 m2
  
  let sub_mat (l1: float list list) (l2: float list list) = 
    List.map2 (fun row1 row2 -> List.map2 (fun x y -> x -. y) row1 row2) l1 l2
    
  let mult_mat (l1: float list list) (l2: float list list) = 
    List.map2 (fun row1 row2 -> List.map2 (fun x y -> x *. y) row1 row2) l1 l2
  
  let sum_2d_matrix (matrix : float list list) : float =
    List.fold_left (fun acc row -> acc +. List.fold_left (+.) 0. row) 0. matrix

  let one_hot (y: int list): float list list =
    let n = List.fold_left max 0 y + 1 in
    let temp = List.init (List.length y) (fun i -> List.init n (fun j -> if j = List.nth y i then 1.0 else 0.0)) in
    transpose temp
    
  let init_params =
    let w1 = float_2d 10 784 in
    let b1 = float_2d 10 1 in
    let v_w1 = float_2d_zeros 10 784 in
    let v_b1 = 0.0 in
    let w2 = float_2d 10 10 in
    let b2 = float_2d 10 1 in
    let v_w2 = float_2d_zeros 10 10 in
    let v_b2 = 0.0 in
    (w1, b1, w2, b2, v_w1, v_b1, v_w2, v_b2)

  let random_indices n x y =
    let rec aux acc n =
    if n = 0 then acc
    else
      let r = Random.int (y - x + 1) + x in
      if List.mem r acc then aux acc n
      else aux (r :: acc) (n-1)
    in aux [] n

  let select_indices matrix indices =
    List.map (fun i -> List.nth matrix i) indices
end

module Util :
  sig
    type 'a t = float list list
    val shape : 'a t -> int * int
    val float_1d : int -> float list
    val float_2d : int -> int -> float list list
    val float_2d_zeros : int -> int -> float list list
    val transpose : 'a list list -> 'a list list
    val dot_product : float list list -> float list list -> float list list
    val add_scalar : float -> float list list -> float list list
    val sub_scalar : float list list -> float -> float list list
    val mult_scalar : float -> float list list -> float list list
    val add_mat : float list list -> float list list -> float list list
    val sub_mat : float list list -> float list list -> float list list
    val mult_mat : float list list -> float list list -> float list list
    val sum_2d_matrix : float list list -> float
    val one_hot : int list -> float list list
    val init_params :
      float list list * float list list * float list list * float list list *
      float list list * float * float li

In [3]:
module Activation = struct
  type 'a t = float list list
  let relu (z : 'a t) : 'a t =
  List.map (fun arr -> List.map (fun x -> max x 0.0) arr) z

  let relu_derive (z : 'a t) : 'a t =
    List.map (fun arr -> List.map (fun x -> if x > 0.0 then 1.0 else 0.0) arr) z

  let softmax (z : 'a t) : 'a t =
    let softmax_row row =
    let exp_row = List.map exp row in
    let sum_exp_row = List.fold_left (+.) 0.0 exp_row in
        List.map (fun x -> x /. sum_exp_row) exp_row
    in
    let smax = List.map softmax_row (Util.transpose z) in
    Util.transpose smax

  let tanh (z : 'a t) : 'a t =
    List.map (fun arr -> List.map (fun x -> tanh x) arr) z

  let sigmoid (z : 'a t) : 'a t =
    let sigmoid_scalar y = 1.0 /. (1.0 +. exp (-. y)) in
    let sigmoid_row row = List.map sigmoid_scalar row in
    List.map sigmoid_row z

  let sigmoid_derive (z : 'a t) : 'a t =
    let s = sigmoid z in
    let s_sub = Util.sub_scalar s 1.0 in
    List.map2 (fun r1 r2 -> List.map2 (fun v1 v2 -> v1 *. (-1. *. v2)) r1 r2) s s_sub
end

module Activation :
  sig
    type 'a t = float list list
    val relu : 'a t -> 'a t
    val relu_derive : 'a t -> 'a t
    val softmax : 'a t -> 'a t
    val tanh : 'a t -> 'a t
    val sigmoid : 'a t -> 'a t
    val sigmoid_derive : 'a t -> 'a t
  end


In [4]:
#require "Csv"

let load_csv (file_name: string) (f: string -> 'a) = 
  Csv.load file_name 
  |> List.map (fun row -> List.map f row)

val load_csv : string -> (string -> 'a) -> 'a list list = <fun>


In [5]:
let train_x : 'a list list = 
  let temp = load_csv "../data/train_small.csv" float_of_string in
  Util.transpose temp 
  |> List.map (fun row -> List.map (fun x -> x /. 255.0) row)
  
let train_y : 'a list = 
  let temp = load_csv "../data/train_small_labels.csv" int_of_string in
  0 |> List.nth (Util.transpose temp)

val train_x : float list list =
  [[0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 

val train_y : int list =
  [1; 0; 1; 4; 0; 0; 7; 3; 5; 3; 8; 9; 1; 3; 3; 1; 2; 0; 7; 5; 8; 6; 2; 0; 2;
   3; 6; 9; 9; 7; 8; 9; 4; 9; 2; 1; 3; 1; 1; 4; 9; 1; 4; 4; 2; 6; 3; 7; 7; 4;
   7; 5; 1; 9; 0; 2; 2; 3; 9; 1; 1; 1; 5; 0; 6; 3; 4; 8; 1; 0; 3; 9; 6; 2; 6;
   4; 7; 1; 4; 1; 5; 4; 8; 9; 2; 9; 9; 8; 9; 6; 3; 6; 4; 6; 2; 9; 1; 2; 0; 5;
   9; 2; 7; 7; 2; 8; 8; 5; 0; 6; 0; 0; 2; 9; 0; 4; 7; 7; 1; 5; 7; 9; 4; 6; 1;
   5; 7; 6; 5; 0; 4; 8; 7; 6; 1; 8; 7; 3; 7; 3; 1; 0; 3; 4; 5; 4; 0; 5; 4; 0;
   3; 5; 1; 0; 8; 3; 7; 0; 9; 6; 6; 9; 5; 4; 6; 9; 3; 5; 4; 2; 4; 8; 7; 7; 5;
   8; 8; 8; 2; 6; 9; 3; 1; 0; 4; 1; 5; 9; 0; 6; 2; 1; 3; 0; 6; 0; 0; 8; 3; 2;
   0; 0; 6; 0; 0; 4; 7; 2; 7; 1; 9; 9; 3; 9; 8; 4; 6; 6; 5; 3; 8; 1; 8; 7; 1;
   3; 7; 6; 3; 6; 3; 6; 3; 2; 3; 2; 2; 7; 9; 2; 3; 2; 7; 5; 5; 8; 8; 2; 0; 1;
   4; 0; 6; 3; 7; 1; 1; 1; 4; 7; 0; 2; 9; 2; 0; 5; 6; 0; 8; 9; 6; 2; 0; 0; 7;
   2; 0; 4; 2; 0; 9; 1; 6; 9; 3; 0; 0; 2; 0; 6; 8; 4; 0; 7; 2; 1; 9; 5; 2;
   ...]


In [6]:
module Prop  = struct
  let forward_prop (w1: float list list) (b1: float list list) (w2: float list list) (b2: float list list) (train_x: float list list) = 
    let z1_dot = Util.dot_product w1 train_x in
    let z1 = Util.add_mat b1 z1_dot in
(*     let _ = asrt(10=100, ""^(string_of_int (fst (Util.shape z1)))^" "^(string_of_int (snd (Util.shape z1)))) in *)
    let a1 = Activation.relu z1 in
    let z2 = Util.dot_product w2 a1
    |> Util.add_mat b2 in
    let a2 = Activation.sigmoid z2 in
    (z1, a1, a2)
    
  let backward_prop z1 a1 a2 w2 train_x train_y:
    (float list list * float * float list list * float) =
    let m_inv = (1. /. (float_of_int (List.length train_y))) in
    let dz2 = Util.sub_mat a2 train_y in
    let dw2 = Util.dot_product dz2 (Util.transpose a1) 
    |> Util.mult_scalar m_inv in
    let db2 = m_inv *. Util.sum_2d_matrix dz2 in
    let dz1 = Util.dot_product (Util.transpose w2) dz2 
    |> Util.mult_mat (Activation.relu_derive z1) in
    let dw1 = Util.dot_product dz1 (Util.transpose train_x) 
    |> Util.mult_scalar m_inv in
    let db1 = m_inv *. Util.sum_2d_matrix dz1 in
    (dw1, db1, dw2, db2)

    
  let update_params_gd w1 b1 w2 b2 dw1 db1 dw2 db2 lr = 
    let w1 = Util.mult_scalar lr dw1 
    |> Util.sub_mat w1 in
    let b1 = lr *. db1 
    |> Util.sub_scalar b1 in
    let w2 = Util.mult_scalar lr dw2 
    |> Util.sub_mat w2 in
    let b2 = lr *. db2 
    |> Util.sub_scalar b2 in
    (w1, b1, w2, b2)
    
  let update_params_gdm w1 b1 w2 b2 dw1 db1 dw2 db2 v_dw1 v_db1 v_dw2 v_db2 lr beta1 beta2 = 
    let v_dw1 = Util.add_mat (Util.mult_scalar beta1 v_dw1) (Util.mult_scalar beta2 dw1) in
    let v_db1 = (beta2 *. db1) +. (beta1 *. v_db1) in
    let v_dw2 = Util.add_mat (Util.mult_scalar beta1 v_dw2) (Util.mult_scalar beta2 dw2) in
    let v_db2 = (beta2 *. db2) +. (beta1 *. v_db2) in
    let w1 = Util.sub_mat w1 (Util.mult_scalar lr v_dw1) in
    let b1 = Util.sub_scalar b1 (lr *. v_db1) in
    let w2 = Util.sub_mat w2 (Util.mult_scalar lr v_dw2) in
    let b2 = Util.sub_scalar b2 (lr *. v_db2) in
    (w1, b1, w2, b2, v_dw1, v_db1, v_dw2, v_db2)
end

module Prop :
  sig
    val forward_prop :
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float list list -> float list list * float list list * float list list
    val backward_prop :
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float list list -> float list list * float * float list list * float
    val update_params_gd :
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float ->
      float list list ->
      float ->
      float ->
      float list list * float list list * float list list * float list list
    val update_params_gdm :
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float list list ->
      float ->
      float list list ->
      float ->
      float list list ->
      float ->
      floa

In [7]:
module type MONAD = sig
  type 'a t
  val return : 'a -> 'a t
  val (>>=)  : 'a t -> ('a -> 'b t) -> 'b t
  val run : 'a t -> 'a
end

module Monad : MONAD = struct
  type 'a t = 'a
  let return x = x
  let (>>=) x f = f x
  let run x = x
end

module type MONAD =
  sig
    type 'a t
    val return : 'a -> 'a t
    val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
    val run : 'a t -> 'a
  end


module Monad : MONAD


In [8]:
(* List.nth (Util.transpose train_x) ((0+1000) mod 1000) *)

In [9]:
module Gd = struct
  type algorithm_type = GD | SGD | MBGD
  let data x y type_ = 
    match type_ with
    | GD -> (x, y)
    | SGD -> 
      let size = snd (Util.shape x) in
      let _ = asrt (snd (Util.shape x) = snd (Util.shape y), "Shape error..") in
      let ri = Util.random_indices 1 0 (size-1) in
      let x', y' = Util.select_indices (Util.transpose x) ri, Util.select_indices (Util.transpose y) ri in
      (Util.transpose x', Util.transpose y')
    | MBGD ->
      let size = snd (Util.shape x) in
      let _ = asrt (snd (Util.shape x) = snd (Util.shape y), "Shape error..") in
      let ri = Util.random_indices 4 0 (size-1) in
      let x', y' = Util.select_indices (Util.transpose x) ri, Util.select_indices (Util.transpose y) ri in
      (Util.transpose x', Util.transpose y')
end

module Gd :
  sig
    type algorithm_type = GD | SGD | MBGD
    val data :
      float list list ->
      float list list -> algorithm_type -> float list list * float list list
  end


In [10]:
type optim = VGD | GDM

type optim = VGD | GDM


In [11]:
let gradient_descent train_x train_y lr iterations type_ optim_ beta1 beta2 =
  let open Monad in
  let w1, b1, w2, b2, v_w1, v_b1, v_w2, v_b2 = Util.init_params in
  let rec batch_loop i w1 b1 w2 b2 v_w1 (v_b1: float) v_w2 (v_b2: float) iterations train_x train_y (x, y) lr =
    if i = iterations then return (w1, b1, w2, b2)
    else Prop.forward_prop w1 b1 w2 b2 x 
         |> return >>= fun (z1, a1, a2) ->
         Prop.backward_prop z1 a1 a2 w2 x y 
         |> return >>= fun (dw1, db1, dw2, db2) ->
         if optim_ = GDM then 
           Prop.update_params_gdm w1 b1 w2 b2 dw1 db1 dw2 db2 v_w1 v_b1 v_w2 v_b2 lr beta1 beta2
           |> return >>= fun (w1', b1', w2', b2', v_dw1', v_db1', v_dw2', v_db2') ->
           if i mod 10 = 0 then Printf.printf "Iteration: %d\n" i;
           batch_loop (i + 1) w1' b1' w2' b2' v_dw1' v_db1' v_dw2' v_db2' iterations train_x train_y (Gd.data train_x train_y type_) lr
         else 
           Prop.update_params_gd w1 b1 w2 b2 dw1 db1 dw2 db2 lr 
           |> return >>= fun (w1', b1', w2', b2') ->
           batch_loop (i + 1) w1' b1' w2' b2' v_w1 v_b1 v_w2 v_b2 iterations train_x train_y (Gd.data train_x train_y type_) lr
  in
  batch_loop 0 w1 b1 w2 b2 v_w1 v_b1 v_w2 v_b2 iterations train_x train_y (Gd.data train_x train_y type_) lr

val gradient_descent :
  float list list ->
  float list list ->
  float ->
  int ->
  Gd.algorithm_type ->
  optim ->
  float ->
  float ->
  (float list list * float list list * float list list * float list list)
  Monad.t = <fun>


In [12]:
(Util.shape train_x);; 
(Util.shape (Util.one_hot train_y));;

- : int * int = (784, 4000)


- : int * int = (10, 4000)


In [13]:
let result = gradient_descent train_x (Util.one_hot train_y) 0.06 2 MBGD VGD 0.9 0.1;;
let w1, b1, w2, b2 = Monad.run result;;

val result :
  (float list list * float list list * float list list * float list list)
  Monad.t = <abstr>


val w1 : float list list =
  [[0.426038888360971146; -0.158361411401767904; 0.0978229458531059137;
    0.366881591014789565; 0.286668034537743921; 0.125315015859070122;
    -0.409011015216012297; -0.188990030213231575; 0.059742196386998736;
    -0.34615215119734255; -0.490233784604748313; 0.256637587325766692;
    -0.251656740474212159; 0.152787092477089326; 0.48717731495272254;
    -0.332554522532503172; 0.0574582248835058262; 0.258516786217372951;
    0.18532904647625259; -0.0324781508558906196; -0.171649560924790556;
    0.225258114528568254; 0.260843272481697697; 0.444543207513039773;
    -0.0466829106219362466; 0.0396486806025953; 0.36819664704757471;
    0.0449662633954558411; 0.130257917730493555; -0.42187402863574619;
    -0.375256835869513194; -0.022730071996965362; -0.229233104415120292;
    0.38945947157575922; -0.199121535032691421; -0.217841553708043145;
    0.0565473408327369365; -0.395298547224043; 0.103989915708647573;
    0.360497981031593695; 0.48943067314257882; -0.3

In [14]:
flush_all();

- : unit = ()


In [15]:
let get_predictions (a2 : float list list) : int list =
  let rec loop acc = function
    | [] -> List.rev acc
    | row :: rest ->
      let (_, max_index) =
        List.fold_left (fun (max_val, max_index) (x, i) ->
          if x > max_val then (x, i) else (max_val, max_index)) (0., 0) (List.mapi (fun i x -> (x, i)) row) in
      loop (max_index :: acc) rest
  in
  loop [] a2

val get_predictions : float list list -> int list = <fun>


In [16]:
let get_accuracy pred y : float =
  let diff = List.map2 (fun v1 v2 -> if v1 = v2 then 1 else 0) pred y in
  let sum_and_len = List.fold_left (fun (s,l) x -> s+x, l+1) (0,0) diff in
  float_of_int (fst (sum_and_len)) /. float_of_int (snd (sum_and_len))

val get_accuracy : 'a list -> 'a list -> float = <fun>


In [17]:
let test_x = 
  let temp = load_csv "../data/test_small.csv" float_of_string in
  Util.transpose temp 
  |> List.map (fun row -> List.map (fun x -> x /. 255.0) row)
  
let test_y = 
  let temp = load_csv "../data/test_small_labels.csv" int_of_string in
  0 |> List.nth (Util.transpose temp)

val test_x : float list list =
  [[0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;
    0.];
   [0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.

val test_y : int list =
  [1; 0; 1; 4; 0; 0; 7; 3; 5; 3; 8; 9; 1; 3; 3; 1; 2; 0; 7]


In [18]:
List.length test_y

- : int = 19


In [19]:
let _, _, a2 = Prop.forward_prop w1 b1 w2 b2 test_x in
let preds = get_predictions (Util.transpose a2) in
(preds, get_accuracy preds test_y)

- : int list * float =
([9; 2; 0; 9; 2; 8; 9; 6; 9; 0; 9; 7; 0; 6; 9; 6; 5; 0; 2],
 0.0526315789473684181)


In [20]:
(* let gradient_descent train_x train_y lr iterations m_inv =
  let w1, b1, w2, b2 = Util.init_params in
  let rec loop i w1 b1 w2 b2 iterations train_x train_y lr m_inv =
    if i = iterations then (w1, b1, w2, b2)
    else
      let z1, a1, z2, a2 = Prop.forward_prop w1 b1 w2 b2 train_x in
      let dw1, db1, dw2, db2 = Prop.backward_prop z1 a1 z2 a2 w1 w2 train_x train_y m_inv in
      let w1', b1', w2', b2' = Prop.update_params w1 b1 w2 b2 dw1 db1 dw2 db2 lr in
      let _ = print_endline "Iteration" in
      loop (i + 1) w1' b1' w2' b2' iterations train_x train_y lr m_inv
  in
  loop 0 w1 b1 w2 b2 iterations train_x train_y lr m_inv *)