# Objectives
* Create a Python object that we can use to manage state and simulate the behavior of Curve pools. We want to do this to:
    * Understand how perturbing pool parameters affects pool metrics 
        * e.g. For a pool with two assets I want to be observe the price impact on asset X when I add/subtract some amount of asset X. I want to produce the plots of reserves vs. price and delta_reserves vs. price.
    * Integrate with systems to import on-chain state into the data structures we define for modeling AMM pools. For example, I want this so I can use historical data Thrackle is gathering about Curve pools to contextualize IRL events with scenarios we want to simulate.

# Brainstorm: what is the right data structure for us to design as an abstract representation of an AMM pool? 

"Bad programmers worry about the code. Good programmers worry about data structures and their relationships." - Linus Torvalds

Let's start by thinking about the smart contract structure for a pool in:
1. Curve
2. Uniswap

The framework for each of these should be extensible to modeling AMM pools from other protocols too. 

## What is the data structure (abstract class) that best represents a generic AMM pool? 
* Observation: Most of our questions tend to revolve around understanding the state machine of a single pool. 
* Observation: There are state machines at several levels: consensus, protocol/app, pool.
    * Decision: It makes sense to start specific and then broaden out for this project. Thrackle already has several projects in flight to understand the protocol/app layer. We want to start by modeling a more narrow, granular system - namely an indivdual AMM pool.  

## Exercise TODO: enumerate the types of events relevant to Curve and Uniswap contracts. 
* How do we represent these events in our data structure? 
    * Which events are common to all AMMs and belong in the abstract class?
    * Which events have unique implementations in each AMM and belong in child classes (e.g. CurveCryptoPool(Pool))? 
    * e.g. is `update_price` common to all AMM pools?

### Base class

In [92]:
from typing import Union, Optional, List
from abc import ABC, abstractmethod

class Pool(ABC): 
    
    '''
    Abstract class for a specific protocol's flavor of AMM pool.
    '''
    
    def __init__(self, 
                 token_symbols : List[str], 
                 token_prices : List[float],
                 token_balances : List[int] = None,
                 exchange_fee : float = 0.0) -> None:
        
        ### TODO: Think about __init__ pool state by reading on-chain data.
        ###       This is going to be child class-specific. What is interface? 
                  
        self.token_symbols = token_symbols
        self.token_prices = token_prices
        self.token_balances = token_balances
        
    def _get_token_id(self, symbol : str) -> int:
        return self.token_symbols.index(symbol)
    
    def __repr__(self) -> str:
        return "\n".join([
            f"{s} has balance {b} at price {p} with total value {b*p}"
            for s, b, p in zip(self.token_symbols, self.token_balances, self.token_prices)
        ])
    
    @abstractmethod
    def add_liquidity(self, add_amounts : Union[int, List[int]]) -> None:
        raise NotImplementedError
    
    @abstractmethod
    def withdraw_liquidity(self, withdraw_amounts : Union[int, List[int]]) -> None:
        raise NotImplementedError
    
    @abstractmethod
    def exchange(self, 
                 sell_token_symbol : str,
                 buy_token_symbol : str,
                 sell_amount : int) -> None:
        raise NotImplementedError
    
    @property
    def exchange_fee(self) -> float:
        return self._exchange_fee

    @exchange_fee.setter
    def exchange_fee(self, fee : float) -> None:
        if fee < 0:
            raise ValueError("Negative fees seems like questionable mechanism design...")
        self._exchange_fee = fee

### Example child class

In [93]:
from typing import Union, Optional, List

class FakePool(Pool):
    
    def __init__(self, 
                 token_symbols : List[str], 
                 token_prices : List[float],
                 token_balances : List[int] = None,
                 exchange_fee : float = 0.0,
                 pool_params : dict = {}) -> None:

        super().__init__(token_symbols, token_prices, token_balances, exchange_fee)
        self.pool_params = pool_params
        
    def tweak_price(self):
        # There is no price update. This fake pool is an arbitrage goldmine! 
        pass
        
    def add_liquidity(self, 
                      symbol : Optional[str],
                      add_amounts : Union[int, List[int]]) -> None:
        pass
    
    def withdraw_liquidity(self, 
                           symbol : Optional[str],
                           withdraw_amounts : Union[int, List[int]]) -> None:
        pass
    
    def exchange(self, 
                 sell_token_symbol : str,
                 buy_token_symbol : str,
                 sell_amount : int) -> None:
        
        sell_token_id = self._get_token_id(sell_token_symbol)
        buy_token_id = self._get_token_id(buy_token_symbol)
        
        # There is no notion of address to send/receive funds to/from.
        self.token_balances[sell_token_id] -= sell_amount
        
        buy_amount = 0 # DUMB TRADE! This AMM is sneaky.
        
        self.token_balances[buy_token_id] -= buy_amount
        self.tweak_price()
        

fake_pool = FakePool(
    token_symbols = ["token 1", "token 2"],
    token_prices = [10, 20],
    token_balances = [200, 100],
    exchange_fee = 0.05
)

In [94]:
fake_pool._get_token_id("token 2")

1

In [95]:
fake_pool

token 1 has balance 200 at price 10 with total value 2000
token 2 has balance 100 at price 20 with total value 2000

In [96]:
fake_pool.exchange("token 1", "token 2", 20)
fake_pool

token 1 has balance 180 at price 10 with total value 1800
token 2 has balance 100 at price 20 with total value 2000

### Curve Tricrypto Pool

In [101]:
from typing import Union, Optional, List

class CurvePool(Pool):
    
    
    def __init__(self, 
                 
                 ### Curve pool params
                 A : int, # Amplification coefficient
                 D : Union[int, List[int]] = None, # Total deposit size
                 tokens = None,
                 
                 ### Generic AMM params
                 token_symbols : List[str] = [], 
                 token_prices : List[float] = [],
                 token_balances : List[int] = None,
                 exchange_fee : float = 0.0
                 
        ) -> None:   
        
        self.A = A  # actually A * n ** (n - 1) because it's an invariant
        self.n = len(token_symbols)
        self.fee = 10 ** 7
        if len(token_prices) > 0:
            self.p = token_prices
        else:
            self.p = [10 ** 18] * self.n
        if D is not None:
            if isinstance(D, list):
                self.x = D
            else:
                self.x = [D // self.n * 10 ** 18 // _p for _p in self.p]
            token_balances = self.x
        else:
            self.x = token_balances
        self.tokens = tokens
        super().__init__(token_symbols, token_prices, token_balances, exchange_fee)     
        
    def xp(self):
        return [x * p // 10 ** 18 for x, p in zip(self.x, self.p)]
    
    def __repr__(self) -> str:
        return "\n".join([
            f"{s} has balance {b} at price {p} with total value {b*p}"
            for s, b, p in zip(self.token_symbols, self.x, self.p)
        ])
        

    def D(self):
        """
        D invariant calculation in non-overflowing integer operations
        iteratively
        A * sum(x_i) * n**n + D = A * D * n**n + D**(n+1) / (n**n * prod(x_i))
        Converging solution:
        D[j+1] = (A * n**n * sum(x_i) - D[j]**(n+1) / (n**n prod(x_i))) / (A * n**n - 1)
        """
        Dprev = 0
        xp = self.xp()
        S = sum(xp)
        D = S
        Ann = self.A * self.n
        while abs(D - Dprev) > 1:
            D_P = D
            for x in xp:
                D_P = D_P * D // (self.n * x)
            Dprev = D
            D = (Ann * S + D_P * self.n) * D // ((Ann - 1) * D + (self.n + 1) * D_P)

        return D
    
    def y(self, i, j, x):
        """
        Calculate x[j] if one makes x[i] = x
        Done by solving quadratic equation iteratively.
        x_1**2 + x1 * (sum' - (A*n**n - 1) * D / (A * n**n)) = D ** (n+1)/(n ** (2 * n) * prod' * A)
        x_1**2 + b*x_1 = c
        x_1 = (x_1**2 + c) / (2*x_1 + b)
        """
        D = self.D()
        xx = self.xp()
        xx[i] = x  # x is quantity of underlying asset brought to 1e18 precision
        xx = [xx[k] for k in range(self.n) if k != j]
        Ann = self.A * self.n
        c = D
        for y in xx:
            c = c * D // (y * self.n)
        c = c * D // (self.n * Ann)
        b = sum(xx) + D // Ann - D
        y_prev = 0
        y = D
        while abs(y - y_prev) > 1:
            y_prev = y
            y = (y ** 2 + c) // (2 * y + b)
        return y  # the result is in underlying units too

    def y_D(self, i, _D):
        """
        Calculate x[j] if one makes x[i] = x
        Done by solving quadratic equation iteratively.
        x_1**2 + x1 * (sum' - (A*n**n - 1) * D / (A * n**n)) = D ** (n+1)/(n ** (2 * n) * prod' * A)
        x_1**2 + b*x_1 = c
        x_1 = (x_1**2 + c) / (2*x_1 + b)
        """
        xx = self.xp()
        xx = [xx[k] for k in range(self.n) if k != i]
        S = sum(xx)
        Ann = self.A * self.n
        c = _D
        for y in xx:
            c = c * _D // (y * self.n)
        c = c * _D // (self.n * Ann)
        b = S + _D // Ann
        y_prev = 0
        y = _D
        while abs(y - y_prev) > 1:
            y_prev = y
            y = (y ** 2 + c) // (2 * y + b - _D)
        return y  # the result is in underlying units too

    def dy(self, i, j, dx):
        # dx and dy are in underlying units
        xp = self.xp()
        return xp[j] - self.y(i, j, xp[i] + dx)

    def remove_liquidity_imbalance(self, amounts):
        _fee = self.fee * self.n // (4 * (self.n - 1))

        old_balances = self.x
        new_balances = self.x[:]
        D0 = self.D()
        for i in range(self.n):
            new_balances[i] -= amounts[i]
        self.x = new_balances
        D1 = self.D()
        self.x = old_balances
        fees = [0] * self.n
        for i in range(self.n):
            ideal_balance = D1 * old_balances[i] // D0
            difference = abs(ideal_balance - new_balances[i])
            fees[i] = _fee * difference // 10 ** 10
            new_balances[i] -= fees[i]
        self.x = new_balances
        D2 = self.D()
        self.x = old_balances

        token_amount = (D0 - D2) * self.tokens // D0

        return token_amount

    def calc_withdraw_one_coin(self, token_amount, i):
        xp = self.xp()
        if self.fee:
            fee = self.fee - self.fee * xp[i] // sum(xp) + 5 * 10 ** 5
        else:
            fee = 0

        D0 = self.D()
        D1 = D0 - token_amount * D0 // self.tokens
        dy = xp[i] - self.y_D(i, D1)

        return dy - dy * fee // 10 ** 10

    def exchange(self, i, j, dx):
        xp = self.xp()
        x = xp[i] + dx
        y = self.y(i, j, x)
        dy = xp[j] - y
        fee = dy * self.fee // 10 ** 10
        assert dy > 0
        self.x[i] = x * 10 ** 18 // self.p[i]
        self.x[j] = (y + fee) * 10 ** 18 // self.p[j]
        return dy - fee
    
#     def exchange(self, 
#                  sell_token_symbol : str,
#                  buy_token_symbol : str,
#                  sell_amount : int) -> None:
        
#         sell_token_id = self._get_token_id(sell_token_symbol)
#         buy_token_id = self._get_token_id(buy_token_symbol)
        
#         # There is no notion of address to send/receive funds to/from.
#         self.token_balances[sell_token_id] -= sell_amount
#         self.token_balances[buy_token_id] -= buy_amount
        
        ### dynamic repegging logic ###
        # TODO: use EMA price oracle to update token_balances * token_prices
        # TODO: update invariant and price scale
        # TODO: "repeg" by adjusting price scale 
            # TODO: only repeg after measuring profit and comparing to actual profit
        
    def add_liquidity(self, 
                      symbol : Optional[str],
                      add_amounts : Union[int, List[int]]) -> None:
        pass
    
    def withdraw_liquidity(self, 
                           symbol : Optional[str],
                           withdraw_amounts : Union[int, List[int]]) -> None:
        pass
    

In [108]:
curve_pool = CurvePool(
    token_symbols = ["ETH", "BTC", "USDT"],
    A = 1000,
    D = 100
    # token_prices = [10, 20],
    # token_balances = [200, 100],
    # exchange_fee = 0.05
)

curve_pool

ETH has balance 33 at price 1000000000000000000 with total value 33000000000000000000
BTC has balance 33 at price 1000000000000000000 with total value 33000000000000000000
USDT has balance 33 at price 1000000000000000000 with total value 33000000000000000000

In [109]:
curve_pool.token_balances

[33, 33, 33]

In [110]:
curve_pool.exchange(0, 2, 5)
curve_pool

ETH has balance 38 at price 1000000000000000000 with total value 38000000000000000000
BTC has balance 33 at price 1000000000000000000 with total value 33000000000000000000
USDT has balance 28 at price 1000000000000000000 with total value 28000000000000000000

# A Different Approach - Copying
https://github.com/curvefi/curve-contract/blob/master/tests/simulation.py

In [57]:
class Curve:

    """
    Python model of Curve pool math.
    """

    def __init__(self, A, D, n, p=None, tokens=None):
        """
        A: Amplification coefficient
        D: Total deposit size
        n: number of currencies
        p: target prices
        """
        self.A = A  # actually A * n ** (n - 1) because it's an invariant
        self.n = n
        self.fee = 10 ** 7
        if p:
            self.p = p
        else:
            self.p = [10 ** 18] * n
        if isinstance(D, list):
            self.x = D
        else:
            self.x = [D // n * 10 ** 18 // _p for _p in self.p]
        self.tokens = tokens

    def xp(self):
        return [x * p // 10 ** 18 for x, p in zip(self.x, self.p)]

    def D(self):
        """
        D invariant calculation in non-overflowing integer operations
        iteratively
        A * sum(x_i) * n**n + D = A * D * n**n + D**(n+1) / (n**n * prod(x_i))
        Converging solution:
        D[j+1] = (A * n**n * sum(x_i) - D[j]**(n+1) / (n**n prod(x_i))) / (A * n**n - 1)
        """
        Dprev = 0
        xp = self.xp()
        S = sum(xp)
        D = S
        Ann = self.A * self.n
        while abs(D - Dprev) > 1:
            D_P = D
            for x in xp:
                D_P = D_P * D // (self.n * x)
            Dprev = D
            D = (Ann * S + D_P * self.n) * D // ((Ann - 1) * D + (self.n + 1) * D_P)

        return D

    def y(self, i, j, x):
        """
        Calculate x[j] if one makes x[i] = x
        Done by solving quadratic equation iteratively.
        x_1**2 + x1 * (sum' - (A*n**n - 1) * D / (A * n**n)) = D ** (n+1)/(n ** (2 * n) * prod' * A)
        x_1**2 + b*x_1 = c
        x_1 = (x_1**2 + c) / (2*x_1 + b)
        """
        D = self.D()
        xx = self.xp()
        xx[i] = x  # x is quantity of underlying asset brought to 1e18 precision
        xx = [xx[k] for k in range(self.n) if k != j]
        Ann = self.A * self.n
        c = D
        for y in xx:
            c = c * D // (y * self.n)
        c = c * D // (self.n * Ann)
        b = sum(xx) + D // Ann - D
        y_prev = 0
        y = D
        while abs(y - y_prev) > 1:
            y_prev = y
            y = (y ** 2 + c) // (2 * y + b)
        return y  # the result is in underlying units too

    def y_D(self, i, _D):
        """
        Calculate x[j] if one makes x[i] = x
        Done by solving quadratic equation iteratively.
        x_1**2 + x1 * (sum' - (A*n**n - 1) * D / (A * n**n)) = D ** (n+1)/(n ** (2 * n) * prod' * A)
        x_1**2 + b*x_1 = c
        x_1 = (x_1**2 + c) / (2*x_1 + b)
        """
        xx = self.xp()
        xx = [xx[k] for k in range(self.n) if k != i]
        S = sum(xx)
        Ann = self.A * self.n
        c = _D
        for y in xx:
            c = c * _D // (y * self.n)
        c = c * _D // (self.n * Ann)
        b = S + _D // Ann
        y_prev = 0
        y = _D
        while abs(y - y_prev) > 1:
            y_prev = y
            y = (y ** 2 + c) // (2 * y + b - _D)
        return y  # the result is in underlying units too

    def dy(self, i, j, dx):
        # dx and dy are in underlying units
        xp = self.xp()
        return xp[j] - self.y(i, j, xp[i] + dx)

    def exchange(self, i, j, dx):
        xp = self.xp()
        x = xp[i] + dx
        y = self.y(i, j, x)
        dy = xp[j] - y
        fee = dy * self.fee // 10 ** 10
        assert dy > 0
        self.x[i] = x * 10 ** 18 // self.p[i]
        self.x[j] = (y + fee) * 10 ** 18 // self.p[j]
        return dy - fee

    def remove_liquidity_imbalance(self, amounts):
        _fee = self.fee * self.n // (4 * (self.n - 1))

        old_balances = self.x
        new_balances = self.x[:]
        D0 = self.D()
        for i in range(self.n):
            new_balances[i] -= amounts[i]
        self.x = new_balances
        D1 = self.D()
        self.x = old_balances
        fees = [0] * self.n
        for i in range(self.n):
            ideal_balance = D1 * old_balances[i] // D0
            difference = abs(ideal_balance - new_balances[i])
            fees[i] = _fee * difference // 10 ** 10
            new_balances[i] -= fees[i]
        self.x = new_balances
        D2 = self.D()
        self.x = old_balances

        token_amount = (D0 - D2) * self.tokens // D0

        return token_amount

    def calc_withdraw_one_coin(self, token_amount, i):
        xp = self.xp()
        if self.fee:
            fee = self.fee - self.fee * xp[i] // sum(xp) + 5 * 10 ** 5
        else:
            fee = 0

        D0 = self.D()
        D1 = D0 - token_amount * D0 // self.tokens
        dy = xp[i] - self.y_D(i, D1)

        return dy - dy * fee // 10 ** 10

In [61]:
pool = Curve(A=0.5, D=100, n=2)
pool.x

[50, 50]

In [62]:
pool.exchange(0, 1, 5)
pool.x

[55, 45.0]