In [1]:
(require '[clojupyter.misc.helper :as helper])
(helper/add-dependencies '[metasoarous/oz "1.6.0-SNAPSHOT"])
(helper/add-dependencies '[com.cemerick/pomegranate "1.1.0"])
(require '[cemerick.pomegranate :as pome])
(pome/add-classpath "/Users/deltam/Dropbox/code_snippets/clojure_toybox/neuro/src")
(pome/add-classpath "/Users/deltam/Dropbox/code_snippets/clojure_toybox/neuro/examples/spiral/src")

(helper/add-dependencies '[com.taoensso/tufte "2.0.1"])

(require '[oz.notebook.clojupyter :as oz]
         '[neuro.core :as nc]
         '[neuro.vol :as vl]
         '[spiral.core :as score])

nil

## Dataset

In [2]:
(def spiral-plot
    {:width 400
     :height 400
     :data {:values score/raw}
     :encoding {:x {:field "x"}
                :y {:field "y"}
                :color {:field "cat" :type "nominal"}}
     :mark "point"
    })
(oz/view! spiral-plot)

## Train

In [3]:
(def net
    (nc/gen-net
        :input 2 :fc
;        :relu 10 :fc
;        :tanh 10 :fc
        :sigmoid 10 :fc
        :softmax 3))

(def trained-net
    (nc/with-params [:train-status-var score/train-status
                     :epoch-limit 300
                     :mini-batch-size 30
                     :learning-rate 1.0
                     :epoch-reporter score/report]
        (score/train net)))

epoch 10: 0.753423
epoch 20: 0.724003
epoch 30: 0.700519
epoch 40: 0.675203
epoch 50: 0.639751
epoch 60: 0.584275
epoch 70: 0.517802
epoch 80: 0.456228
epoch 90: 0.405971
epoch 100: 0.366043
epoch 110: 0.333593
epoch 120: 0.306276
epoch 130: 0.282531
epoch 140: 0.261337
epoch 150: 0.241960
epoch 160: 0.223897
epoch 170: 0.206996
epoch 180: 0.191389
epoch 190: 0.177241
epoch 200: 0.164640
epoch 210: 0.153578
epoch 220: 0.143953
epoch 230: 0.135592
epoch 240: 0.128300
epoch 250: 0.121891
epoch 260: 0.116204
epoch 270: 0.111113
epoch 280: 0.106518
epoch 290: 0.102343
epoch 300: 0.098528
elapsed 18.675 sec


#'user/trained-net

In [4]:
(def loss (map-indexed (fn [i x] {:epoch i, :value x}) (:train-loss-history @score/train-status)))
(println (last loss))

(def line-plot
  {:width 400
   :height 400
   :data {:values loss}
   :encoding {
    :x {:field "epoch"}
    :y {:field "value"}
   }
   :mark "line"})
(oz/view! line-plot)

{:epoch 299, :value 0.09852775299280354}


## Plot Decision boudary

In [11]:
(def grid (for [y (range 1.0 -1.01 -0.01), x (range -1.0 1.01 0.01)]
            [x y]))

(println (count grid))
(def xy-vol (vl/T (vl/vol 2 (count grid) (flatten grid))))
(def start (System/currentTimeMillis))
(def out-vol (nc/feedforward trained-net xy-vol))
(println (/ (- (System/currentTimeMillis) start) 1000.0))

(def mesh (let [rows (vl/rows out-vol)]
            (map (fn [i] {:x (vl/wget in-vol i 0)
                          :y (vl/wget in-vol i 1)
                          :val (vl/argmax (nth rows i))})
               (range (count grid)))))


(def decision-boundary-plot
    (merge score/decision-boundary-plot
           {:width 600
            :height 600
            
            :signals [
              {:name "classes" :value [0 1 2]}
              {:name "meshGrid"
               :value {
                 :width (count (distinct (map :x mesh)))
                 :height (count (distinct (map :y mesh)))
                 :values (map :val mesh)}
              }
            ]}))

(oz/view! decision-boundary-plot)

40401
153.896


In [12]:
(def contours-plot
{
  :width 600,
  :height 600,
  :autosize "pad",

  :data [
    {
      :name "contours"
      :values mesh
    }
  ],
  
  :scales [
    {
      :name "x",
      :type "linear",
      :zero false,
      :domain {:data "contours", :field "x"},
      :domainMax 1.0
      :domainMin -1.0
      :range "width"
    },
    {
      :name "y",
      :type "linear",
      :zero false,
      :domain {:data "contours", :field "y"},
      :domainMax 1.0
      :domainMin -1.0
      :range "height"
    },
    {
      :name "class"
      :type "threshold"
      :zero true
      :domain [0 1 2]
      :range "category"
    }
  ],

  :axes [
    {
      :scale "x",
      :grid true,
      :domain false,
      :orient "bottom",
      :title "x"
    },
    {
      :scale "y",
      :grid true,
      :domain false,
      :orient "left",
      :title "y"
    }
  ],

  :legends [{
    :fill "class"
    :type "category"
  }],

  :marks [
    {
      :type "symbol"
      :from {:data "contours"},
      :encode {
        :enter {
          :x {:scale "x" :field "x"}
          :y {:scale "y" :field "y"}
          :shape {:value "circle"}
          :fill {:scale "class", :field "val"}
          :fillOpacity {:value 0.3}
          :size {:value 10}
        }
      }
    }
  ]

}
)

(oz/view! contours-plot)