# User Define Kernels

## elementwise kernels 基本用法

In [46]:
import numpy as np
import cupy as cp
import cv2
%matplotlib inline
import matplotlib.pyplot as plt

Elementwise Kernel由四部分组成：1.输入参数列表；2.输出参数列表；3.实现逻辑代码；4.Kernel名称。比如，实现 $f(x,y) = (x - y)^2$的Kernel定义如下：

In [21]:
squared_diff = cp.ElementwiseKernel(
'float32 x,float32 y','float32 z','''
    z = (x - y) * (x - y);
    ''',
    'squared_diff'
)

In [22]:
x = cp.arange(1,10).astype('f')
y = cp.arange(11,20).astype('f')
res = squared_diff(x,y)
print('squared diff:',res)

squared diff: [100. 100. 100. 100. 100. 100. 100. 100. 100.]


通过标量广播

In [23]:
x = cp.arange(10,dtype=np.float32).reshape(2,5)
y = cp.arange(5,dtype=np.float32)
res = squared_diff(x,y)
print(res)

[[ 0.  0.  0.  0.  0.]
 [25. 25. 25. 25. 25.]]


显示指定输出

In [24]:
z = cp.empty((2,5),dtype=np.float32)
squared_diff(x,y,z)
print(z)

[[ 0.  0.  0.  0.  0.]
 [25. 25. 25. 25. 25.]]


泛型参数类型Kernel

In [25]:
squared_diff_generic = cp.ElementwiseKernel('T x,T y','T z','z = (x - y) * (x - y)',
                                           'squared_diff_generic')

In [26]:
x = cp.arange(10).reshape(2,5)
y = cp.arange(5)
res = squared_diff_generic(x,y)
print(res)

[[ 0  0  0  0  0]
 [25 25 25 25 25]]


## 数组手动索引

通过使用raw关键字、特殊的变量 i 和 _ind.size()方法可以实现手动索引。变量 i 表示循环体内的索引；_ind.size()表示所要操作的元素数量。

In [27]:
add_reverse = cp.ElementwiseKernel(
    'T x,raw T y','T z','z = x + y[_ind.size() - i - 1]',
    'add_reverse'
    )

In [28]:
x = cp.arange(1,10)
y = cp.arange(11,20)
res = add_reverse(x,y)
print(res)

[20 20 20 20 20 20 20 20 20]


## Reduction Kernel

reduction kernel 主要由四部分组成：

1.Identity Value:初始化值

2.Mapping expression：预处理

3.Reduction expression：通过特殊变量a和b操作

4.Post mapping expression：通过特殊变量a操作

In [29]:
l2norm_kernel = cp.ReductionKernel(
    'T x', # 输入参数
    'T y', # 输出参数
    'x * x', # map
    'a + b', # reduce
    'y = sqrt(a)',# post-reduction map
    '0', # identity value
    'l2norm' # kernel name
)

In [30]:
x = cp.arange(10,dtype=np.float32).reshape(2,5)
res = l2norm_kernel(x,axis=1)
print(res)

[ 5.477226  15.9687195]
