# Padé appproximation

In [1]:
import logging

import numpy

import cicada
from cicada.additive import AdditiveProtocolSuite
from cicada.communicator import SocketCommunicator

from statistics import mean, stdev
from tqdm import *

logging.basicConfig(level=logging.INFO)

def main(communicator):
    log = cicada.Logger(logging.getLogger(), communicator)
    protocol = AdditiveProtocolSuite(communicator)

    values = numpy.array(-1)
    log.info(f"Player {communicator.rank} values: {values}")

    values_share = protocol.share(src=0, secret=values, shape=())
    pade_tanh_share = protocol._pade_approx(numpy.tanh, values_share)
    pade_tanh = pade_tanh_share
    log.info(f"Player {communicator.rank} pade_tanh({values}): {pade_tanh} vs tanh({values}): {numpy.tanh(values)}")
    abs_err_list = []
    for i in numpy.arange(-1, 1, .1):
        cic_pade_tanh = protocol._pade_approx(numpy.tanh, protocol.share(src=0, secret=numpy.array(i), shape=()))
        legit_tanh = numpy.tanh(i)
        abs_err_list.append((legit_tanh-cic_pade_tanh))
        if communicator.rank ==1:
            print(legit_tanh, cic_pade_tanh, legit_tanh-cic_pade_tanh)
    

    
    if communicator.rank ==1:
        print(mean(abs_err_list), stdev(abs_err_list))

SocketCommunicator.run(world_size=3, fn=main);

INFO:root:Player 0 values: -1
INFO:root:Player 1 values: -1
INFO:root:Player 2 values: -1
INFO:root:Player 0 pade_tanh(-1): -0.7627071286262695 vs tanh(-1): -0.7615941559557649
INFO:root:Player 1 pade_tanh(-1): -0.7627071286262695 vs tanh(-1): -0.7615941559557649
INFO:root:Player 2 pade_tanh(-1): -0.7627071286262695 vs tanh(-1): -0.7615941559557649


-0.7615941559557649 -0.7627071286262695 0.0011129726705046972
-0.7162978701990245 -0.803611258629846 0.08731338843082148
-0.664036770267849 -0.8452953030537901 0.18125853278594106
-0.6043677771171635 -0.8882166310042057 0.2838488538870422
-0.5370495669980353 -0.9362973760932944 0.39924780909525914
-0.4621171572600098-1.0002931691586046  0.5381760118985948
-0.379948962255225 -1.0926885434657456 0.7127395812105206
-0.29131261245159107 -1.18698347107438 0.8956708586227891
-0.19737532022490417 -1.144215530903328 0.9468402106784237
-0.09966799462495601 -0.780439121756487 0.680771127131531
-2.220446049250313e-16 -0.13814432989690723 0.138144329896907
0.09966799462495547 0.4903339191564148 -0.3906659245314593
0.19737532022490373 0.8638613861386139 -0.6664860659137102
0.29131261245159074 0.9911045218680504 -0.6997919094164597
0.3799489622552246 0.998417095370004 -0.6184681331147794
0.4621171572600094 0.9700923323966278 -0.5079751751366184
0.537049566998035 0.9323926504923358 -0.395343083494300

In [2]:
from scipy.interpolate import approximate_taylor_polynomial, pade
degree = 9
func_taylor = approximate_taylor_polynomial(numpy.tanh, 0, degree, 1)
paden, paded = pade([x for x in func_taylor][::-1], degree//2, degree%2+degree//2)
print(paden(1)/paded(1), func_taylor(1), numpy.tanh(1))

0.7623234211429005 0.7615941559557651 0.7615941559557649


In [3]:
for i in tqdm(numpy.arange(-1, 1, .1)):
    degree = 9
    func_taylor = approximate_taylor_polynomial(numpy.tanh, 0, degree, 5)
    paden, paded = pade([x for x in func_taylor][::-1], degree//2, degree%2+degree//2)
    print(paden(i)/paded(i), func_taylor(i), numpy.tanh(i), numpy.tanh(i)-paden(i)/paded(i))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 2507.21it/s]

-0.69183672695528 -0.6918365835570223 -0.7615941559557649 -0.06975742900048487
-0.636635354665084 -0.6366353089798658 -0.7162978701990245 -0.07966251553394044
-0.5773159842445603 -0.5773159715690709 -0.664036770267849 -0.0867207860232887
-0.5141746063369771 -0.5141746033842407 -0.6043677771171635 -0.0901931707801864
-0.44756121347358013 -0.4475612129262017 -0.5370495669980353 -0.08948835352445517
-0.37787491413125773 -0.37787491405694207 -0.4621171572600098 -0.08424224312875206
-0.3055583236432489 -0.3055583236368196 -0.379948962255225 -0.07439063861197609
-0.23109131870020064 -0.2310913186999276 -0.29131261245159107 -0.06022129375139043
-0.15498425028761556 -0.15498425028761237 -0.19737532022490417 -0.04239106993728861
-0.07777071630940041 -0.0777707163094004 -0.09966799462495601 -0.0218972783155556
-1.3191618985888415e-16 -1.3191618985888415e-16 -2.220446049250313e-16 -9.012841506614716e-17
0.07777071630940006 0.07777071630940005 0.09966799462495547 0.021897278315555407
0.15498425028


