Browse files

added support for missing values

  • Loading branch information...
1 parent 50aa2b0 commit b3a25e30ed0a6c628287bef8f0c3e9d2153c954f @eandrejko committed Apr 28, 2012
Showing with 24 additions and 11 deletions.
  1. +7 −6 src/random_forests/core.clj
  2. +17 −5 src/random_forests/train.clj
@@ -88,15 +88,15 @@
{:feature feature :value value :text (str/join " and " text)}))
(= :continuous (:type feature))
- (fn [example] (<= (nth example i) value))
+ (fn [example] (and (nth example i) (<= (nth example i) value)))
{:feature feature :value value :text (str (:name feature) "<=" value)})
(= :text (:type feature))
(fn [example] (contains? (nth example i) value))
{:feature feature :value value :text (str (:name feature) " contains " (if (:dict feature) ((:dict feature) value) value))})
- (fn [example] (= (nth example i) value))
+ (fn [example] (and (nth example i) (= (nth example i) value)))
{:feature feature :value value :text (str (:name feature) "==" value)}))))
(defn pairs
@@ -117,6 +117,7 @@
(= :continuous (:type feature))
(let [values (->> (map #(nth % (:index feature)) examples)
+ (filter (comp not nil?))
(set (concat
(map #(/ (+ (last %) (first %)) 2) (pairs values))
@@ -220,18 +221,18 @@
(defn combine-predictions
"combines predictions of examples of the form {example [prediction] ...} and returns [target prediction] pairs"
- [predictions]
+ [eval-fn predictions]
(->> predictions
(reduce (partial merge-with concat))
- (map (fn [[example preds]] (vector (target example) (avg preds))))))
+ (map (fn [[example preds]] (vector (target example) (eval-fn preds))))))
(defn evaluate-forest
"evaluates collection of trees by averaging predictions on held out data within trees meta-data :eval
returns collection of [target prediction] pairs"
- [forest]
+ [forest eval-fn]
(->> forest
(map (comp :eval meta))
- (combine-predictions)))
+ (combine-predictions eval-fn)))
(defn votes
"determines vote of each decision tree in forest"
@@ -58,6 +58,14 @@
(map (fn [[a b]] (Math/abs (- a b))))
+(defn mean-classification-error
+ "measures l1 loss from forest evaluation"
+ [evaluation]
+ (->> evaluation
+ (map (fn [[a b]] (if (= a b) 1 0)))
+ (rf/avg)
+ (float)))
(defn -main
[& args]
(let [[options args banner] (cli args
@@ -68,6 +76,7 @@
["-o" "--output" "Write detailed training error output in CSV format to output file"]
["-t" "--target" "Prediction target name"]
["-b" "--binary" "Perform binary classification of target (measures AUC loss)" :default false :flag true]
+ ["-u" "--multi" "Perform multi-class classification of target (measures classification rate)" :default false :flag true]
["-l" "--limit" "Number of trees to build" :parse-fn #(Integer/parseInt %) :default 100])]
(when (or (not (first args)) (:help options))
(println banner)
@@ -80,9 +89,11 @@
(keep-indexed (fn [i x] (if (= x target-name) i)))
examples (->> (named-examples header input)
- (map #(map (fn [[name val]] ((get encoding name identity) val)) %))
+ (map #(map (fn [[name val]] (try ((get encoding name identity) val) (catch java.lang.NumberFormatException e nil))) %))
(map vec)
- (map (fn [z] (conj z (nth z target-index))))) ;; target is at end
+ (map (fn [z] (conj z (nth z target-index)))) ;; target is at end
+ (filter #(not (nil? (last %))))
+ )
features (set (features header (:features options)))]
(let [forest (take (:limit options)
(rf/build-random-forest examples features (:split options) (:size options)))
@@ -91,14 +102,15 @@
(if (:output options)
(spit (:output options) "tree_count,target,prediction,error\n"))
(doseq [trees sub-forests]
- (let [evaluation (rf/evaluate-forest trees)
+ (let [combiner (if (:multi options) rf/mode rf/avg)
+ evaluation (rf/evaluate-forest trees combiner)
loss (-> evaluation
- ((if (:binary options) auc-loss mean-absolute-loss)))]
+ ((if (:binary options) auc-loss (if (:multi options) mean-classification-error mean-absolute-loss))))]
(println (format "%d: %f" (count trees) loss))
(if (:output options)
(spit (:output options)
(->> evaluation
- (map (fn [[a b]] [(count trees) a b (- a b)]))
+ (map (fn [[a b]] [(count trees) a b (if (:multi options) (if (= a b) 1 0) (- a b))]))
(map #(str (clojure.string/join "," %) "\n"))
(reduce str))
:append true)))))

0 comments on commit b3a25e3

Please sign in to comment.