In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

In [3]:
def mlmc_gbm(l,Nl,T=1,M=2):
    if Nl==0:
        return 
    Nsteps=M**l
    dt=T/Nsteps
    sqrt_dt=np.sqrt(dt)
    W=np.random.randn(Nsteps,Nl)*sqrt_dt
    P_l=calc_P_l(W,dt)
    sum1=sum(P_l)
    sum2=sum(P_l**2)
    
    if l==0:
        return np.array([sum1,sum2,sum1,sum2])
    else:
        Wnew=W[0::M,:]
        for m in range(1,M):
            Wnew+=W[m::M,:]
        P_lm1=calc_P_l(Wnew,M*dt)
        dP_l=P_l-P_lm1
        return [sum(dP_l),sum(dP_l**2),sum1,sum2]
    
def calc_P_l(W,dt,X0=100,T=1):
    '''
    Inputs: W (matrix of Nsteps x Nl random numbers ~ N(0,1))
    Outputs: h(X) (payoff for each of the Nl paths)
    Calculate X for Nl paths over Nsteps 
    '''
    Nsteps,Nl=np.shape(W)
    X=X0*np.ones(Nl)
    for i in range(Nsteps):
        X+=r*X*dt + sig*X*W[i,:]
    X[X<0]=0
    return np.exp(-0.05*T)*X

def N_opt(V,eps,l,L,M):
    return int((2/(eps**2))*sum(np.sqrt(V*M**(np.arange(0,L+1))))*np.sqrt(V[l]/M**l)+1)
    
def mlmc(mlmc_fn,eps,N0=10**4,M=2,T=1):
    '''
    '''
    V=np.zeros(1)
    N=N0*np.ones(1)
    dN=N0*np.ones(1)
    sums=np.zeros(4,1)
    L=0
    sums[:,L]+=mlmc_fn(L,N0,M,T)
    V[L]=(sums[1,L]-sums[0,L]**2)/N0

    while (sum(dN)!=0):
        for l in range(0,L+1):
            Nl=N[l]
            V[l]=(sums[1,l]-sums[0,l]**2)/Nl
            Nl_new=N_opt(V,eps,l,L,M)
            dNl[l]=max(0,Nl_new-Nl)
            sums[L,:]+=mlmc_fn(l,dNl,M,T)
            
        N+=dNl
        if max(abs(sums[0,L-1])/(N[-2]*M),abs(sums[0,L])/N[-1])>(M-1)*eps/np.sqrt(2) or L<=2:
            L+=1
            V=np.hstack(V,np.zeros(1))
            N=np.hstack(N,N0*np.ones(1))
            dN=np.hstack(dN,N0*np.ones(1))
            sums=np.hstack(sums,np.zeros(4,1))
            sums[:,L]+=mlmc_fn(L,N0,M,T)
            V[L]=(sums[1,L]-sums[0,L]**2)/N0
    
    return sums,N
            

In [7]:
#Exact result according to BS formula
D1 =(r+sig**2/2)/sig
D2 = D1 - sig
print(100*norm.cdf(D1)-100*np.exp(-r)*norm.cdf(D2))


12.335998930368717
