## The Stableswap equation

The Stableswap equation given in the Curve whitepaper is

$$ An^n\sum_i x_i + D = An^n D + \frac{D^{n+1}}{n^n\prod\limits_i x_i}$$

$D$ is the stableswap invariant.

$A$ is the amplification coefficient.

The $x_i$'s are the pool balances for each token.

## Solving for $D$ using Newton's method

To solve for $D$ in the stableswap equation, we can use the auxiliary form of the equation

$$ f(D) = 0 $$

where

$$ f(D) = An^n D + \frac{D^{n+1}}{n^n\prod x_i} - An^n S - D $$

Note $f(P) < 0$ and $f(S) > 0$ (this is simple to see using that $P <= S$ with equality only when $x_1 = x_2$).  Since $P < S$ in the generic case, we expect that somewhere in-between, there is a $D$ such that $f(D) = 0$.  In fact, the situation is much better than that.  

The derivative of $f$ is

$$ f'(D) = An^n + (n+1) \frac{D^n}{n^n\prod x_i} - 1 $$

Since $f'$ is always be negative (as long as $A > 1$), $f$ is strictly decreasing and Newton's method will rapidly find a solution.

Newton's method gives the iteration:

$$ x_{k+1} = x_k - \frac{f(x_k)}{f'(x_k)} $$

After some cleanup, this gives the iterative formula:

$$ D = \frac{(nD_p + An^nS)D}{(n+1)D_p + (An^n - 1)D} $$

where $D_p = \frac{D^{n+1}}{n^n\prod x_i}$

## Convergence under integer arithmetic

Due to the necessity of implementing the above iterative formula in integer arithmetic for the EVM, some trickiness arises.

First, note that $D_p$ is computed in the code iteratively to avoid overflow.  This necessitates doing multiple integer divisions which truncate.

Second (and more seriously), the iterative formula for $D$ does a final integer division, which effectively floors $D$ and can push the iterative result to below the correct value.  The next iteration of the formula then gives a result above the correct value, leading to an oscillation around the correct value.  This can result in an infinite loop as we demonstrate.

### Calculation in vyper (from 3Pool)

```python
def get_D(xp: uint256[N_COINS], amp: uint256) -> uint256:
    S: uint256 = 0
    for _x in xp:
        S += _x
    if S == 0:
        return 0

    Dprev: uint256 = 0
    D: uint256 = S
    Ann: uint256 = amp * N_COINS
    for _i in range(255):
        D_P: uint256 = D
        for _x in xp:
            D_P = D_P * D / (_x * N_COINS)
        Dprev = D
        D = (Ann * S + D_P * N_COINS) * D / ((Ann - 1) * D + (N_COINS + 1) * D_P)
        if D > Dprev:
            if D - Dprev <= 1:
                break
        else:
            if Dprev - D <= 1:
                break
    return D
```

In [113]:
from math import prod

def f(D, xp, A):
    """Useful to check accuracy of solution."""
    S = sum(xp)
    n = len(xp)
    Ann = A * n**n
    return Ann * D + D**(n+1) / (n**n * prod(xp)) - (Ann * S + D)
    
    
def f_prime(D, xp, A):
    return Ann + (n+1) * D**n / (n**n * prod(xp)) - 1

def f_double_prime(D, xp, A):
    return n * (n+1) * D**(n-1) / (n**n * prod(xp))


In [69]:
# this causes the prod ("real") calc to bork
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 5 * 10**18

In [102]:
# does NOT bork! but does bounce back and forth
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 500 * 10**18

In [90]:
# does NOT bork! but does bounce back and forth
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 50 * 10**18

In [99]:
# falls into an interesting cycle
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 8 * 10**18

In [60]:
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 500_000 * 10**18

In [54]:
A = 100
x_1 = 50_000_000 * 10**18
x_2 = 1_000_000 * 10**18

In [50]:
A = 100
x_1 = 50_000_000 * 10**18
x_2 = 35_000_000 * 10**18

In [56]:
A = 100
x_1 = 90_000_000 * 10**18
x_2 = 10_000_000 * 10**18

In [120]:
# iterative calculation for D invariant

xp = [x_1, x_2]

n = len(xp)
S = sum(xp)


Ann = A * n**n

## Commented-out since this can enter an infinite loop!
D = S
Dprev = 0
i = 0
while abs(Dprev - D) > 1:
    D_P = D
    for x in xp:
        D_P = D_P * D // (x * n)
    Dprev = D
    D = (Ann * S + D_P * n) * D // ((Ann - 1) * D + (n + 1) * D_P)
    print(D, "diff:", Dprev - D)
    i += 1
    if i > 25:
        print("something is very wrong")
        break
print("Solution:", D, "Iterations:", i)
print("f(D):", f(D, xp, A))

print("")
print("Integer version with 'bounce' stop condition:\n")

D = S
Dprev = S + 1
i = 0
while Dprev > D:
    D_P = D
    for x in xp:
        D_P = D_P * D // (x * n)
    Dprev = D
    D = (Ann * S + D_P * n) * D // ((Ann - 1) * D + (n + 1) * D_P)
    print("D:", D, "diff:", Dprev - D)
    print("f''(D) / f'(D):", f_double_prime(Dprev, xp, A) / f_prime(Dprev, xp, A))
    print(int((18478317168485044777940646 - Dprev)**2 * f_double_prime(Dprev, xp, A) / f_prime(Dprev, xp, A)))
    print(D - 18478317168485044777940646)
    i += 1
print("Solution:", Dprev, "Iterations:", i)
print("f(D):", f(Dprev, xp, A))
print("f'(D):", f_prime(Dprev, xp, A))
print("f''(D):", f_double_prime(Dprev, xp, A))
print("f''(D) / f'(D):", f_double_prime(Dprev, xp, A) / f_prime(Dprev, xp, A))

print("")
print("Floating point version:\n")

D = S
Dprev = S + 1
i = 0
while Dprev > D:
    D_P = D**(n+1) / (prod(xp) * n**n)
    Dprev = D
    D = (Ann * S + D_P * n) * D / ((Ann - 1) * D + (n + 1) * D_P)
    print("D:", D, "diff:", Dprev - D)
    i += 1
print("Solution:", int(Dprev), "Iterations:", i)
print("f(D):", f(int(Dprev), xp, A))



65756092287907639975314212 diff: 32744407712092360024685788
44168124704128399818935301 diff: 21587967583779240156378911
30363857632889687695239839 diff: 13804267071238712123695462
22411924665442352735413816 diff: 7951932967447334959826023
19096103336669527471340793 diff: 3315821328772825264073023
18496762595037143446940757 diff: 599340741632384024400036
18478334248576893518569115 diff: 18428346460249928371642
18478317168499707301101656 diff: 17080077186217467459
18478317168485044777940658 diff: 14662523160998
18478317168485044777940649 diff: 9
18478317168485044777940646 diff: 3
18478317168485044777940648 diff: -2
18478317168485044777940649 diff: -1
Solution: 18478317168485044777940649 Iterations: 13
f(D): 0.0

Integer version with 'bounce' stop condition:

D: 65756092287907639975314212 diff: 32744407712092360024685788
f''(D) / f'(D): 2.0249781326079244e-26
129670482049358726472663040
47277775119422595197373566
D: 44168124704128399818935301 diff: 21587967583779240156378911
f''(D) / f'(D

In [110]:
len('14662523160998')

14