In [None]:
#use "topfind"
#require "torch"
open Base
open Torch

In [19]:
(* This should reach ~97% accuracy. *)
let hidden_nodes = 128
let epochs = 1000
let learning_rate = 1e-3

(* Loads the MNIST dataset. *)
let mnist = Mnist_helper.read_files ~with_caching:true ()
let { Dataset_helper.train_images; train_labels; _ } = mnist

val hidden_nodes : int = 128


val epochs : int = 1000


val learning_rate : float = 0.001


val mnist : Torch.Mnist_helper.t =
  {Torch.Mnist_helper.train_images = <abstr>; train_labels = <abstr>;
   test_images = <abstr>; test_labels = <abstr>}


val train_images : Torch.Tensor.t = <abstr>
val train_labels : Torch.Tensor.t = <abstr>


In [20]:
let vs = Var_store.create ~name:"nn" ()
let linear1 = Layer.linear vs hidden_nodes ~activation:Relu ~input_dim:Mnist_helper.image_dim
let linear2 = Layer.linear vs Mnist_helper.label_count ~input_dim:hidden_nodes
let model xs = Layer.forward linear1 xs |> Layer.forward linear2

let adam = Optimizer.adam vs ~learning_rate

val vs : Torch.Layer.Var_store.t = <abstr>


val linear1 : Torch.Layer.t = <abstr>


val linear2 : Torch.Layer.t = <abstr>


val model : Torch.Tensor.t -> Torch.Tensor.t = <fun>


val adam : Torch.Optimizer.t = <abstr>


In [21]:
for index = 1 to epochs do
  (* Compute the cross-entropy loss. *)
  let loss = Tensor.cross_entropy_for_logits (model train_images) ~targets:train_labels in

  Optimizer.backward_step adam ~loss;

  if index % 50 = 0 then begin
    (* Compute the validation error. *)
    let test_accuracy =
      Dataset_helper.batch_accuracy mnist `test ~batch_size:1000 ~predict:model
    in
    Stdio.printf "%d %f %.2f%%\n%!" index (Tensor.float_value loss) (100. *. test_accuracy);
  end;
  Stdlib.Gc.full_major ();
done

50 0.041288 89.45%
100 0.027718 92.59%
150 0.021788 93.96%
200 0.017728 94.83%
250 0.014741 95.38%
300 0.012470 95.85%
350 0.010663 96.25%
400 0.009181 96.50%
450 0.007949 96.70%
500 0.006925 96.92%
550 0.006074 96.96%
600 0.005358 97.05%
650 0.004746 97.06%
700 0.004215 97.11%
750 0.003748 97.14%
800 0.003337 97.23%
850 0.002973 97.26%
900 0.002651 97.35%
950 0.002368 97.39%
1000 0.002119 97.38%


- : unit = ()
