-
Notifications
You must be signed in to change notification settings - Fork 3
/
rbm.clj
324 lines (292 loc) · 13 KB
/
rbm.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
(ns deebn.rbm
(:refer-clojure :exclude [+ - * / ==])
(:require [deebn.protocols :refer [Testable Trainable Classify]]
[deebn.util :refer [bernoulli gen-softmax
get-min-position sigmoid]]
[clojure.core.matrix :as m]
[clojure.core.matrix.operators :refer [+ - * / ==]]
[clojure.core.matrix.select :as s]
[clojure.core.matrix.random :as rand]
[clojure.core.matrix.stats :as stats]
[clojure.set :refer [difference]]
[clojure.tools.reader.edn :as edn])
(:import java.io.Writer))
(m/set-current-implementation :vectorz)
;;;===========================================================================
;;; Generate Restricted Boltzmann Machines
;;; ==========================================================================
;; We define a purely generative RBM, trained without any class
;; labels, and a classification RBM
(defrecord RBM [w vbias hbias w-vel vbias-vel hbias-vel visible hidden])
(defrecord CRBM [w vbias hbias w-vel vbias-vel hbias-vel visible hidden classes])
(defn build-rbm
"Factory function to produce an RBM record."
[visible hidden]
(let [w (m/matrix (repeatedly visible #(/ (rand/sample-normal hidden) 100)))
w-vel (m/zero-matrix visible hidden)
;; TODO: The visual biases should really be set to
;; log(p_i/ (1 - p_i)), where p_i is the proportion of
;; training vectors in which unit i is turned on.
vbias (m/zero-vector visible)
hbias (m/array (repeat hidden -4))
vbias-vel (m/zero-vector visible)
hbias-vel (m/zero-vector hidden)]
(->RBM w vbias hbias w-vel vbias-vel hbias-vel visible hidden)))
(defn build-jd-rbm
"Factory function to build a joint density RBM for testing purposes.
This RBM has two sets of visible units - the typical set
representing each observation in the data set, and a softmax unit
representing the label for each observation. These are combined, and
the label becaomes part of the input vector."
[visible hidden classes]
(let [rbm (build-rbm (+ visible classes) hidden)]
(map->CRBM (assoc rbm :classes classes))))
;;;===========================================================================
;;; Train an RBM
;;; ==========================================================================
(defn update-weights
"Determine the weight gradient from this batch"
[ph ph2 batch pv]
(reduce + (map #(- (m/outer-product %1 %2) (m/outer-product %3 %4))
(m/rows batch) (m/rows ph)
(m/rows pv) (m/rows ph2))))
;; TODO: Implement CD-K - currently CD-1 is hard-coded.
(defn update-rbm
"Single batch step update of RBM parameters"
[batch rbm learning-rate momentum]
(let [batch-size (m/row-count batch)
ph (m/emap sigmoid (+ (:hbias rbm) (m/mmul batch (:w rbm))))
h (m/emap bernoulli ph)
pv (m/emap sigmoid (+ (:vbias rbm)
(m/mmul h (m/transpose (:w rbm)))))
v (m/emap bernoulli pv)
ph2 (m/emap sigmoid (+ (:hbias rbm) (m/mmul v (:w rbm))))
delta-w (/ (update-weights ph ph2 batch pv) batch-size)
delta-vbias (/ (reduce + (map #(- % %2)
(m/rows batch)
(m/rows pv)))
batch-size)
delta-hbias (/ (reduce + (map #(- % %2)
(m/rows h)
(m/rows ph2)))
batch-size)
w-vel (+ (* momentum (:w-vel rbm)) (* learning-rate delta-w))
vbias-vel (+ (* momentum (:vbias-vel rbm))
(* learning-rate delta-vbias))
hbias-vel (+ (* momentum (:hbias-vel rbm))
(* learning-rate delta-hbias))]
(assoc rbm
:w (+ (:w rbm) w-vel)
:vbias (+ (:vbias rbm) vbias-vel)
:hbias (+ (:hbias rbm) hbias-vel)
:w-vel w-vel :vbias-vel vbias-vel :hbias-vel hbias-vel)))
(defn train-epoch
"Train a single epoch"
[rbm dataset learning-rate momentum batch-size]
(loop [rbm rbm
batch (m/matrix (s/sel dataset (range 0 batch-size) (s/irange)))
batch-num 1]
(let [start (* (dec batch-num) batch-size)
end (min (* batch-num batch-size) (m/row-count dataset))]
(if (>= start (m/row-count dataset))
rbm
(do
(print ".")
(flush)
(recur (update-rbm batch rbm learning-rate momentum)
(m/matrix (s/sel dataset (range start end) (s/irange)))
(inc batch-num)))))))
(defn select-overfitting-sets
"Given a dataset, attempt to choose reasonable validation and test
sets to monitor overfitting."
[dataset]
(let [obvs (m/row-count dataset)
validation-indices (set (repeatedly (/ obvs 100) #(rand-int obvs)))
validations (m/matrix (s/sel dataset
(vec validation-indices) (s/irange)))
train-indices (difference
(set (repeatedly (/ obvs 100)
#(rand-int obvs))) validation-indices)
train-sample (m/matrix (s/sel dataset (vec train-indices) (s/irange)))]
{:validations validations
:train-sample train-sample
:dataset (s/sel dataset (s/exclude (vec validation-indices))
(s/irange))}))
(defn free-energy
"Compute the free energy of a given visible vector and RBM. Lower is
better."
[x rbm]
(let [hidden-input (+ (:hbias rbm) (m/mmul x (:w rbm)))]
(- (- (m/mmul x (:vbias rbm)))
(reduce + (mapv #(Math/log (+ 1 (Math/exp %))) hidden-input)))))
(defn check-overfitting
"Given an rbm, a sample from the training set, and a validation set,
determine if the model is starting to overfit the data. This is
measured by a difference in the average free energy over the
training set sample and the validation set."
[rbm train-sample validations]
(let [avg-train-energy (stats/mean (pmap #(free-energy %1 rbm)
(m/rows train-sample)))
avg-validation-energy (stats/mean (pmap #(free-energy %1 rbm)
(m/rows validations)))]
(Math/abs ^Double (- avg-train-energy avg-validation-energy))))
(defn train-rbm
"Given a training set, train an RBM
params is a map with various options:
learning-rate: defaults to 0.1
initial-momentum: starting momentum. Defaults to 0.5
momentum: momentum after `momentum-delay` epochs have passed. Defaults to 0.9
momentum-delay: epochs after which `momentum` is used instead of
`initial-momentum`. Defaults to 3
batch-size: size of each mini-batch. Defaults to 10
epochs: number of times to train the model over the entire training set.
Defaults to 100
gap-delay: number of epochs elapsed before early stopping is considered
gap-stop-delay: number of sequential epochs where energy gap is increasing
before stopping"
[rbm dataset params]
(let [{:keys [validations train-sample dataset]}
(select-overfitting-sets dataset)
{:keys [learning-rate initial-momentum momentum momentum-delay
batch-size epochs gap-delay gap-stop-delay]
:or {learning-rate 0.1
initial-momentum 0.5
momentum 0.9
momentum-delay 3
batch-size 10
epochs 100
gap-delay 10
gap-stop-delay 2}} params]
(println "Training epoch 1")
(loop [rbm (train-epoch rbm dataset learning-rate
initial-momentum batch-size)
epoch 2
energy-gap (check-overfitting rbm train-sample validations)
gap-inc-count 0]
(if (> epoch epochs)
rbm
(do (println "\nTraining epoch" epoch)
(let [curr-momentum (if (> epoch momentum-delay)
momentum initial-momentum)
rbm (train-epoch rbm dataset learning-rate
curr-momentum batch-size)
gap-after-train (check-overfitting rbm train-sample
validations)
_ (println "\nGap pre-train:" energy-gap
"After train:" gap-after-train)]
(if (and (>= epoch gap-delay)
(neg? (- energy-gap gap-after-train))
(>= gap-inc-count gap-stop-delay))
rbm
(recur rbm
(inc epoch)
gap-after-train
(if (neg? (- energy-gap gap-after-train))
(inc gap-inc-count)
0)))))))))
(extend-protocol Trainable
CRBM
(train-model [m dataset params]
(train-rbm m dataset params))
RBM
(train-model [m dataset params]
(train-rbm m dataset params)))
;;;===========================================================================
;;; Testing an RBM trained on a data set
;;;===========================================================================
(defn get-prediction
"For a given observation and RBM, return the predicted class."
[x rbm num-classes labeled?]
(let [softmax-cases (mapv #(gen-softmax % num-classes) (range num-classes))
trials (m/matrix (mapv #(m/join % %2) softmax-cases
(if labeled?
(repeat (butlast x))
(repeat x))))
results (mapv #(free-energy % rbm) trials)]
(get-min-position results)))
(extend-protocol Classify
CRBM
(classify [m obv]
(get-prediction obv m (:classes m) false)))
(defn test-rbm
"Test a joint density RBM trained on a data set. Returns an error
percentage.
dataset should have the label as the last entry in each
observation."
[rbm dataset num-classes]
(let [num-observations (m/row-count dataset)
predictions (pmap #(get-prediction % rbm num-classes true) dataset)
errors (mapv #(if (== (last %) %2) 0 1) dataset predictions)
total (m/esum errors)]
(double (/ total num-observations))))
(extend-protocol Testable
CRBM
(test-model [m dataset]
(test-rbm m dataset (:classes m))))
;;;===========================================================================
;;; Utility functions for an RBM
;;;===========================================================================
;; This is designed for EDN printing, not actually visualizing the RBM
;; at the REPL (this is only needed because similar methods are not
;; defined for clojure.core.matrix implementations)
(defmethod clojure.core/print-method RBM print-RBM [rbm ^Writer w]
(.write w (str "#deebn.rbm.RBM {"
" :w " (m/to-nested-vectors (:w rbm))
" :vbias " (m/to-nested-vectors (:vbias rbm))
" :hbias " (m/to-nested-vectors (:hbias rbm))
" :w-vel " (m/to-nested-vectors (:w-vel rbm))
" :vbias-vel " (m/to-nested-vectors (:vbias-vel rbm))
" :hbias-vel " (m/to-nested-vectors (:hbias-vel rbm))
" :visible " (:visible rbm)
" :hidden " (:hidden rbm)
" }")))
(defmethod clojure.core/print-method CRBM print-CRBM [rbm ^Writer w]
(.write w (str "#deebn.rbm.CRBM {"
" :w " (m/to-nested-vectors (:w rbm))
" :vbias " (m/to-nested-vectors (:vbias rbm))
" :hbias " (m/to-nested-vectors (:hbias rbm))
" :w-vel " (m/to-nested-vectors (:w-vel rbm))
" :vbias-vel " (m/to-nested-vectors (:vbias-vel rbm))
" :hbias-vel " (m/to-nested-vectors (:hbias-vel rbm))
" :visible " (:visible rbm)
" :hidden " (:hidden rbm)
" :classes " (:classes rbm)
" }")))
(defn save-rbm
"Save a RBM to disk."
[rbm filepath]
(spit filepath (pr-str rbm)))
(defn edn->RBM
"The default map->RBM function provided by the defrecord doesn't
provide us with the performant implementation (i.e. matrices and
arrays from core.matrix), so this function adds a small step to
ensure that."
[data]
(->RBM (m/matrix (:w data))
(m/matrix (:vbias data))
(m/matrix (:hbias data))
(m/matrix (:w-vel data))
(m/matrix (:vbias-vel data))
(m/matrix (:hbias-vel data))
(:visible data)
(:hidden data)))
(defn edn->CRBM
"The default map->RBM function provided by the defrecord doesn't
provide us with the performant implementation (i.e. matrices and
arrays from core.matrix), so this function adds a small step to
ensure that."
[data]
(->CRBM (m/matrix (:w data))
(m/matrix (:vbias data))
(m/matrix (:hbias data))
(m/matrix (:w-vel data))
(m/matrix (:vbias-vel data))
(m/matrix (:hbias-vel data))
(:visible data)
(:hidden data)
(:classes data)))
(defn load-rbm
"Load a RBM from disk."
[filepath]
(edn/read-string {:readers {'deebn.rbm.RBM edn->RBM
'deebn.rbm.CRBM edn->CRBM}} (slurp filepath)))