# The Inverse Logit Function

In the theoretical description of the logistic regression probabilistic model, we come upon the inverse logit function:

$$
z \to \frac{1}{1 + \exp(-z)}
$$

This function is also sometimes expressed using the softmax-style representation:
    
$$
z \to \frac{\exp(z)}{\exp(z) + 1}
$$

When implementing this function in code, which of these forms (if either) should we use?

The correct answer to this question is not immediately obvious as evidenced by a sequence of commits to the popular SciPy library:

* [A commit to reverse the direction of a branch that switches between the two expressions](https://github.com/scipy/scipy/commit/47b73a3a87b86b4e18a7f21482eeaa66e26467d0)
* [A commit to remove the branch completely](https://github.com/scipy/scipy/commit/30e181c1179177bd4e40c240ca70ce3b82dac873)

# Our Analysis Strategy

To decide between these two expressions, we need to consider how the optimal 64-bit floating point implementation would behave and then consider how these specific expressions fail relative to that unknown optimal implementation. To do that, we'll do some analysis that combines theoretical reasoning and empirical checks in Julia.

In [1]:
import Printf: @printf
import Statistics: mean

Throughout this analysis, we'll make use of Julia's built-in BigFloat type, which wraps the MPFR library and allows us to define the global precision. We'll do so now:

In [2]:
setprecision(256)

256

Given access to 256-bit floating point numbers, we can evaluate the accuracy of an expression that operates on 64-bit floating point numbers by lifting it to operate on BigFloat's and then truncatinig down to Float64 again. We define a `lift` function to do this:

In [3]:
lift(e) = x -> Float64(e(big(x)))

lift (generic function with 1 method)

To make use of this lifting and comparison strategy, we'll write a helper function that spits out diagnostics about the accuracy of an expression over a set of input values. In particular, we'll consider the average number of incorrect bits in the output, the rate at which the output matches the optimal output exactly, the average error and the maximium error. Our function to do this is called `evaluate_errors`:

In [4]:
function evaluate_errors(e, xs)
    y_approx = map(e, xs)
    y_optimal = map(lift(e), xs)
    wrong_bits = count_ones.(
        xor.(
            map(f -> reinterpret(Int64, f), y_approx),
            map(f -> reinterpret(Int64, f), y_optimal),
        )
    )
    errs = abs.(y_approx - y_optimal)
    @printf(
        """
        Frequency of Exact Results:       %s
        Average Error:                    %s
        Maximum Error:                    %s
        Average Number of Incorrect Bits: %s
        """,
        mean(errs .== 0.0),
        mean(errs),
        maximum(errs),
        mean(wrong_bits),
    )
    return
end

evaluate_errors (generic function with 1 method)

# Our Theoretical Analysis

We first want to determine the range of inputs over which an optimal implementation would produce meaningful results. That range is determined by how closely 64-bit floats can come to 0.0 and 1.0. Specifically, no solution will ever produce accurate numbers past the smallest float above 0.0 and the largest float below 1.0 except by generating the constant values 0.0 and 1.0. We compute those closest points using the Julia functions `nextfloat` and `prevfloat`:

In [5]:
nextfloat(0.0)

5.0e-324

In [6]:
prevfloat(1.0)

0.9999999999999999

We can get a sense of the input range that generates these by using 256-bit precision floats to compute the logit of these outputs and then finding the closest 64-bit floating point values.

In [7]:
logit(p) = log(p / (1 - p))

logit (generic function with 1 method)

In [8]:
theoretical_lower = lift(logit)(nextfloat(0.0))

-744.4400719213812

In [9]:
theoretical_upper = lift(logit)(prevfloat(1.0))

36.7368005696771

We know that all `z < theoretical_lower` should generate `0.0` as a result and all `z > theoretical_upper` should generate `1.0` as a result. So we only need to understand how well any proposed implementation works within these bounds.

## Finding a Good Expression

We'll consider the two options listed at the start:

In [10]:
e1 = z -> 1 / (1 + exp(-z))

#9 (generic function with 1 method)

In [11]:
e2 = z -> exp(z) / (exp(z) + 1)

#11 (generic function with 1 method)

To get started thinking about these expressions, let's consider how they behave as `z` heads toward `-Inf` or `Inf`.

As `z` heads towards `-Inf`, the dominating factor for `e1` will be when `1 + exp(-z)` goes to `Inf`, since `1.0 / Inf === 0.0`. If `1 + exp(-z) === Inf`, then `exp(-z) === Inf`. What happens right before that? Before that, we expect that roughly `exp(-z) === floatmax(Float64)`, which implies that `z = -log(floatmax(Float64))`. We can calculate this exactly to see the bound on how low `z` can be before overflow kicks in:


In [12]:
-log(floatmax(Float64))

-709.782712893384

We can also directly confirm that `e1` starts to generate exact `0.0` values below that point:

In [13]:
(
    e1(-log(floatmax(Float64))),
    e1(prevfloat(-log(floatmax(Float64))))
)

(5.562684646268137e-309, 0.0)

This leaves us with a region from `theoretical_lower` to `prevfloat(-log(floatmax(Float64)))` in which we might hope we can do better. That region is:

In [14]:
(
    theoretical_lower,
    prevfloat(-log(floatmax(Float64))),
)

(-744.4400719213812, -709.7827128933841)

In contrast to `e1`, `e2` has an interesting behavior in this region: for very negative `z`, `exp(z) + 1 === 1.0`, so the overall expression for `e2` becomes identical to evaluating `exp(z)` alone. This change happens roughly when `exp(z) < eps(1.0)`, which is:

In [15]:
log(eps(1.0))

-36.04365338911715

Can we convince ourselves that `exp(z)` produces useful results in the region in which `e1` overflows? One thing we can do as a basic sanity check is verify that the the result is correct at `theoretical_lower`:

In [16]:
exp(theoretical_lower) === nextfloat(0.0)

true

This suggests `e2` should be superior to `e1` in the region in which we know `e1` is suspectible to overflow. Later we'll convince ourselves that it's broadly true that `e2` is superior.

Having considered what happens as `z` heads towards `-Inf`, let's see what happpens as `z` heads towards `Inf`.

As `z` heads to `Inf`, `e1` is dominated by errors when `1 + exp(-z) === 1.0`. This happens roughly when `exp(-z) < eps(1.0)`. That should occur around:

In [17]:
-log(eps(1.0))

36.04365338911715

In contrast, as `z` heads to `Inf`, `e2` is dominated by errors when `exp(z) + 1 === exp(z)`. That occurs when `eps(exp(z)) > 1`.

From this [tweet](https://twitter.com/i/status/1245691382607601666), this occurs around `log(2^53)`, which is:

In [18]:
log(2^53)

36.7368005696771

So we might expect that `e2` is better `e1` in the region between `-log(eps(1.0))` and `log(2^53)`.

Before we move on to verifying our ideas, we should summarize all of the magic numbers we've encountered so far:

In [19]:
(
    theoretical_lower, # Value below which invlogit(z) must result in 0.0.
    -log(floatmax(Float64)), # Value below which exp(-z) overflows to Inf.
    -log(eps(1.0)), # Value above which exp(-z) underflows below eps(1.0), so 1 + exp(-z) === 1.0.
    log(2^53), # Value above which exp(z) + 1 === exp(z).
    theoretical_upper, # Value above which invlogit(z) must result in 1.0.
)

(-744.4400719213812, -709.782712893384, 36.04365338911715, 36.7368005696771, 36.7368005696771)

There is one limitation for `e2` that we haven't called out: it will overflow to `Inf / (Inf + 1)` for large `z`, which generates a `NaN` result. We should guard the function to prevent this. This is easy to do using `theoretical_upper`.

# Empirical Assessment

So far we've used theoretical arguments. Going forward, we'll evaluate our expressions against optimal results calculated using 256-bit floating point numbers using the `evaluate_errors` function we defined earlier. We'll consider three regions:

* `theoretical_lower` to `-log(floatmax(Float64))`
* `-log(floatmax(Float64))` to `-log(eps(1.0))`
* `-log(eps(1.0))` to `theoretical_upper`

In [20]:
n = 10_000
lower = theoretical_lower
upper = -log(floatmax(Float64))
xs = range(lower, upper, length=n)

-744.4400719213812:0.00346608251105083:-709.782712893384

In [21]:
evaluate_errors(e1, xs)

Frequency of Exact Results:       0.0001
Average Error:                    1.6021112698925e-310
Maximum Error:                    5.54343729801906e-309
Average Number of Incorrect Bits: 13.1108


In [22]:
evaluate_errors(e2, xs)

Frequency of Exact Results:       0.9964
Average Error:                    0.0
Maximum Error:                    5.0e-324
Average Number of Incorrect Bits: 0.0079


If you're worried that we're comparing our expressions against their lifted versions and therefore potentially not making a fair comparison, note that `lift(e1)` and `lift(e2)` evaluate to the same thing in this region:

In [23]:
all(map(lift(e1), xs) .== map(lift(e2), xs))

true

Let's repeat for the other two regions:

In [24]:
n = 10_000
lower = -log(floatmax(Float64))
upper = -log(eps(1.0))
xs = range(lower, upper, length=n)

-709.782712893384:0.0745900956378139:36.04365338911715

In [25]:
evaluate_errors(e1, xs)

Frequency of Exact Results:       0.7058
Average Error:                    2.738026739281825e-18
Maximum Error:                    1.1102230246251565e-16
Average Number of Incorrect Bits: 0.575


In [26]:
evaluate_errors(e2, xs)

Frequency of Exact Results:       0.889
Average Error:                    1.0784458228483968e-19
Maximum Error:                    1.1102230246251565e-16
Average Number of Incorrect Bits: 0.2148


In [27]:
n = 10_000
lower = -log(eps(1.0))
upper = theoretical_upper
xs = range(lower, upper, length=n)

36.04365338911715:6.932165022101682e-5:36.7368005696771

In [28]:
evaluate_errors(e1, xs)

Frequency of Exact Results:       0.415
Average Error:                    6.494804694057165e-17
Maximum Error:                    1.1102230246251565e-16
Average Number of Incorrect Bits: 0.585


In [29]:
evaluate_errors(e2, xs)

Frequency of Exact Results:       1.0
Average Error:                    0.0
Maximum Error:                    0.0
Average Number of Incorrect Bits: 0.0


# Putting It All Together

Taken together, I think this suggest the best implementation should look something like, but potentially using `ifelse` to make the function amenable to SIMD optimizations:

In [30]:
function invlogit(z)
    if z < -744.4400719213812 # Float64(logit(big(nextfloat(0.0))))
        0.0
    elseif z > 36.7368005696771 # Float64(logit(big(prevfloat(1.0))))
        1.0
    else
        exp(z) / (exp(z) + 1)
    end
end

invlogit (generic function with 1 method)

In [39]:
invlogit(-10000), e1(-10000), e2(-10000)

(0.0, 0.0, 0.0)

In [40]:
invlogit(theoretical_lower), e1(theoretical_lower), e2(theoretical_lower)

(5.0e-324, 0.0, 5.0e-324)

In [41]:
invlogit(-log(floatmax(Float64))), e1(-log(floatmax(Float64))), e2(-log(floatmax(Float64)))

(5.562684646268137e-309, 5.562684646268137e-309, 5.562684646268137e-309)

In [42]:
invlogit(-1), e1(-1), e2(-1)

(0.2689414213699951, 0.2689414213699951, 0.2689414213699951)

In [43]:
invlogit(0), e1(0), e2(0)

(0.5, 0.5, 0.5)

In [44]:
invlogit(1), e1(1), e2(1)

(0.7310585786300049, 0.7310585786300049, 0.7310585786300049)

In [45]:
invlogit(-log(eps(1.0))), e1(-log(eps(1.0))), e2(-log(eps(1.0)))

(0.9999999999999998, 0.9999999999999998, 0.9999999999999998)

In [46]:
invlogit(10000), e1(10000), e2(10000)

(1.0, 1.0, NaN)