In [11]:
import numpy as np
from numba import njit,float64,prange
import numba
numba.config.NUMBA_DEFAULT_NUM_THREADS=4

$$
B_{i, j}=\sin \frac{i-0.5}{N} \pi \times \cos \frac{k}{N} \pi, C_{k, j}=e^{-\frac{k-0.5}{N}\times \pi} \times \sqrt{\frac{j(j-1)}{N^{2}}}
\\
A_{i, j}=\sum_{k=1}^{N} B_{i, k} \times C_{k, j},\quad N=200
$$

In [2]:
@njit
def b(i,k):
    return np.sin( np.pi*(i-0.5)/200 )*np.cos(np.pi*k/200)

@njit
def c(k,j):
    return np.exp(-(k-0.5)/200*np.pi)*np.sqrt(j*(j-1)/200**2)

@njit("float64[:](float64[:])")
def bubble_abs_sort(to_sort):
    len_sort = len(to_sort)
    for i in range(len_sort-1):
        for j in range(len_sort-1-i):
            if( np.abs(to_sort[j])>np.abs(to_sort[j+1]) ):
                to_sort[j], to_sort[j + 1] = to_sort[j + 1], to_sort[j]
    return to_sort

@njit
def abs_sum(to_sum):
    len_sum = len(to_sum)
    result = 0.
    for i in to_sum:
        result+=i
    return result

In [259]:
@njit(nogil=True, parallel=True)
def main():
    result_m = np.zeros((200,200))
    
    for i in prange(200):
        for j in prange(200):
            to_sum = bubble_abs_sort(np.array([b(i+1,k)*c(k,j+1) for k in np.arange(1,201,1)]))
            result_m[i,j] = abs_sum(to_sum)
    
    return result_m

In [263]:
%timeit result_m = main()
display(result_m)

738 ms ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


array([[0.        , 0.00182956, 0.00316889, ..., 0.25550399, 0.25679769,
        0.25809139],
       [0.        , 0.00548823, 0.0095059 , ..., 0.76644893, 0.77032971,
        0.77421048],
       [0.        , 0.00914555, 0.01584056, ..., 1.27720475, 1.28367166,
        1.29013856],
       ...,
       [0.        , 0.00914555, 0.01584056, ..., 1.27720475, 1.28367166,
        1.29013856],
       [0.        , 0.00548823, 0.0095059 , ..., 0.76644893, 0.77032971,
        0.77421048],
       [0.        , 0.00182956, 0.00316889, ..., 0.25550399, 0.25679769,
        0.25809139]])

In [264]:
result_m[0,1]

0.0018295613429145507

本程序计算:0.0018295613429145507

直接计算：0.0018295613429145496

mma直接计算：0.00182956134291455

mma排序计算:0.001829561342914551

pytorch计算:0.0018295613429145

python默认的是17位小数的精度

高精度使用decimal模块，配合getcontext

利用float.as_integer_ratio()可以分数形式精确表示小数