In [None]:
import torch
import torchvision.models as models 
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np 
from scipy.stats import random_correlation 
from scipy.stats import multivariate_normal as normal 
import datetime as dt
import matplotlib.pyplot as plt
import time
X_0_sc = 0       ## initial X scalar
DIM = 1          ## dimension of X
T=1              ## time horizon
N=100            ## number of discretization for time horizon T
B_0_sc= 1        ## drift scalar baseline
Sig_0_sc= 1      ## vol scalar baseline
M  = 3000000            ## number of outer mc samples  (M_0)
M_w= 30000              ## number of inner mc samples  (M_1 with M_1<M_0)
mps_device = torch.device("mps")
V0_array = torch.zeros([1,1], device="mps")
W0_array = torch.zeros([1,1,2], device="mps")
Time_array = torch.zeros([1,1], device="mps")
##################################### Fixed, uncorrelated case #########################################
if DIM == 1:
    B_0 = np.float32(np.ones((1,1))*B_0_sc)
    Sig_0 = np.float32(np.ones((1,1))*Sig_0_sc)
    X_0= np.float32(np.ones((1,1))*X_0_sc)
else:
    B_0 = np.float32(np.ones((DIM))*B_0_sc)
    Sig_0 = np.float32(np.eye(DIM)*Sig_0_sc)
    X_0 = np.float32(np.ones(DIM)*X_0_sc)
################################     Torch variable def    ##############################
with torch.device(mps_device): 
    X_0_tor = torch.tensor(torch.from_numpy(X_0), device="mps") 
    B_0_tor = torch.tensor(torch.from_numpy(B_0), device="mps") 
    Sig_0_tor = torch.tensor(torch.from_numpy(Sig_0), device="mps") 
    # ##############################     Float increment     ##############################
    Dt_tor = torch.tensor(T/N, device="mps") 
    DIM_tor = torch.tensor(DIM, device="mps") 
    #DIM_arr_tor = torch.range(1,DIM, device="mps") / DIM_tor
    #Dx_tor = torch.tensor(H) 
    # ##############################     Random sample generation     ##############################
    T_arr_tor= torch.arange(1,N+1)*Dt_tor 
    R_T_arr_tor=torch.linspace(T,0,N+1)[:-1] 
    Sqrt_T_arr_tor = torch.sqrt(T_arr_tor) 
    R_Sqrt_T_arr_tor = torch.sqrt(R_T_arr_tor) 
    W_sample = torch.tensor(torch.FloatTensor(DIM, M).normal_(), device="mps") 
    Dum_O1, Dum_O2, Dum_O3, Dum_O4 = torch.ones([M,DIM,DIM]), torch.ones([M, N, DIM]), torch.ones([M, DIM, N]), torch.ones([N, DIM, DIM])
    Dum_Z1, Dum_Z2 = torch.zeros([DIM,DIM]), torch.zeros([M,DIM,DIM])
    V1_w = torch.zeros(1)
    V1_Jac_w = torch.zeros(1)
R_Sqrt_T_arr_tensor=torch.permute(Dum_O3*R_Sqrt_T_arr_tor, (2,0,1))
Sqrt_T_arr_tensor=torch.permute(Dum_O3*Sqrt_T_arr_tor, (2,0,1))
X0_D = Dum_O2 * X_0_tor 
DR_D = Dum_O2 * B_0_tor
DIF_D1 = torch.permute(R_Sqrt_T_arr_tensor*torch.transpose(torch.matmul(Sig_0_tor,W_sample),0,1), (1,0,2))
DIF_D2 = torch.permute(Sqrt_T_arr_tensor*torch.transpose(torch.matmul(Sig_0_tor,W_sample),0,1), (1,0,2))
####################### Function for v^0 #######################
def func_v(X_T):  
    inner_sum = torch.sum(X_T, dim=1)
    v_value = torch.mean(torch.pow(inner_sum,4))                  
    return v_value
XT_V = X0_D[:,0,:] + DR_D[:,0,:] * T_arr_tor[-1] + DIF_D1[:,0,:]    #### num of disc, num of MC, dim of X
MC_value=func_v(XT_V)
J_w_const_mtx = torch.permute(torch.matmul(Dum_O4,Sig_0_tor),[1,2,0])
####################### Function for w and J_xw #######################
def func_w(X_T):   
    inner_sum = torch.sum(X_T, dim=2) 
    w_value = torch.mean(4*torch.pow(inner_sum,3),dim=0)                ### [N]
    w_value_d = Dum_O3[0,:,:] * w_value                             ### [DIM, N]
    J_w_value = torch.mean(4*3*torch.pow(inner_sum,2),dim=0)              
    J_w_value_mtx =  J_w_const_mtx* J_w_value                     ### [DIM, DIM, N]  
    norm_w, norm_J_w_sig_0 = torch.linalg.norm(w_value_d, dim=0), torch.linalg.norm(J_w_value_mtx,dim=(0,1))
    return norm_w, norm_J_w_sig_0, torch.sum(norm_w), torch.sum(norm_J_w_sig_0)
XT_W_FIX =  X0_D[:M_w,:,:] + DR_D[:M_w,:,:] * T_arr_tor[-1] + DIF_D1[:M_w,:,:]
####################### Run #######################
start = time.time()
V0_array[0,0] = func_v(XT_V)
pc_M = 0
p=1
for l in range(M_w):
    V1_Temp = func_w(XT_W_FIX[:M_w,:,:] + Dum_O2[:M_w,:,:]*DIF_D2[l,:,:])
    V1_w += V1_Temp[2]
    V1_Jac_w += V1_Temp[3]
    if (pc_M == 10000):
        print('%7.d MC-paths' %(p*10000))
        p+=1
        pc_M = 0
    pc_M +=1
W0_array[0,0,:] = torch.tensor([V1_w * Dt_tor / M_w , V1_Jac_w * Dt_tor / M_w], device="mps") 
end = time.time()
Time_array[0,0]= torch.tensor(end - start , device="mps") 
####################### Save #######################
V0_array_np =  torch.Tensor.numpy(torch.as_tensor(V0_array[0], dtype=torch.float32, device='cpu'))
W0_array_np =  torch.Tensor.numpy(torch.as_tensor(W0_array[0,0,:], dtype=torch.float32, device='cpu'))
Time_array_np =  torch.Tensor.numpy(torch.as_tensor(Time_array[0], dtype=torch.float32, device='cpu'))
INFO_np = V0_array_np[0], W0_array_np[0], W0_array_np[1], Time_array_np[0]
INFO_np