-
Notifications
You must be signed in to change notification settings - Fork 0
/
afk.clj
186 lines (162 loc) · 6.92 KB
/
afk.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
(ns josh.meanings.initializations.afk
"Fast and Provably Good Seedings for k-Means is a paper by Olivier Bachem,
Mario Lucic, S. Hamed Hassani, and Andreas Krause which introduces an
improvement to the monte carlo markov chain approximation of k-means++
D^2 sampling. It accomplishes this by computing the D^2 sampling
distribution with respect to the first cluster. This has the practical
benefit of removing some of the assumptions, like choice of distance
metric, which were imposed in the former framing. As such the name of
this algorithm is assumption free k-mc^2. A savvy reader may note that
by computing the D^2 sampling distribution as part of the steps this
algorithm loses some of the theoretical advantages of the pure markov
chain formulation. The paper argues that this is acceptable, because
in practice computing the first D^2 sampling distribution ends up paying
for itself by reducing the chain length necessary to get convergence
guarantees."
(:require
[clojure.spec.alpha :as s]
[tech.v3.dataset :as ds]
[tech.v3.dataset.reductions :as dsr]
[clojure.tools.logging :as log]
[josh.meanings.persistence :as p]
[josh.meanings.initializations.utils :refer [centroids->dataset weighted-sample uniform-sample add-default-chain-length]]
[josh.meanings.initializations.core :refer [initialize-centroids]]
[josh.meanings.specs :as specs]))
(defn- point
"Returns a poiint from a row, dropping q(x) entry."
[row]
(butlast row))
(s/fdef point
:args (s/cat :row :josh.meanings.specs/row)
:ret :josh.meanings.specs/point)
(defn- qx
"Returns q(x) from a row."
[row]
(last row))
(s/fdef qx
:args (s/cat :row :josh.meanings.specs/row)
:ret :josh.meanings.specs/distance)
(defn samples-needed
"Returns the number of samples needed to do monte carlo sampling."
[k m]
(*' (dec k) m))
(s/fdef samples-needed
:args (s/cat
:k :josh.meanings.specs/k
:m :josh.meanings.specs/m)
:ret :josh.meanings.specs/sample-count)
;; In the paper they formulate sampling such that sampling is carried out
;; one weighted sample at a time. I'm not going to do that. Instead I'm going
;; to get one large sample. Doing this means we won't be doing both the CPU
;; intensive and disk intensive parts of our algorithm at the same time.
(defn- samples
"Get all the samples we'll need for the markov chain."
[ds-seq k m]
(log/info "Sampling with respect to q(x)")
(weighted-sample ds-seq qx (samples-needed k m) :replace true))
(s/fdef samples
:args (s/cat :ds-seq :josh.meanings.specs/sampling-datasets
:k :josh.meanings.specs/k
:m :josh.meanings.specs/m)
:ret :josh.meanings.specs/rows
:fn (fn [{:keys [args ret]}]
(let [k (second (:k args))
m (second (:m args))]
(= (samples-needed k m) (count ret)))))
(defn square
"Returns the x^2."
[x]
(* x x))
(s/fdef square :args (s/cat :x number?) :ret number?)
(defn make-weight-fn
"Create a function which computes the weight of a point given the
current set of clusters."
[distance-fn clusters]
(fn [p2]
(apply min (for [p1 clusters] (distance-fn p1 p2)))))
(s/fdef q-of-x
:args (s/cat
:conf :josh.meanings.specs/configuration
:cluster :josh.meanings.specs/point))
(defn- q-of-x
"Computes the q(x) distribution for all x in the dataset."
[conf cluster]
(log/info "Computing q(x) distribution with respect to" cluster)
(let [d (partial (:distance-fn conf) cluster)
dxs (comp square d)
dxs-for-cmap (fn [& cols] (dxs cols))
stats (dsr/aggregate
{"n" (dsr/row-count)
"sum(d(x)^2)" (dsr/sum "d(x)^2")}
(map #(ds/column-map % "d(x)^2" dxs-for-cmap) (p/read-dataset-seq conf :points)))
n (first (get stats "n"))
d2-sum (first (get stats "sum(d(x)^2)"))
regularization-term (/ 1 (* n 2))
;; instead of multiplying 1/2 by each refactoring the /2 into true-d2
doubled-d2-sum (* 2 d2-sum)
qx (fn [& cols] (+ (/ (dxs cols) doubled-d2-sum) regularization-term))]
(log/info "Caching q(x) distribution in :points dataset")
(p/write-dataset-seq conf :points
(->> (p/read-dataset-seq conf :points)
(map #(ds/column-map % "q(x)" qx))))))
(s/fdef cleanup-q-of-x :args (s/cat :conf :josh.meanings.specs/configuration))
(defn- cleanup-q-of-x
"Removes q(x) distribution for all x in the dataset."
[conf]
(log/info "Removing cached q(x) distribution in :points dataset")
(p/write-dataset-seq conf :points
(->> (p/read-dataset-seq conf :points)
(map #(dissoc % "q(x)")))))
(s/fdef mcmc-sample
:args (s/cat :distance-fn ifn?
:c :josh.meanings.specs/point
:rsp :josh.meanings.specs/rows)
:ret :josh.meanings.specs/point)
(defn- mcmc-sample
"Perform markov chain monte carlo sampling to approxiate D^2 sampling"
[weight-fn rsp]
(loop [points (map point rsp) ;; the points
dyqyseq (map * ;; d(c, y) * q(y)
(map weight-fn points)
(map qx rsp))
rands (repeatedly (count points) rand) ;; Unif(0, 1)
x (first points) ;; x
dxqx (first dyqyseq)] ;; d(c, x) * q(x)
(if (empty? points)
x
(let [take (or (zero? dxqx) (> (/ (first dyqyseq) dxqx) (first rands)))]
(recur
(rest points)
(rest dyqyseq)
(rest rands)
(if take (first points) x)
(if take (first dyqyseq) dxqx))))))
(s/fdef k-means-assumption-free-mc-initialization
:args (s/cat :conf :josh.meanings.specs/configuration)
:ret :josh.meanings.specs/points)
(defn- k-means-assumption-free-mc-initialization
[conf]
{:pre [(contains? conf :m) (contains? conf :k) (contains? conf :distance-fn)]
:post [(= (:k conf) (count %))]}
(log/info "Performing afk-mc initialization")
(log/info "Sampling cluster from dataset for initial centroid choice")
(let [cluster (first (uniform-sample (p/read-dataset-seq conf :points) 1))]
(log/info "Got initial cluster" cluster)
(q-of-x conf cluster)
(let [k (:k conf) ;; number of clusters
m (:m conf) ;; markov chain length
sp (samples (p/read-dataset-seq conf :points) k m)
clusters
(loop [cs [cluster] rsp sp]
(let [weight-fn (make-weight-fn (:distance-fn conf) cs)]
(log/info "Performing round of mcmc sampling")
(if (empty? rsp)
cs
(let [nc (mcmc-sample weight-fn (take m rsp))]
(recur (conj cs nc) (drop m rsp))))))]
(cleanup-q-of-x conf)
clusters)))
(defmethod initialize-centroids
:afk-mc
[conf]
(centroids->dataset conf (k-means-assumption-free-mc-initialization (add-default-chain-length conf))))