Skip to content

Commit

Permalink
Improve support for updaters and schedules
Browse files Browse the repository at this point in the history
  • Loading branch information
enragedginger committed Apr 8, 2018
1 parent c19f231 commit 2ed8d7e
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 26 deletions.
123 changes: 112 additions & 11 deletions src/jutsu/ai/core.clj
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@
GaussianDistribution]
[java.io File]
[java.util Random]
(org.deeplearning4j.nn.layers.recurrent GravesBidirectionalLSTM)
(org.nd4j.linalg.schedule StepSchedule MapSchedule ScheduleType)
(org.nd4j.linalg.schedule StepSchedule MapSchedule ScheduleType ExponentialSchedule InverseSchedule PolySchedule SigmoidSchedule)
(org.deeplearning4j.nn.conf.layers.variational VariationalAutoencoder VariationalAutoencoder$Builder)
(org.nd4j.linalg.learning.config Nesterovs Adam)))
(org.nd4j.linalg.learning.config Nesterovs Adam RmsProp Sgd AdaDelta AdaGrad AdaMax Nadam NoOp)))

(defn regression-csv-iterator [filename batch-size label-index]
(let [path (-> (ClassPathResource. filename)
Expand Down Expand Up @@ -110,10 +109,6 @@
split-config (split-at layers-index (partition 2 edn-config))]
split-config))

(def schedule-type-map
{:iteration (ScheduleType/ITERATION)
:epoch (ScheduleType/EPOCH)})

(def options
{:sgd (OptimizationAlgorithm/STOCHASTIC_GRADIENT_DESCENT)
:tanh (Activation/TANH)
Expand All @@ -128,15 +123,11 @@
:xavier (WeightInit/XAVIER)
:mcxent (LossFunctions$LossFunction/MCXENT)
:truncated-bptt (BackpropType/TruncatedBPTT)
:map-schedule (fn [schedule-type key-value-pairs] (MapSchedule. (get schedule-type-map schedule-type) key-value-pairs))
:pooling-type-max (SubsamplingLayer$PoolingType/MAX)
:distribution (WeightInit/DISTRIBUTION)
:renormalize-l2-per-layer (GradientNormalization/RenormalizeL2PerLayer)
:workspace-single (WorkspaceMode/SINGLE)
:workspace-separate (WorkspaceMode/SEPARATE)
:step-schedule (fn [schedule-type initial-value decay-rate step] (StepSchedule. (get schedule-type-map schedule-type) initial-value decay-rate step))
:nesterovs (Nesterovs.)
:adam (Adam.)
})

(defn get-option [arg]
Expand Down Expand Up @@ -291,3 +282,113 @@

(defn guassian-distribution [min max]
(GaussianDistribution. min max))

(defmulti build-updater (fn [updater-key opts] updater-key))

(defmethod build-updater :rms-prop
[_ {:keys [learning-rate learning-rate-schedule epsilon]}]
(cond
(and learning-rate learning-rate-schedule epsilon) (RmsProp. learning-rate learning-rate-schedule epsilon)
learning-rate (RmsProp. learning-rate)
learning-rate-schedule (RmsProp. learning-rate-schedule)
:else (RmsProp.)
))

(defmethod build-updater :adam
[_ {:keys [learning-rate learning-rate-schedule beta1 beta2 epsilon]}]
(cond
(and learning-rate beta1 beta2 epsilon) (Adam. learning-rate beta1 beta2 epsilon)
learning-rate (Adam. learning-rate)
learning-rate-schedule (Adam. learning-rate-schedule)
:else (Adam.)
))

(defmethod build-updater :nesterovs
[_ {:keys [learning-rate momentum-schedule learning-rate-schedule momentum]}]
(cond
(and learning-rate momentum-schedule) (Nesterovs. learning-rate momentum-schedule)
(and learning-rate-schedule momentum-schedule) (Nesterovs. learning-rate-schedule momentum-schedule)
(and learning-rate-schedule momentum) (Nesterovs. learning-rate-schedule momentum)
learning-rate-schedule (Nesterovs. learning-rate-schedule)
(and learning-rate momentum) (Nesterovs. learning-rate momentum)
momentum (Nesterovs. momentum)
:else (Nesterovs.)
))

(defmethod build-updater :sgd
[_ {:keys [learning-rate-schedule learning-rate]}]
(cond
learning-rate-schedule (Sgd. learning-rate-schedule)
learning-rate (Sgd. learning-rate)
:else (Sgd.)))

(defmethod build-updater :ada-delta
[_ {:keys []}]
(AdaDelta.))

(defmethod build-updater :ada-grad
[_ {:keys [learning-rate learning-rate-schedule epsilon]}]
(cond
(and learning-rate-schedule epsilon) (AdaGrad. learning-rate-schedule epsilon)
learning-rate-schedule (AdaGrad. learning-rate-schedule)
(and learning-rate epsilon) (AdaGrad. learning-rate epsilon)
learning-rate (AdaGrad. learning-rate)
:else (AdaGrad.)
))

(defmethod build-updater :ada-max
[_ {:keys [learning-rate learning-rate-schedule beta1 beta2 epsilon]}]
(cond
(and learning-rate beta1 beta2 epsilon) (AdaMax. learning-rate beta1 beta2 epsilon)
learning-rate-schedule (AdaMax. learning-rate-schedule)
learning-rate (AdaMax. learning-rate)
:else (AdaMax.)
))

(defmethod build-updater :nadam
[_ {:keys [learning-rate learning-rate-schedule beta1 beta2 epsilon]}]
(cond
(and learning-rate beta1 beta2 epsilon) (Nadam. learning-rate beta1 beta2 epsilon)
learning-rate-schedule (Nadam. learning-rate-schedule)
learning-rate (Nadam. learning-rate)
:else (Nadam.)
))

(defmethod build-updater :no-op
[_ {:keys []}]
(NoOp.))

(def schedule-type-map
{:iteration (ScheduleType/ITERATION)
:epoch (ScheduleType/EPOCH)})

(defmulti build-schedule-internal (fn [schedule-key opts] schedule-key))

(defmethod build-schedule-internal :exponential
[_ {:keys [schedule-type initial-value gamma]}]
(ExponentialSchedule. schedule-type (double initial-value) (double gamma)))

(defmethod build-schedule-internal :inverse
[_ {:keys [schedule-type initial-value gamma power]}]
(InverseSchedule. schedule-type (double initial-value) (double gamma) (double power)))

(defmethod build-schedule-internal :map
[_ {:keys [schedule-type values]}]
(MapSchedule. schedule-type values))

(defmethod build-schedule-internal :poly
[_ {:keys [schedule-type initial-value power max-iter]}]
(PolySchedule. schedule-type (double initial-value) (double power) (int max-iter)))

(defmethod build-schedule-internal :sigmoid
[_ {:keys [schedule-type initial-value gamma step-size]}]
(SigmoidSchedule. schedule-type (double initial-value) (double gamma) (int step-size)))

(defmethod build-schedule-internal :step
[_ {:keys [schedule-type initial-value decay-rate step]}]
(StepSchedule. schedule-type (double initial-value) (double decay-rate) (double step)))

(defn build-schedule [schedule-type opts]
(let [updated-opts (-> opts
(assoc :schedule-type (get schedule-type-map (:schedule-type opts))))]
(build-schedule-internal schedule-type updated-opts)))
31 changes: 16 additions & 15 deletions test/jutsu/ai/core_test.clj
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
(ns jutsu.ai.core-test
(:require [clojure.test :refer :all]
[jutsu.ai.core :as ai]
[jutsu.matrix.core :as m])
(:import (org.nd4j.linalg.learning.config Sgd Adam Nesterovs)
(org.nd4j.linalg.schedule MapSchedule ScheduleType StepSchedule)))
[jutsu.matrix.core :as m]))

(def n (ai/network [:optimization-algo :sgd
:updater (Nesterovs. 0.5 0.9)
:updater (ai/build-updater :nesterovs {:learning-rate 0.5 :momentum 0.9})
:layers [[:dense [:n-in 1 :n-out 2 :activation :tanh]]
[:dense [:n-in 2 :n-out 2 :activation :tanh]]
[:output :mse [:n-in 2 :n-out 1
Expand All @@ -15,7 +13,7 @@
:backprop true]))

(def layer-config-test-2 [:optimization-algo :sgd
:updater (Nesterovs. 0.5 0.9)
:updater (ai/build-updater :nesterovs {:learning-rate 0.5 :momentum 0.9})
:layers [[:dense [:n-in 4 :n-out 4 :activation :relu]]
[:dense [:n-in 4 :n-out 4 :activation :relu]]
[:output :negative-log-likelihood [:n-in 4 :n-out 10
Expand Down Expand Up @@ -93,12 +91,12 @@
:l2 0.0005
:weight-init :xavier
:optimization-algo :sgd
;todo it'd be cool if this line
:updater (Adam. (MapSchedule. ScheduleType/EPOCH {(int 0) 0.01 (int 1000) 0.005 (int 3000) 0.001}))
;could be this line, instead
;:updater [:adam [:learning-rate-schedule [:map-schedule [:epoch {0 0.01 1000 0.005 3000 0.001}]]]]
;or something like this
;:updater [:adam [:learning-rate 0.01]]
:updater (ai/build-updater
:adam
{:learning-rate-schedule (ai/build-schedule
:map
{:schedule-type :epoch
:values {(int 0) 0.01 (int 1000) 0.005 (int 3000) 0.001}})})
:layers [[:convolution [5 5] [:n-in 1 :stride [1 1] :n-out 20 :activation :identity]]
[:sub-sampling :pooling-type-max [:kernel-size [2 2] :stride [2 2]]]
[:convolution [5 5] [:stride [1 1] :n-out 50 :activation :identity]]
Expand All @@ -121,11 +119,14 @@
:weight-init :distribution
:dist (ai/normal-distribution 0.0 0.1)
:activation :relu
:updater (Nesterovs. (StepSchedule. ScheduleType/ITERATION 1e-2 0.1 100000))
:updater (ai/build-updater :nesterovs {:learning-rate-schedule (ai/build-schedule :step {:schedule-type :iteration
:initial-value 1e-2
:decay-rate 0.1
:step 100000
})})
:bias-updater (ai/build-updater :nesterovs {:learning-rate (* 1e-2 2)})
:gradient-normalization :renormalize-l2-per-layer
:optimization-algo :sgd
;todo: where did this go? is this part of the updater now too?
;:bias-learning-rate (* 1e-2 2)
:l2 (* 5 1e-4)
:mini-batch false
:layers [[:convolution [11 11] [4 4] [3 3] [:name "cnn1" :n-in 3 :n-out 96 :bias-init 0.0]]
Expand All @@ -152,7 +153,7 @@
(def mnist-config
[:seed 123
:activation :relu
:updater (Nesterovs. 0.006)
:updater (ai/build-updater :nesterovs {:learning-rate 0.006})
:optimization-algo :sgd
:l2 1e-4
:layers
Expand Down

0 comments on commit 2ed8d7e

Please sign in to comment.