In [1]:
import torch
from torch.utils.cpp_extension import load

In [3]:
!rm -rf build
!mkdir -p build/diffpms build/diffpms_cuda

In [4]:
!pip install Ninja



In [5]:
!apt install gcc-8 g++-8 -y

Reading package lists... Done
Building dependency tree       
Reading state information... Done
g++-8 is already the newest version (8.4.0-1ubuntu1~18.04).
gcc-8 is already the newest version (8.4.0-1ubuntu1~18.04).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.


In [6]:
!update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 1000
!update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 1000

In [7]:
!gcc --version
!g++ --version
!nvcc --version

gcc (Ubuntu 8.4.0-1ubuntu1~18.04) 8.4.0
Copyright (C) 2018 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

g++ (Ubuntu 8.4.0-1ubuntu1~18.04) 8.4.0
Copyright (C) 2018 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Wed_Jul_22_19:09:09_PDT_2020
Cuda compilation tools, release 11.0, V11.0.221
Build cuda_11.0_bu.TC445_37.28845127_0


In [8]:
!rm -rf noa
!git clone https://github.com/grinisrit/noa.git

Cloning into 'noa'...
remote: Enumerating objects: 1899, done.[K
remote: Counting objects: 100% (1050/1050), done.[K
remote: Compressing objects: 100% (628/628), done.[K
remote: Total 1899 (delta 631), reused 738 (delta 350), pack-reused 849[K
Receiving objects: 100% (1899/1899), 7.10 MiB | 14.81 MiB/s, done.
Resolving deltas: 100% (1120/1120), done.


In [9]:
diffpms = load(name='diffpms',
             build_directory='./build/diffpms',
             sources=['noa/docs/pms/diffpms.cc'],
             extra_include_paths=['noa/include'],
             extra_cflags=['-Wall -Wextra -Wpedantic -O3 -std=c++17'], 
             verbose=False)

In [10]:
kinetic_energies = torch.linspace(1e-3, 1e6, 10000).double()
recoil_energies = 0.0505 * kinetic_energies

In [11]:
brems = diffpms.bremsstrahlung(kinetic_energies, recoil_energies)
brems[:5]

tensor([3.5293e-04, 3.9395e-06, 4.0777e-06, 4.1341e-06, 4.1650e-06],
       dtype=torch.float64)

In [12]:
diffpms_cuda = load(name='diffpms_cuda',
             build_directory='./build/diffpms_cuda',
             sources=['noa/docs/pms/diffpms.cu'],
             extra_include_paths=['noa/include'],
             extra_cflags=['-Wall -Wextra -Wpedantic -O3 -std=c++17'],
             extra_cuda_cflags=['-std=c++17 --extended-lambda'],
             verbose=False)

In [13]:
kinetic_energies_gpu = kinetic_energies.cuda()
recoil_energies_gpu = recoil_energies.cuda()

In [14]:
brems_gpu = diffpms_cuda.bremsstrahlung(kinetic_energies_gpu, recoil_energies_gpu);
brems_gpu[:5]

tensor([3.5293e-04, 3.9395e-06, 4.0777e-06, 4.1341e-06, 4.1650e-06],
       device='cuda:0', dtype=torch.float64)

Based on a [gist](https://gist.github.com/t-vi/2f4fe23a5b473b9dceb95b163378b4d5#file-pytorch-numba-py) by [Thomas Viehmann](https://gist.github.com/t-vi)

In [None]:
from numba import cuda, njit, prange
import numpy as np
import math
import torch
import ctypes

@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')
def cu_exp_matrix_mul(A, c, d, u, v, b, n, m):
    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bx = cuda.blockIdx.x
    by = cuda.blockIdx.y
    bw = cuda.blockDim.x
    bh = cuda.blockDim.y

    bi = tx + bx * bw
    ni = ty + by * bh

    if ni >= n or bi >= b:
        return
    r = 0
    for mi in range(m):
        r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]
    v[bi, ni] = r


@njit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')
def gnu_exp_matrix_mul(A, c, d, u, v, b, n, m):
    for bi in range(b):
      for ni in range(n):
        r = 0
        for mi in range(m):
            r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]
        v[bi, ni] = r 

@njit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)', parallel=True)
def omp_exp_matrix_mul(A, c, d, u, v, b, n, m):
    for bi in prange(b):
      for ni in prange(n):
        r = 0
        for mi in range(m):
            r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]
        v[bi, ni] = r 

def py_exp_matrix_mul(A, c, d, u, v, b, n, m):
    for bi in range(b):
      for ni in range(n):
        r = 0
        for mi in range(m):
            r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) * u[bi, mi]
        v[bi, ni] = r 

def get_devicendarray(t):
    assert t.type() == 'torch.cuda.FloatTensor'
    ctx = cuda.cudadrv.devices.get_context(t.device.index)
    mp = cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel()*4)
    return cuda.cudadrv.devicearray.DeviceNDArray(t.size(), [i*4 for i in t.stride()], np.dtype('float32'), 
                                                  gpu_data=mp, stream=torch.cuda.current_stream().cuda_stream)

def batch_expmat_product(A, c, d, u, omp=False, py=False):
    BLOCK=32
    b = c.size(0)
    n = A.size(0)
    m = A.size(1)
    assert A.dim()==2 and c.dim()==2 and d.dim()==2 and u.dim()==2, "dimension mismatch"
    assert c.size(1)==m and d.size(0)==b and d.size(1)==n and u.size(0)==b and u.size(1)==m, "size mismatch"
    v = u.new(d.size()).zero_()

    if A.is_cuda and c.is_cuda and d.is_cuda and u.is_cuda:
        Ad,cd,dd,ud,vd = (get_devicendarray(x) for x in (A,c,d,u,v))
        cu_exp_matrix_mul[((b-1)//BLOCK+1,(m-1)//BLOCK+1),(BLOCK,BLOCK)](Ad,cd,dd,ud,vd,b,n,m)
    else:
        Ad,cd,dd,ud,vd = (x.cpu().numpy() for x in (A,c,d,u,v))
        if omp:
            omp_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)
        else:
            if py:
                py_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)
            else:
                gnu_exp_matrix_mul(Ad,cd,dd,ud,vd,b,n,m)
    return v


In [None]:
b,n,m = 100,200,300
A = torch.randn(n,m)
c = torch.randn(b,m)
d = torch.randn(b,n)
u = torch.randn(b,m)
t = torch.randn(b,n)

In [None]:
w_py = batch_expmat_product(A,c,d,u, py=True)
w_py

In [None]:
w_cpu = batch_expmat_product(A,c,d,u)
torch.abs(w_py - w_cpu).mean()

In [None]:
w_cpu_omp = batch_expmat_product(A,c,d,u, omp=True)
torch.abs(w_py - w_cpu_omp).mean()

In [None]:
Acu,ccu,dcu,ucu = (x.cuda() for x in (A,c,d,u))
w_gpu = batch_expmat_product(Acu,ccu,dcu,ucu)
torch.abs(w_py.cuda() - w_gpu).mean()

In [None]:
%timeit batch_expmat_product(A,c,d,u, py=True)

In [None]:
%timeit batch_expmat_product(A,c,d,u)

In [None]:
%timeit batch_expmat_product(A,c,d,u, omp=True)

In [None]:
%timeit batch_expmat_product(Acu,ccu,dcu,ucu)