In [37]:
(use '[clojure.core.matrix]
     '[clojure.core.matrix.operators]
     '[clojure.core.matrix.stats]
     '[clojure.core.matrix.random]
     '[clojure.core.matrix.dataset])

In [38]:
;; Input dataset
(def X (array [[0 0 1]
               [0 1 1]
               [1 0 1]
               [1 1 1]]))

;; output dataset
(def y (transpose (array [[0 1 1 1]])))

;; initialize weights randomly with mean 0
;; with seed 1 ( just a good practice )
(def syn0 (- (* 2 (sample-uniform [3 4] 1)) 1))
(def syn1 (- (* 2 (sample-uniform [4 1] 1)) 1))

#'clj-nn.two-layer/syn1

In [39]:
(defn nonlin
  ([x] (nonlin x false))
  ([x deriv] (if deriv
              (* x (- 1 x))
              (/ 1 (+ 1 (exp (- x)))))))

#'clj-nn.two-layer/nonlin

In [40]:
(defn forward
  [syn1 syn0]
  (let [l0 X
        l1 (nonlin (dot l0 syn0))
        l2 (nonlin (dot l1 syn1))

        l2_error (- y l2)
        l2_delta (* l2_error (nonlin l2 true))

        l1_error (dot l2_delta (transpose syn1))
        l1_delta (* l1_error (nonlin l1 true))]
      [(+ syn1 (dot (transpose l1) l2_delta))
       (+ syn0 (dot (transpose l0) l1_delta))
       l2_delta]))

#'clj-nn.two-layer/forward

In [41]:
(defn train
  []
  (loop [param [syn1 syn0 nil]
         cnt 10001]
      (when (== (mod cnt 1000) 0)
        (println (format "%f" (get (array (mean (abs (get param 2)))) 0))))
      (if (== cnt 0)
       (do
         (println "done")
         [syn1 syn0])
       (recur
         (forward (first param) (second param))
         (dec cnt)))))

#'clj-nn.two-layer/train

In [42]:
(def result(train))
(def syn1 (get result 0))
(def syn0 (get result 1))

0.128360
0.000428
0.000186
0.000116
0.000084
0.000065
0.000053
0.000045
0.000039
0.000034
0.000030
done


#'clj-nn.two-layer/syn0

In [43]:
syn1

[[0.4617563814065817] [-0.17983837701559668] [-0.5845703173805659] [-0.33456588808097765]]

In [44]:
syn0

[[0.4617563814065817 -0.17983837701559668 -0.5845703173805659 -0.33456588808097765] [0.9355118188482414 -0.9877656354684774 0.9274095940464153 0.8797307775638197] [0.8943898353263877 0.8741642977919393 -0.2056513156305888 -0.3049639415937795]]