Skip to content

Commit

Permalink
Updating to new “Clojurey” syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
gigasquid committed Jan 17, 2020
1 parent b268ad6 commit 64d114d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 41 deletions.
2 changes: 1 addition & 1 deletion deps.edn
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{:paths ["src" "resources"]
:deps {org.clojure/clojure {:mvn/version "1.10.1"}
cnuernber/libpython-clj {:mvn/version "1.30"}}
cnuernber/libpython-clj {:mvn/version "1.31-SNAPSHOT"}}
:aliases
{:test {:extra-paths ["test"]
:extra-deps {org.clojure/test.check {:mvn/version "0.10.0"}}}
Expand Down
35 changes: 19 additions & 16 deletions src/gigasquid/gpt2.clj
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
(ns gigasquid.gpt2
(:require [libpython-clj.require :refer [require-python]]
[libpython-clj.python :as py]))
[libpython-clj.python :as py :refer [py. py.. py.-]]))

;;; sudo pip3 install torch
;;; sudo pip3 install transformers

;https://huggingface.co/transformers/quickstart.html - OpenAI GPT-2

(require-python '(transformers))
(require-python '(torch))
(require-python 'transformers)
(require-python 'torch)


;;; Load pre-trained model tokenizer (vocabulary)

(def tokenizer (py/$a transformers/GPT2Tokenizer from_pretrained "gpt2"))
(def tokenizer (py. transformers/GPT2Tokenizer "from_pretrained" "gpt2"))
(def text "Who was Jim Henson ? Jim Henson was a")
;; encode text input
(def indexed-tokens (py/$a tokenizer encode text))
(def indexed-tokens (py. tokenizer encode text))
indexed-tokens ;=>[8241, 373, 5395, 367, 19069, 5633, 5395, 367, 19069, 373, 257]

;; convert indexed tokens to pytorch tensor
Expand All @@ -27,11 +27,11 @@ tokens-tensor

;;; Load pre-trained model (weights)
;;; Note: this will take a few minutes to download everything
(def model (py/$a transformers/GPT2LMHeadModel from_pretrained "gpt2"))
(def model (py. transformers/GPT2LMHeadModel from_pretrained "gpt2"))

;;; Set the model in evaluation mode to deactivate the DropOut modules
;;; This is IMPORTANT to have reproducible results during evaluation!
(py/$a model eval)
(py. model eval)


;;; Predict all tokens
Expand All @@ -41,11 +41,11 @@ tokens-tensor
;;; get the predicted next sub-word"
(def predicted-index (let [last-word-predictions (-> predictions first last)
arg-max (torch/argmax last-word-predictions)]
(py/$a arg-max item)))
(py. arg-max item)))

predicted-index ;=>582

(py/$a tokenizer decode (-> (into [] indexed-tokens)
(py. tokenizer decode (-> (into [] indexed-tokens)
(conj predicted-index)))

;=> "Who was Jim Henson? Jim Henson was a man"
Expand All @@ -57,19 +57,19 @@ predicted-index ;=>582

;; Here is a fully-working example using the past with GPT2LMHeadModel and argmax decoding (which should only be used as an example, as argmax decoding introduces a lot of repetition):

(def tokenizer (py/$a transformers/GPT2Tokenizer from_pretrained "gpt2"))
(def model (py/$a transformers/GPT2LMHeadModel from_pretrained "gpt2"))
(def tokenizer (py. transformers/GPT2Tokenizer from_pretrained "gpt2"))
(def model (py. transformers/GPT2LMHeadModel from_pretrained "gpt2"))

(def generated (into [] (py/$a tokenizer encode "The Manhattan bridge")))
(def generated (into [] (py. tokenizer encode "The Manhattan bridge")))
(def context (torch/tensor [generated]))


(defn generate-sequence-step [{:keys [generated-tokens context past]}]
(let [[output past] (model context :past past)
token (torch/argmax (first output))
new-generated (conj generated-tokens (py/$a token tolist))]
new-generated (conj generated-tokens (py. token tolist))]
{:generated-tokens new-generated
:context (py/$a token unsqueeze 0)
:context (py. token unsqueeze 0)
:past past
:token token}))

Expand All @@ -90,7 +90,7 @@ predicted-index ;=>582
;;; Let's make a nice function to generate text

(defn generate-text [starting-text num-of-words-to-predict]
(let [tokens (into [] (py/$a tokenizer encode starting-text))
(let [tokens (into [] (py. tokenizer encode starting-text))
context (torch/tensor [tokens])
result (reduce
(fn [r i]
Expand All @@ -116,7 +116,7 @@ predicted-index ;=>582

;;; from https://github.com/huggingface/transformers/issues/1725

(require-python '(torch.nn.functional))
(require-python 'torch.nn.functional)

(defn sample-sequence-step [{:keys [generated-tokens context past temp]
:or {temp 0.8}}]
Expand Down Expand Up @@ -164,3 +164,6 @@ predicted-index ;=>582
0.8)
"Rich Hickey developed Clojure because he wanted a modern Lisp for functional programming, symbiotic with the established Java platform. He knew that Clojure would make it hard to access any memory through Java, and code a good amount of Lisp. He had much to learn about programming at the time, and Clojure was perfect for him. It was important to understand the dominant language of Lisp, which was Clojure and JVM. Because of this, JVM was named 'Snack: No Slobs in Clojure'. This was a very important order of things, for JVM. Clojure had a major advantage over JVM in"

(generate-text2 "What is the average rainfall in Florida?"
100
0.8)
46 changes: 22 additions & 24 deletions src/gigasquid/mxnet.clj
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
(ns gigasquid.mxnet
(:require [libpython-clj.require :refer [require-python]]
[libpython-clj.python :as py]
[libpython-clj.python :as py :refer [py. py.. py.-]]
[clojure.string :as string]))

;;; sudo pip3 install mxnet

(require-python '(mxnet mxnet.ndarray mxnet.module mxnet.io))
(require-python '(mxnet.test_utils))
(require-python '(mxnet.initializer))
(require-python '(mxnet.metric))
(require-python '(mxnet.symbol))
(require-python '(mxnet.test_utils mxnet.initializer mxnet.metric mxnet.symbol))


;;; get the mnist data and format it

(def mnist (mxnet.test_utils/get_mnist))
(def train-x (mxnet.ndarray/array (py/$a (py/get-item mnist "train_data") "reshape" -1 784)))
(def train-x (mxnet.ndarray/array (py. (py/get-item mnist "train_data") "reshape" -1 784)))
(def train-y (mxnet.ndarray/array (py/get-item mnist "train_label")))
(def test-x (mxnet.ndarray/array (py/$a (py/get-item mnist "test_data") "reshape" -1 784)))
(def test-x (mxnet.ndarray/array (py. (py/get-item mnist "test_data") "reshape" -1 784)))
(def test-y (mxnet.ndarray/array (py/get-item mnist "test_label")))

(def batch-size 100)
Expand All @@ -31,8 +28,8 @@
:batch_size batch-size))


(def data-shapes (py/get-attr train-dataset "provide_data"))
(def label-shapes (py/get-attr train-dataset "provide_label"))
(def data-shapes (py.- train-dataset "provide_data"))
(def label-shapes (py.- train-dataset "provide_label"))

data-shapes ;=> [DataDesc[data,(10, 784),<class 'numpy.float32'>,NCHW]]
label-shapes ;=> [DataDesc[softmax_label,(10,),<class 'numpy.float32'>,NCHW]]
Expand All @@ -53,37 +50,37 @@ label-shapes ;=> [DataDesc[softmax_label,(10,),<class 'numpy.float32'>,NCHW]]


(def model (py/call-kw mxnet.module/Module [] {:symbol net :context (mxnet/cpu)}))
(py/$a model bind :data_shapes data-shapes :label_shapes label-shapes)
(py/$a model init_params)
(py/$a model init_optimizer :optimizer "adam")
(py. model bind :data_shapes data-shapes :label_shapes label-shapes)
(py. model init_params)
(py. model init_optimizer :optimizer "adam")
(def acc-metric (mxnet.metric/Accuracy))


(defn end-of-data-error? [e]
(string/includes? (.getMessage e) "StopIteration"))

(defn reset [iter]
(py/$a iter reset))
(py. iter reset))

(defn next-batch [iter]
(try (py/$a iter next)
(try (py. iter next)
(catch Exception e
(when-not (end-of-data-error? e)
(throw e)))))

(defn get-metric [metric]
(py/$a metric get))
(py. metric get))

(defn train-epoch [model dataset metric]
(reset dataset)
(loop [batch (next-batch dataset)
i 0]
(if batch
(do
(py/$a model forward batch :is_train true)
(py/$a model backward)
(py/$a model update)
(py/$a model update_metric metric (py/get-attr batch "label"))
(py. model forward batch :is_train true)
(py. model backward)
(py. model update)
(py. model update_metric metric (py/get-attr batch "label"))
(when (zero? (mod i 100)) (println "i-" i " Training Accuracy " (py/$a metric get)))
(recur (next-batch dataset) (inc i)))
(println "Final Training Accuracy " (get-metric metric)))))
Expand All @@ -94,8 +91,8 @@ label-shapes ;=> [DataDesc[softmax_label,(10,),<class 'numpy.float32'>,NCHW]]
i 0]
(if batch
(do
(py/$a model forward batch)
(py/$a model update_metric metric (py/get-attr batch "label"))
(py. model forward batch)
(py. model update_metric metric (py/get-attr batch "label"))
(when (zero? (mod i 100)) (println "i-" i " Test Accuracy " (py/$a metric get)))
(recur (next-batch dataset) (inc i)))
(println "Final Test Accuracy " (get-metric metric)))))
Expand All @@ -116,17 +113,18 @@ label-shapes ;=> [DataDesc[softmax_label,(10,),<class 'numpy.float32'>,NCHW]]

;;visualization

(py. train-dataset "reset")
(def bd (next-batch train-dataset))
(def data (first (py/get-attr bd "data")))

(def image (mxnet.ndarray/slice data :begin 0 :end 1))
(def image2 (py/$a image "reshape" [28 28]))
(def image2 (py. image "reshape" [28 28]))
(def image3 (-> (mxnet.ndarray/multiply image2 256)
(mxnet.ndarray/cast :dtype "uint8")))
(def npimage (py/$a image3 asnumpy))
(def npimage (py. image3 "asnumpy"))


(require-python '(cv2))
(require-python 'cv2)
(cv2/imwrite "number.jpg" npimage)


Expand Down

0 comments on commit 64d114d

Please sign in to comment.