## The Stableswap equation

The Stableswap equation given in the Curve whitepaper is

$$ An^n\sum_i x_i + D = An^n D + \frac{D^{n+1}}{n^n\prod\limits_i x_i}$$

$D$ is the stableswap invariant.

$A$ is the amplification coefficient.

The $x_i$'s are the pool balances for each token.

## Solving for $D$ using Newton's method

To solve for $D$ in the stableswap equation, we can use the auxiliary form of the equation

$$ f(D) = 0 $$

where

$$ f(D) = An^n D + \frac{D^{n+1}}{n^n\prod x_i} - An^n S - D $$

Note $f(P) < 0$ and $f(S) > 0$ (this is simple to see using that $P <= S$ with equality only when $x_1 = x_2$).  Since $P < S$ in the generic case, we expect that somewhere in-between, there is a $D$ such that $f(D) = 0$.  In fact, the situation is much better than that.  

The derivative of $f$ is

$$ f'(D) = An^n + (n+1) \frac{D^n}{n^n\prod x_i} - 1 $$

Since $f'$ is always be negative (as long as $A > 1$), $f$ is strictly decreasing and Newton's method will rapidly find a solution.

Newton's method gives the iteration:

$$ x_{k+1} = x_k - \frac{f(x_k)}{f'(x_k)} $$

After some cleanup, this gives the iterative formula:

$$ D = \frac{(nD_p + An^nS)D}{(n+1)D_p + (An^n - 1)D} $$

where $D_p = \frac{D^{n+1}}{n^n\prod x_i}$

## Convergence under integer arithmetic

Due to the necessity of implementing the above iterative formula in integer arithmetic for the EVM, some trickiness arises.

First, note that $D_p$ is computed in the code iteratively to avoid overflow.  This necessitates doing multiple integer divisions which truncate.

Second (and more seriously), the iterative formula for $D$ does a final integer division, which effectively floors $D$ and can push the iterative result to below the correct value.  The next iteration of the formula then gives a result above the correct value, leading to an oscillation around the correct value.  This can result in an infinite loop as we demonstrate.

### Calculation in vyper (from 3Pool)

```python
def get_D(xp: uint256[N_COINS], amp: uint256) -> uint256:
    S: uint256 = 0
    for _x in xp:
        S += _x
    if S == 0:
        return 0

    Dprev: uint256 = 0
    D: uint256 = S
    Ann: uint256 = amp * N_COINS
    for _i in range(255):
        D_P: uint256 = D
        for _x in xp:
            D_P = D_P * D / (_x * N_COINS)
        Dprev = D
        D = (Ann * S + D_P * N_COINS) * D / ((Ann - 1) * D + (N_COINS + 1) * D_P)
        if D > Dprev:
            if D - Dprev <= 1:
                break
        else:
            if Dprev - D <= 1:
                break
    return D
```

In [113]:
from math import prod

def f(D, xp):
    """Useful to check accuracy of solution."""
    S = sum(xp)
    n = len(xp)
    Ann = A * n**n
    return Ann * D + D**(n+1) / (n**n * prod(xp)) - Ann * S - D
    

In [122]:
# this causes the prod ("real") calc to bork
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 5 * 10**18

In [118]:
A = 100
x_1 = 98_500_000 * 10**18
x_2 = 500_000 * 10**18

In [126]:
# iterative calculation for D invariant

xp = [x_1, x_2]

n = len(xp)
S = sum(xp)


Ann = A * n**n

## Commented-out since this can enter an infinite loop!
# D = S
# Dprev = 0
# while abs(Dprev - D) > 1:
#     print("D:", D, "diff:", Dprev - D)
#     D_P = D
#     for x in xp:
#         D_P = D_P * D // (x * n)
#     Dprev = D
#     D = (Ann * S + D_P * n) * D // ((Ann - 1) * D + (n + 1) * D_P)
#     if Dprev < D:
#         print("bounce", Dprev, D)
# print("Solution:", D, "diff:", Dprev - D)
# print("f(D):", f(D, xp))

print("")
print("Integer version with 'bounce' stop condition:\n")

D = S
Dprev = S + 1
while Dprev > D:
    print("D:", D, "diff:", Dprev - D)
    D_P = D
    for x in xp:
        D_P = D_P * D // (x * n)
    Dprev = D
    D = (Ann * S + D_P * n) * D // ((Ann - 1) * D + (n + 1) * D_P)
print("D:", D, "diff:", Dprev - D)
print("Solution:", Dprev)
print("f(D):", f(Dprev, xp))

print("")
print("Floating point version:\n")

D = S
Dprev = S + 1
while Dprev > D:
    print("D:", D, "diff:", Dprev - D)
    D_P = D**(n+1) / (prod(xp) * n**n)
    Dprev = D
    D = (Ann * S + D_P * n) * D / ((Ann - 1) * D + (n + 1) * D_P)
print("D:", D, "diff:", Dprev - D)
print("Solution:", int(Dprev))
print("f(D):", f(int(Dprev), xp))



Integer version with 'bounce' stop condition:

D: 98500005000000000000000000 diff: 1
D: 65667563309164106919876098 diff: 32832441690835893080123902
D: 43781715209300009448160977 diff: 21885848099864097471715121
D: 29197316771515518644498080 diff: 14584398437784490803662897
D: 19489237612461864965045431 diff: 9708079159053653679452649
D: 13051938117784021218456290 diff: 6437299494677843746589141
D: 8839573172235296387956430 diff: 4212364945548724830499860
D: 6203362690289300030686614 diff: 2636210481945996357269816
D: 4775397704395200157228709 diff: 1427964985894099873457905
D: 4269095004311729273569827 diff: 506302700083470883658882
D: 4205219681045563264055971 diff: 63875323266166009513856
D: 4204253928655881395185073 diff: 965752389681868870898
D: 4204253710021333751122052 diff: 218634547644063021
D: 4204253710021322547503443 diff: 11203618609
D: 4204253710021322547503433 diff: 10
D: 4204253710021322547503442 diff: -9
Solution: 4204253710021322547503433
f(D): -1726039982080.0

Float