# CUDA 测试代码

In [4]:
!pip install numba -i https://mirrors.ustc.edu.cn/pypi/simple
from numba import cuda, njit
import numpy as np

# 检查是否有可用的 GPU
cuda_available = cuda.is_available()
print("CUDA is available: ", cuda_available)


def add_kernel_cpu(io_array, value):
    for idx in range(io_array.size):
        io_array[idx] += value


# 定义一个简单的 CUDA 核函数（仅当CUDA可用时）
if cuda_available:
    @cuda.jit
    def add_kernel_gpu(io_array, value):
        idx = cuda.grid(1)
        if idx < io_array.size:
            io_array[idx] += value

# 创建一个numpy数组
data = np.arange(100, dtype=np.float32)
data_original = data.copy()  # 保留原始数据以备后用

if cuda_available:
    # 分配GPU内存并将数据复制到GPU
    data_gpu = cuda.to_device(data)

    # 设置线程块大小和网格大小
    threadsperblock = 256
    blockspergrid = (data.size + (threadsperblock - 1)) // threadsperblock

    # 调用核函数
    add_kernel_gpu[blockspergrid, threadsperblock](data_gpu, 10)

    # 将修改后的数据从GPU复制回主机
    data_gpu.copy_to_host(data)
else:
    # 使用CPU进行计算
    add_kernel_cpu(data, 10)

print("Modified array (using CUDA when available): ", data)

# 如果你想要验证结果是否正确，可以对比使用NumPy直接加10的结果
expected_result = data_original + 10
print("Expected result: ", expected_result)
assert np.allclose(data, expected_result), "The results do not match the expected output."

Looking in indexes: https://mirrors.ustc.edu.cn/pypi/simple

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
CUDA is available:  False
Modified array (using CUDA when available):  [ 10.  11.  12.  13.  14.  15.  16.  17.  18.  19.  20.  21.  22.  23.
  24.  25.  26.  27.  28.  29.  30.  31.  32.  33.  34.  35.  36.  37.
  38.  39.  40.  41.  42.  43.  44.  45.  46.  47.  48.  49.  50.  51.
  52.  53.  54.  55.  56.  57.  58.  59.  60.  61.  62.  63.  64.  65.
  66.  67.  68.  69.  70.  71.  72.  73.  74.  75.  76.  77.  78.  79.
  80.  81.  82.  83.  84.  85.  86.  87.  88.  89.  90.  91.  92.  93.
  94.  95.  96.  97.  98.  99. 100. 101. 102. 103. 104. 105. 106. 107.
 108. 109.]
Expected result:  [ 10.  11.  12.  13.  14.  15.  16.  17.  18.  19.  20.  21.  22.  23.
  24.  25.  26.  2