# Composing Contracts

Toying with ideas from [Composing Contracts:An Adventure in Financial Engineering](https://www.cs.tufts.edu/~nr/cs257/archive/simon-peyton-jones/contracts.pdf) by Jones, Eber and Seward.

In [1]:
from typing import TypeVar, Generic
from dataclasses import dataclass, asdict, astuple
from abc import ABC, abstractmethod
from datetime import datetime

In [2]:
class Observable(ABC):
    pass

In [3]:
@dataclass
class Konst(Observable):
    constant: float

@dataclass
class Stock(Observable):
    ticker: str

In [4]:
T = TypeVar('T')

class ObservableVisitor(ABC, Generic[T]):
    
    @abstractmethod
    def konst(self, constant: float) -> T: pass

    def stock(self, ticker: str) -> T: pass
    
    def __call__(self, observable: Observable) -> T:
        if isinstance(observable, Konst):
            return self.konst(observable.constant)
        elif isinstance(observable, Stock):
            return self.stock(observable.ticker)
        else:
            raise TypeError(f'Unknown observable type "{type(contract).__name__}"')

In [5]:
class Contract(ABC):
    pass

In [6]:
@dataclass
class Zero(Contract):
    pass

@dataclass
class One(Contract):
    currency: str

@dataclass
class Give(Contract):
    contract: Contract

@dataclass
class And(Contract):
    contract1: Contract
    contract2: Contract

@dataclass
class Or(Contract):
    contract1: Contract
    contract2: Contract

@dataclass
class Truncate(Contract):
    horizon: datetime
    contract: Contract

@dataclass
class Then(Contract):
    contract1: Contract
    contract2: Contract

@dataclass
class Scale(Contract):
    observable: Observable
    contract: Contract

@dataclass
class Get(Contract):
    contract: Contract

@dataclass
class Anytime(Contract):
    contract: Contract

In [7]:
T = TypeVar('T')

class ContractVisitor(ABC, Generic[T]):

    @abstractmethod
    def zero(self) -> T: pass

    @abstractmethod
    def one(self, currency: str) -> T: pass

    @abstractmethod
    def give(self, contract: Contract) -> T: pass

    @abstractmethod
    def and_(self, contract1: Contract, contract2: Contract) -> T: pass

    @abstractmethod
    def or_(self, contract1: Contract, contract2: Contract) -> T: pass

    @abstractmethod
    def truncate(self, horizon: datetime, contract: Contract) -> T: pass

    @abstractmethod
    def then(self, contract1: Contract, contract2: Contract) -> T: pass

    @abstractmethod
    def scale(self, observable: Observable, contract: Contract) -> T: pass

    @abstractmethod
    def get(self, contract: Contract) -> T: pass

    @abstractmethod
    def anytime(self, contract: Contract) -> T: pass

    def __call__(self, contract: Contract) -> T:
        if isinstance(contract, Zero):
            return self.zero()
        elif isinstance(contract, One):
            return self.one(**contract.__dict__)
        elif isinstance(contract, Give):
            return self.give(**contract.__dict__)
        elif isinstance(contract, And):
            return self.and_(**contract.__dict__)
        elif isinstance(contract, Or):
            return self.or_(**contract.__dict__)
        elif isinstance(contract, Truncate):
            return self.truncate(**contract.__dict__)
        elif isinstance(contract, Then):
            return self.then(**contract.__dict__)
        elif isinstance(contract, Scale):
            return self.scale(**contract.__dict__)
        elif isinstance(contract, Get):
            return self.get(**contract.__dict__)
        elif isinstance(contract, Anytime):
            return self.anytime(**contract.__dict__)
        else:
            raise TypeError(f'Unknown contract type "{type(contract).__name__}"')

In [8]:
class ObservablePrinter(ObservableVisitor[str]):
    def konst(self, constant: float) -> str: return str(constant)
    def stock(self, ticker: str) -> str: return f'"{ticker}"'

In [9]:
class ContractPrinter(ContractVisitor[str]):
    
    def __init__(self, observable_visitor: ObservableVisitor) -> None:
        self.observable_visitor = observable_visitor
    
    def zero(self) -> str: return 'zero'
    def one(self, currency: str) -> str: return f'one {currency}'
    def give(self, contract: Contract) -> str: return f'give ({self(contract)})'
    def and_(self, contract1: Contract, contract2: Contract) -> str: return f'and ({self(contract1)}) ({self(contract2)})'
    def or_(self, contract1: Contract, contract2: Contract) -> str: return f'or ({self(contract1)}) ({self(contract2)})'
    def truncate(self, horizon: datetime, contract: Contract) -> str: return f'truncate "{horizon}" ({self(contract)})'
    def then(self, contract1: Contract, contract2: Contract) -> str: return f'then ({self(contract1)}) ({self(contract2)})'
    def scale(self, observable: Observable, contract: Contract) -> str: return f'scale {self.observable_visitor(observable)} ({self(contract)})'
    def get(self, contract: Contract) -> str: return f'get ({self(contract)})'
    def anytime(self, contract: Contract) -> str: return f'anytime ({self(contract)})'

In [10]:
def scaleK(constant: float, contract: Contract):
    return Scale(Konst(constant), contract)

In [11]:
def zcb(maturity: datetime, notional: float, currency: str):
    return scaleK(notional, Get(Truncate(maturity, One(currency))))

In [12]:
def european(maturity: datetime, contract: Contract):
    return Get(Truncate(maturity, Or(contract, Zero())))

In [13]:
mybond1 = zcb(datetime(2030, 7, 14), 1000000, 'EUR')
mybond1

Scale(observable=Konst(constant=1000000), contract=Get(contract=Truncate(horizon=datetime.datetime(2030, 7, 14, 0, 0), contract=One(currency='EUR'))))

In [14]:
ContractPrinter(ObservablePrinter())(mybond1)

'scale 1000000 (get (truncate "2030-07-14 00:00:00" (one EUR)))'

In [15]:
asdict(mybond1)

{'observable': {'constant': 1000000},
 'contract': {'contract': {'horizon': datetime.datetime(2030, 7, 14, 0, 0),
   'contract': {'currency': 'EUR'}}}}

In [16]:
astuple(mybond1)

((1000000,), ((datetime.datetime(2030, 7, 14, 0, 0), ('EUR',)),))

In [17]:
dict(**mybond1.__dict__)

{'observable': Konst(constant=1000000),
 'contract': Get(contract=Truncate(horizon=datetime.datetime(2030, 7, 14, 0, 0), contract=One(currency='EUR')))}

In [18]:
def european_put(ticker: str, currency: str, maturity: datetime, strike: float):
    return Get(Truncate(maturity, Or(Scale(Stock(ticker), One(currency)), scaleK(strike, One(currency)))))

In [19]:
myput1 = european_put('ABC Eqty', 'USD', datetime(2030, 7, 14), 123.45)
myput1

Get(contract=Truncate(horizon=datetime.datetime(2030, 7, 14, 0, 0), contract=Or(contract1=Scale(observable=Stock(ticker='ABC Eqty'), contract=One(currency='USD')), contract2=Scale(observable=Konst(constant=123.45), contract=One(currency='USD')))))

In [20]:
ContractPrinter(ObservablePrinter())(myput1)

'get (truncate "2030-07-14 00:00:00" (or (scale "ABC Eqty" (one USD)) (scale 123.45 (one USD))))'

In [21]:
import numpy as np

In [22]:
class SimulationModel(ObservableVisitor[np.array]):
    def konst(self, constant: float) -> np.array: raise NotImplementedError('konst')
    def stock(self, ticker: str) -> np.array: raise NotImplementedError('stock')

In [23]:
class Evaluator(ContractVisitor[np.array]):
    
    def __init__(self, model: SimulationModel):
        self.model = model
    
    def zero(self) -> np.array: raise NotImplementedError('zero')
    def one(self, currency: str) -> np.array: raise NotImplementedError('one')
    def give(self, contract: Contract) -> np.array: raise NotImplementedError('give')
    def and_(self, contract1: Contract, contract2: Contract) -> np.array: raise NotImplementedError('and_')
    def or_(self, contract1: Contract, contract2: Contract) -> np.array: raise NotImplementedError('or_')
    def truncate(self, horizon: datetime, contract: Contract) -> np.array: raise NotImplementedError('truncate')
    def then(self, contract1: Contract, contract2: Contract) -> np.array: raise NotImplementedError('then')
    def scale(self, observable: Observable, contract: Contract) -> np.array: raise NotImplementedError('scale')
    def get(self, contract: Contract) -> np.array: raise NotImplementedError('get')
    def anytime(self, contract: Contract) -> np.array: raise NotImplementedError('anytime')

In [24]:
evaluator = Evaluator(SimulationModel())

In [25]:
evaluator(myput1)

NotImplementedError: get