-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdiscriminator.clj
More file actions
145 lines (115 loc) · 4.37 KB
/
discriminator.clj
File metadata and controls
145 lines (115 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
(ns clojure-mxnet-autoencoder.discriminator
(:require [clojure-mxnet-autoencoder.viz :as viz]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.initializer :as initializer]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.symbol :as sym]))
(def data-dir "data/")
(def batch-size 100)
(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte")))
(sh "./get_mnist_data.sh"))
;;; Load the MNIST datasets
(def train-data
(mx-io/mnist-iter
{:image (str data-dir "train-images-idx3-ubyte")
:label (str data-dir "train-labels-idx1-ubyte")
:input-shape [784]
:flat true
:batch-size batch-size
:shuffle true}))
(def test-data
(mx-io/mnist-iter
{:image (str data-dir "t10k-images-idx3-ubyte")
:label (str data-dir "t10k-labels-idx1-ubyte")
:input-shape [784]
:batch-size batch-size
:flat true
:shuffle true}))
(def input (sym/variable "input"))
(def output (sym/variable "input_"))
(defn get-symbol []
(as-> input data
;; encode
(sym/fully-connected "encode1" {:data data :num-hidden 100})
(sym/activation "sigmoid1" {:data data :act-type "sigmoid"})
;; encode
(sym/fully-connected "encode2" {:data data :num-hidden 50})
(sym/activation "sigmoid2" {:data data :act-type "sigmoid"})
;;; this last bit changed from autoencoder
;;output
(sym/fully-connected "result" {:data data :num-hidden 10})
(sym/softmax-output {:data data :label output})))
(def data-desc (first (mx-io/provide-data-desc train-data)))
;;; so we are actually using the label now too
(def label-desc (first (mx-io/provide-label-desc train-data)))
;;; this part needs to change the label shapes to actually be the real labels
(def model (-> (m/module (get-symbol) {:data-names ["input"] :label-names ["input_"]})
(m/bind {:data-shapes [(assoc data-desc :name "input")]
:label-shapes [(assoc label-desc :name "input_")]})
(m/init-params {:initializer (initializer/uniform 1)})
(m/init-optimizer {:optimizer (optimizer/adam {:learning-rage 0.001})})))
;;; we use accuracy instead of mse
(def my-metric (eval-metric/accuracy))
(defn train [num-epochs]
(doseq [epoch-num (range 0 num-epochs)]
(println "starting epoch " epoch-num)
(mx-io/do-batches
train-data
(fn [batch]
;;; here we make sure to use the label
;;; now for forward and update-metric
(-> model
(m/forward {:data (mx-io/batch-data batch)
:label (mx-io/batch-label batch)})
(m/update-metric my-metric (mx-io/batch-label batch))
(m/backward)
(m/update))))
(println {:epoch epoch-num
:metric (eval-metric/get-and-reset my-metric)})))
(comment
(def my-batch (mx-io/next train-data))
(def images (mx-io/batch-data my-batch))
(viz/im-sav {:title "originals"
:output-path "results/"
:x (-> images
first
(ndarray/reshape [100 1 28 28]))})
;;; before training
(def my-test-batch (mx-io/next test-data))
(def test-images (mx-io/batch-data my-test-batch))
(viz/im-sav {:title "test-images"
:output-path "results/"
:x (-> test-images
first
(ndarray/reshape [100 1 28 28]))})
(def preds (m/predict-batch model {:data test-images} ))
(->> preds
first
(ndarray/argmax-channel)
(ndarray/->vec)
(take 10))
;=> (1.0 8.0 8.0 8.0 8.0 8.0 2.0 8.0 8.0 1.0)
(train 3)
;; starting epoch 0
;; {:epoch 0, :metric [accuracy 0.83295]}
;; starting epoch 1
;; {:epoch 1, :metric [accuracy 0.9371333]}
;; starting epoch 2
;; {:epoch 2, :metric [accuracy 0.9547667]}
;;; after training
(def preds (m/predict-batch model {:data test-images} ))
(->> preds
first
(ndarray/argmax-channel)
(ndarray/->vec)
(take 10))
;=> (6.0 1.0 0.0 0.0 3.0 1.0 4.0 8.0 0.0 9.0)
;;; save model
(m/save-checkpoint model {:prefix "model/discriminator"
:epoch 2})
)