In [311]:
import math
import numpy as np

In [312]:
p = [1.0, 2.0]
# Contract amplification parameter (equal to A * n^(n-1))
amplificationParameter = 1
D = 10000

In [313]:
# p = [
#     999.999000000000002679, 
#     996.803511277021540000,
#     999.999000000000000000,
#     999.999000000000000000,
#     999.999000000000000000
# ]
# amplificationParameter = 20555
# D = 495000000.000000881116305497

In [314]:
# Stable Oracle Functions

def computeBalancesForPrices(amplificationParameter, invariant, prices):
    a, b = computeAandB(amplificationParameter, prices)
    n = len(prices)
    k = computeK(amplificationParameter, a, b, n, prices)
    
    sumPriceDivision = 0
    for pi in prices:
        sumPriceDivision += a/(k*pi - a)
    
    balancesForPrices = np.zeros_like(prices)
    for i in np.arange(0, n):
        balancesForPrices[i] = (b*invariant/(a - k*p[i]))/(1-sumPriceDivision)

    return balancesForPrices

def computeK(amplificationParameter, a, b, n, prices):
    k = (min([1 + 1/(1+b), 2-b/a]) * a/min(prices))
    for i in range(255):
        T, dTdk, dPdkDivP, alphaDivPTn = computeKparams(amplificationParameter, k, a, b, n, prices)
        
        fk = T - alphaDivPTn    
        flk = ((n+1)*dTdk + T*dPdkDivP)
        
        newK = k - fk / flk

        # Error of 0.000000000000001%
        if (abs(newK - k)/k < math.pow(10, -15)):
            return newK
        k = newK
    raise RuntimeError("KDidntConverge()")

# Notice that this function returns P'/P as `dPdkDivP`, and alpha/(P*T^n) as `alphaDivPTn`. This is a numerical device used by Solidity to avoid overflows.
def computeKparams(amplificationParameter, k, a, b, n, prices):
    T = 0
    dTdk = 0
    dPdkDivP = 0
    alphaDivPTn = b
    for pi in prices:
        ri = pi/a
        den = (k*ri - 1)
        T += 1/den
        dTdk -= ri/(den*den)
        dPdkDivP += ri/den
        alphaDivPTn = alphaDivPTn * b/a
    T -= 1

    for pi in prices:
        ri = pi/a
        den = (k*ri - 1)
        alphaDivPTn = alphaDivPTn / (den * T)
    
    return T, dTdk, dPdkDivP, alphaDivPTn

def computeAandB(amplificationParameter, prices):
    n = len(prices)
    # Calculates Curve's A
    A = amplificationParameter/math.pow(n, n-1)
    nn = math.pow(n, n)
    a = A * nn * nn
    b = a - nn
    return a, b

In [315]:
# Stable Math functions

def compute_invariant(A, balances):
    # Calculate the sum of balances
    total_balance = sum(balances)
    n = len(balances)
    prodBalances = 1
    for i in range(n):
        prodBalances = prodBalances * balances[i]

    # Initial invariant
    D = total_balance

    # Iteratively compute the invariant
    for i in range(500):
        fD = math.pow(D, n+1)/(math.pow(n, n)*prodBalances) + A*math.pow(n, n)*D - D - A*math.pow(n, n)*total_balance
        flD = (n+1)*math.pow(D, n)/(math.pow(n, n)*prodBalances) + A*math.pow(n, n) - 1
        newD = D - fD/flD

        # Check for convergence
        if abs(newD - D) < math.pow(10, -5):
            return newD
        D = newD

    raise RuntimeError("StableInvariantDidntConverge()")

def computeBalance(A, balances, D, tokenIndex):
    x = balances[tokenIndex]
    n = len(balances)
    nn = math.pow(n, n)
    prodBalances = 1
    sumBalances = 0
    for i in range(len(balances)):
        if (i != tokenIndex):
            prodBalances = prodBalances * balances[i]
            sumBalances += balances[i]
        
    for i in range(255):
        fx = A*nn*(sumBalances+x) + D - A*nn*D - math.pow(D, n+1)/(nn*prodBalances*x)
        flx = A*nn + math.pow(D, n+1)/(nn*prodBalances*math.pow(x,2))
        newX = x - fx/flx
        if (abs(x - newX) < math.pow(10, -15)):
            return newX
        x = newX
    raise RuntimeError("StableInvariantDidntConverge()")

def computeOutGivenIn(A, balances, D, tokenInIndex, tokenOutIndex, amountIn):
    newBalances = balances.copy()
    newBalances[tokenInIndex] = newBalances[tokenInIndex] + amountIn
    newBalances[tokenOutIndex] = computeBalance(A, newBalances, D, tokenOutIndex)
    return balances[tokenOutIndex] - newBalances[tokenOutIndex]

In [317]:
balancesForPrices = computeBalancesForPrices(amplificationParameter, D, p)
print('balancesForPrices:', balancesForPrices)

n = len(p)
A = amplificationParameter/(math.pow(n, n-1))
newD = compute_invariant(A, balancesForPrices)
print('old invariant:', D)
print('new invariant:', newD)

tvl = 0
for i in range(len(p)):
    tvl += p[i]*balancesForPrices[i]

print('TVL:', tvl)

amountIn = 0.01
pBase = p[0]
pFinal = np.zeros_like(p)
pFinal[0] = pBase
for i in range(len(p) - 1):
    amountOut = computeOutGivenIn(A, balancesForPrices, D, 0, i+1, amountIn)
    pFinal[i+1] = pBase * amountIn/amountOut

print('original prices:', p)
print('final prices   :', pFinal)


balancesForPrices: [8478.1038478  2476.17812932]
old invariant: 10000
new invariant: 9999.999999999998
TVL: 13430.460106436785
original prices: [1.0, 2.0]
final prices   : [1.         2.00000216]
