# Composing Contracts

Toying with ideas from [Composing Contracts:An Adventure in Financial Engineering](https://www.cs.tufts.edu/~nr/cs257/archive/simon-peyton-jones/contracts.pdf) by Jones, Eber and Seward.

In [1]:
from typing import TypeVar, Generic
from dataclasses import dataclass, asdict, astuple
from abc import ABC, abstractmethod
from datetime import datetime

In [2]:
class Observable(ABC):
    pass

In [3]:
@dataclass
class Konst(Observable):
    constant: float

In [4]:
class Contract(ABC):
    pass

In [5]:
@dataclass
class Zero(Contract):
    pass

@dataclass
class One(Contract):
    currency: str

@dataclass
class Give(Contract):
    contract: Contract

@dataclass
class And(Contract):
    contract1: Contract
    contract2: Contract

@dataclass
class Or(Contract):
    contract1: Contract
    contract2: Contract

@dataclass
class Truncate(Contract):
    horizon: datetime
    contract: Contract

@dataclass
class Then(Contract):
    contract1: Contract
    contract2: Contract

@dataclass
class Scale(Contract):
    observable: Observable
    contract: Contract

@dataclass
class Get(Contract):
    contract: Contract

@dataclass
class Anytime(Contract):
    contract: Contract

In [6]:
T = TypeVar('T')

class ContractVisitor(ABC, Generic[T]):

    @abstractmethod
    def zero(self) -> T: pass

    @abstractmethod
    def one(self, currency: str) -> T: pass

    @abstractmethod
    def give(self, contract: Contract) -> T: pass

    @abstractmethod
    def and_(self, contract1: Contract, contract2: Contract) -> T: pass

    @abstractmethod
    def or_(self, contract1: Contract, contract2: Contract) -> T: pass

    @abstractmethod
    def truncate(self, horizon: datetime, contract: Contract) -> T: pass

    @abstractmethod
    def then(self, contract1: Contract, contract2: Contract) -> T: pass

    @abstractmethod
    def scale(self, observable: Observable, contract: Contract) -> T: pass

    @abstractmethod
    def get(self, contract: Contract) -> T: pass

    @abstractmethod
    def anytime(self, contract: Contract) -> T: pass

    def __call__(self, contract: Contract) -> T:
        if isinstance(contract, Zero):
            return self.zero()
        elif isinstance(contract, One):
            return self.one(**contract.__dict__)
        elif isinstance(contract, Give):
            return self.give(**contract.__dict__)
        elif isinstance(contract, And):
            return self.and_(**contract.__dict__)
        elif isinstance(contract, Or):
            return self.or_(**contract.__dict__)
        elif isinstance(contract, Truncate):
            return self.truncate(**contract.__dict__)
        elif isinstance(contract, Then):
            return self.then(**contract.__dict__)
        elif isinstance(contract, Scale):
            return self.scale(**contract.__dict__)
        elif isinstance(contract, Get):
            return self.get(**contract.__dict__)
        elif isinstance(contract, Anytime):
            return self.anytime(**contract.__dict__)
        else:
            raise TypeError(f'Unknown contract type "{type(contract).__name__}"')

In [7]:
class ContractPrinter(ContractVisitor[str]):
    def zero(self) -> str: return 'zero'
    def one(self, currency: str) -> str: return f'one {currency}'
    def give(self, contract: Contract) -> str: return f'give ({self(contract)})'
    def and_(self, contract1: Contract, contract2: Contract) -> str: return f'and ({self(contract1)}) ({self(contract2)})'
    def or_(self, contract1: Contract, contract2: Contract) -> str: return f'or ({self(contract1)}) ({self(contract2)})'
    def truncate(self, horizon: datetime, contract: Contract) -> str: return f'truncate "{horizon}" ({self(contract)})'
    def then(self, contract1: Contract, contract2: Contract) -> str: return f'then ({self(contract1)}) ({self(contract2)})'
    def scale(self, observable: Observable, contract: Contract) -> str: return f'scale ({observable}) ({self(contract)})'
    def get(self, contract: Contract) -> str: return f'get ({self(contract)})'
    def anytime(self, contract: Contract) -> str: return f'anytime ({self(contract)})'

In [8]:
def scaleK(constant: float, contract: Contract):
    return Scale(Konst(constant), contract)

In [9]:
def zcb(maturity: datetime, notional: float, currency: str):
    return scaleK(notional, Get(Truncate(maturity, One(currency))))

In [10]:
mybond1 = zcb(datetime(2030, 7, 14), 1000000, 'EUR')
mybond1

Scale(observable=Konst(constant=1000000), contract=Get(contract=Truncate(horizon=datetime.datetime(2030, 7, 14, 0, 0), contract=One(currency='EUR'))))

In [11]:
ContractPrinter()(mybond1)

'scale (Konst(constant=1000000)) (get (truncate "2030-07-14 00:00:00" (one EUR)))'

In [12]:
asdict(mybond1)

{'observable': {'constant': 1000000},
 'contract': {'contract': {'horizon': datetime.datetime(2030, 7, 14, 0, 0),
   'contract': {'currency': 'EUR'}}}}

In [13]:
astuple(mybond1)

((1000000,), ((datetime.datetime(2030, 7, 14, 0, 0), ('EUR',)),))

In [14]:
dict(**mybond1.__dict__)

{'observable': Konst(constant=1000000),
 'contract': Get(contract=Truncate(horizon=datetime.datetime(2030, 7, 14, 0, 0), contract=One(currency='EUR')))}