Skip to content

Commit

Permalink
Merge branch 'sritchie/nested' into sritchie/jvp
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 23, 2024
2 parents 002d575 + 3cf0f00 commit d096f70
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 139 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

## [unreleased]

- #156:

- Makes forward- and reverse-mode automatic differentiation compatible with
each other, allowing for proper mixed-mode AD

- Adds support for derivatives of literal functions in reverse-mode

- #165:

- Fixes Alexey's Amazing Bug for our tape implementation
Expand Down
48 changes: 43 additions & 5 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,21 @@
(->Function
fexp (f/arity f) (domain-types f) (range-type f))))

(defn- forward-mode-fold [f primal-s tag]
(defn- forward-mode-fold
"Takes
- a literal function `f`
- a structure `primal-s` of the primal components of the args to `f` (with
respect to `tag`)
- the `tag` of the innermost active derivative call
And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that
generates a new [[emmy.differential/Dual]] by applying the chain rule and
summing the partial derivatives for each perturbed argument in the input
structure."
[f primal-s tag]
(fn
([] 0)
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
Expand All @@ -263,7 +277,18 @@
(g/+ tangent (g/* (literal-apply partial primal-s)
dx))))))))

(defn- reverse-mode-fold [f primal-s tag]
(defn- reverse-mode-fold
"Takes
- a literal function `f`
- a structure `primal-s` of the primal components of the args to `f` (with
respect to `tag`)
- the `tag` of the innermost active derivative call
And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that assembles all partial derivatives of
`f` into a new [[emmy.tape/TapeCell]]."
[f primal-s tag]
(fn
([] [])
([partials]
Expand All @@ -275,9 +300,22 @@
partials))))

(defn- literal-derivative
"Takes a literal function `f` and a sequence of arguments `xs`, and generates an
expanded `((D f) xs)` by applying the chain rule and summing the partial
derivatives for each perturbed argument in the input structure."
"Takes
- a literal function `f`
- a structure `s` of arguments
- the `tag` of the innermost active derivative call
- an instance of a perturbation `dx` associated with `tag`
and generates the proper return value for `((D f) xs)`.
In forward-mode AD this is a new [[emmy.differential/Dual]] generated by
applying the chain rule and summing the partial derivatives for each perturbed
argument in the input structure.
In reverse-mode, this is a new [[emmy.tape/TapeCell]] containing a sequence of
pairs of each input paired with the partial derivative of `f` with respect to
that input."
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(d/dual? dx) forward-mode-fold
Expand Down
20 changes: 8 additions & 12 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -596,17 +596,6 @@
(o/make-operator #(g/partial-derivative % [])
g/derivative-symbol))

(def D-rev
"Reverse-mode derivative operator..."
(o/make-operator #(tape/gradient % [])
g/derivative-symbol))

(defn partial-rev
"Reverse-mode partial derivative."
[& selectors]
(o/make-operator #(tape/gradient % selectors)
`(~'partial ~@selectors)))

(defn D-as-matrix [F]
(fn [s]
(matrix/s->m
Expand Down Expand Up @@ -687,7 +676,14 @@
(d/tag x))

(tape/tape? x)
(u/illegal "TODO implement this using fmap style.")
(tape/->TapeCell
(tape/tape-tag x)
(tape/tape-id x)
(rec (tape/tape-primal x))
(mapv (fn [[node partial]]
[(rec node)
(rec partial)])
(tape/tape-partials x)))

:else (-> (g/simplify x)
(x/substitute replace-m))))
Expand Down
48 changes: 14 additions & 34 deletions src/emmy/differential.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
;;
;; $$f(a+b\varepsilon) = f(a)+ (Df(a)b)\varepsilon$$
;;
;; > NOTE: See [[lift-1]] for an implementation of this idea.
;; > NOTE: See [[emmy.tape/lift-1]] for an implementation of this idea.
;;
;; This justifies our claim above: applying a function to some dual number
;; $a+\varepsilon$ returns a new dual number, where
Expand Down Expand Up @@ -257,9 +257,9 @@
;; ### What Return Values are Allowed?
;;
;; Before we discuss the implementation of dual
;; numbers (called [[Differential]]), [[lift-1]], [[lift-2]] and the rest of the
;; machinery that makes this all possible; what sorts of objects is `f` allowed
;; to return?
;; numbers (called [[Differential]]), [[emmy.tape/lift-1]], [[emmy.tape/lift-2]]
;; and the rest of the machinery that makes this all possible; what sorts of
;; objects is `f` allowed to return?
;;
;; The dual number approach is beautiful because we can bring to bear all sorts
;; of operations in Clojure that never even _see_ dual numbers. For example,
Expand Down Expand Up @@ -482,16 +482,6 @@
(-> (primal-tangent-pair dx tag)
(nth 0))))

(defn deep-primal
"Version of [[primal]] that will descend recursively into any [[Dual]] instance
returned by [[primal]] until encountering a non-[[Dual]].
Given a non-[[Dual]], acts as identity."
[dx]
(if (dual? dx)
(recur (.-primal ^Dual dx))
dx))

(defn tangent
"If `dx` is an instance of [[Dual]] returns the `tangent` component. Else, returns 0.
Expand Down Expand Up @@ -668,29 +658,19 @@

;; ## Chain Rule and Lifted Functions
;;
;; Finally, we come to the heart of it! [[lift-1]] and [[lift-2]] "lift", or
;; augment, unary or binary functions with the ability to
;; handle [[Dual]] instances in addition to whatever other types they
;; previously supported.
;;
;; These functions are implementations of the single and multivariable Taylor
;; series expansion methods discussed at the beginning of the namespace.
;; For the rest of the story, please see the implementations
;; of [[emmy.tape/lift-1]] and [[emmy.tape/lift-2]]. These functions "lift", or
;; augment, unary or binary functions with the ability to handle [[Dual]]
;; instances in addition to whatever other types they previously supported.
;;
;; There is yet another subtlety here, noted in the docstrings below. [[lift-1]]
;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that
;; can't accept [[Dual]]s. But the first-order derivatives that you have
;; to supply _do_ have to be able to take [[Dual]] instances.
;;
;; This is because the [[tangent]] of [[Dual]] might still be a [[Dual]], and
;; for `Df` to handle this we need to be able to take the second-order
;; derivative.
;;
;; Magically this will all Just Work if you pass an already-lifted function, or
;; a function built out of already-lifted components, as `df:dx` or `df:dy`.

;; TODO port docs above...
;; The [[dual?]] branches inside these functions are implementations of the
;; single and multivariable Taylor series expansion methods discussed at the
;; beginning of the namespace.

;; ## Generic Method Installation
;;
;; These generic methods don't need to be lifted, so live here alongside
;; the [[Dual]] type definition.

(defmethod g/zero-like [::dual] [_] 0)
(defmethod g/one-like [::dual] [_] 1)
Expand Down
98 changes: 33 additions & 65 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,12 @@

(defn tag-of
"More permissive version of [[tape-tag]] that returns `nil` when passed a
non-[[TapeCell]] instance.
TODO note what we handle now."
non-perturbation."
[x]
(cond (tape? x) (tape-tag x)
(d/dual? x) (d/tag x)
:else nil))

;; TODO move tag stuff here?

(defn inner-tag
"Given any number of `tags`, returns the tag most recently bound
via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call
Expand All @@ -330,11 +326,6 @@
d/*active-tags*)
(apply max tags)))

;; TODO we could change `perturbed?` into something like
;; `possible-perturbations`, to get collection types to return sequence of
;; inputs for this. Then we could handle map-shaped inputs etc into literal
;; functions, if we had the proper descriptor language for it.

(defn tag+perturbation
"Given any number of `dxs`, returns a pair of the form
Expand All @@ -346,20 +337,18 @@
If none of `dxs` has an active tag, returns `nil`."
([& dxs]
(let [m (into {} (mapcat
(fn [dx]
(when-let [t (tag-of dx)]
{t dx})))
dxs)]
(let [xform (map
(fn [dx]
(when-let [t (tag-of dx)]
[t dx])))
m (into {} xform dxs)]
(when (seq m)
(let [tag (apply inner-tag (keys m))]
[tag (m tag)])))))

(defn primal-of
"More permissive version of [[tape-primal]] that returns `v` when passed a
non-[[TapeCell]]-or-[[emmy.differential/Dual]] instance.
TODO fix docstring"
non-perturbation."
([v]
(primal-of v (tag-of v)))
([v tag]
Expand All @@ -368,12 +357,11 @@
:else v)))

(defn deep-primal
"Version of [[tape-primal]] that will descend recursively into any [[TapeCell]]
instance returned by [[tape-primal]] until encountering a non-[[TapeCell]].
"Version of [[tape-primal]] that will descend recursively into any perturbation
instance returned by [[tape-primal]] or [[emmy.differential/primal]] until
encountering a non-perturbation.
Given a non-[[TapeCell]], acts as identity.
TODO say what we really do now"
Given a non-perturbation, acts as identity."
([v]
(cond (tape? v) (recur (tape-primal v))
(d/dual? v) (recur (d/primal v))
Expand Down Expand Up @@ -491,20 +479,6 @@
;;
(defrecord Completed [v->partial]
d/IPerturbed
;; TODO note that this can happen because these can pop out from inside of
;; ->partial-fn. And that is currently where the tag-rewriting has to occur.
;;
;; But that is going to be inefficient for lots of intermediate values...
;; ideally we could call this AFTER we select out the IDs. That implies that
;; we want to shove that inside of extract.
;;
;; TODO TODO TODO definitely do this, we definitely want that to happen, don't
;; have those stacked levels, otherwise super inefficient to walk multiple
;; times.
;;
;; TODO AND THEN if that's true then we can delete this implementation, since
;; we'll already be pulled OUT of the completed map.

;; NOTE that it's a problem that `replace-tag` is called on [[Completed]]
;; instances now. In a future refactor I want `get` calls out of
;; a [[Completed]] map to occur before tag replacement needs to happen.
Expand Down Expand Up @@ -550,10 +524,6 @@
- the partial derivative of the output with respect to that value."
[root]
(let [nodes (topological-sort root)

;; TODO this is the spot where we want to wire in many sensitivities. So
;; how would it work, if we set all of the sensitivities for the outputs
;; at once? What would the ordering be as we walked backwards?
sensitivities {(tape-id root) 1}]
(->Completed
(reduce process sensitivities nodes))))
Expand All @@ -570,10 +540,6 @@

(declare ->partials)

;; TODO fix the docstring, and think of how we can combine this into the
;; narrative of what we find in derivative. Maybe this should be the main
;; version?

(defn- ->partials-fn
"Returns a new function that composes a 'tag extraction' step with `f`. The
returned fn will
Expand Down Expand Up @@ -616,11 +582,6 @@
(vector? output)
(mapv #(->partials % tag) output)

;; Here is an example of the subtlety. We MAY want to go one at a
;; time... or we may want to insert some sensitivity entry into the
;; entire structure and roll the entire structure back. We don't do that
;; YET so I bet we can get away with ignoring it for this first PR. But
;; we are close to needing that.
(s/structure? output)
(s/mapr #(->partials % tag) output)

Expand Down Expand Up @@ -755,11 +716,14 @@
(matrix/seq-> (cons x more)))))))

;; ## Lifted Functions

;; [[lift-1]] and [[lift-2]] "lift", or augment, unary or binary functions with
;; the ability to handle [[emmy.differential/Dual]] and [[TapeCell]] instances
;; in addition to whatever other types they previously supported.
;;
;; NOTE these next two functions are similar to the functions
;; in [[emmy.differential]]; both of these should be merged and install methods
;; that can handle the interaction between [[TapeCell]]
;; and [[emmy.differential/Differential]] instances.
;; Forward-mode support for [[emmy.differential/Dual]] is an implementation of
;; the single and multivariable Taylor series expansion methods discussed at the
;; beginning of [[emmy.differential]].
;;
;; To support reverse-mode automatic differentiation, When a unary or binary
;; function `f` encounters a [[TapeCell]] `x` (and `y` in the binary case) it
Expand All @@ -778,10 +742,19 @@
;; ````
;;
;; in the binary case.

;; There is a subtlety here, noted in the docstrings below. [[lift-1]]
;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that
;; can't accept [[emmy.differential/Dual]] and [[TapeCell]]s. But the
;; first-order derivatives that you have to supply _do_ have to be able to take
;; instances of these types.
;;
;; The partial derivative implementations are passed in directly or retrieved
;; from the generic implementation using the same method as in
;; the [[emmy.differential]] versions, hinting again that we should unify these.
;; This is because, for example, the [[emmy.differential/tangent]] of [[Dual]]
;; might still be a [[Dual]], and will hit the first-order derivative via the
;; chain rule.
;;
;; Magically this will all Just Work if you pass an already-lifted function, or
;; a function built out of already-lifted components, as `df:dx` or `df:dy`.

(defn lift-1
"Given:
Expand Down Expand Up @@ -875,7 +848,7 @@
(cond (tape? dx) (operate-reverse tag)
(d/dual? dx) (operate-forward tag)
:else
(u/illegal "Non-tape or differential perturbation!"))
(u/illegal "Non-tape or dual perturbation!"))
(f x y))))))

(defn lift-n
Expand Down Expand Up @@ -948,13 +921,8 @@

(defn ^:no-doc by-primal
"Given some unary or binary function `f`, returns an augmented `f` that acts on
the primal entries of any [[TapeCell]] arguments encountered, irrespective of
tag.
Given a [[TapeCell]] with a [[TapeCell]] in its [[primal-part]], the returned
`f` will recursively descend until it hits a non-[[TapeCell]].
TODO fix docs"
the primal entries of any perturbed arguments encountered, irrespective of
tag."
[f]
(fn
([x] (f (deep-primal x)))
Expand Down
Loading

0 comments on commit d096f70

Please sign in to comment.