In [1]:
%%writefile MCTS.cu

Writing MCTS.cu


In [7]:
%%writefile MCTS.cu
#include <cstdio>
#include <cstdlib>
#include <cuda_runtime.h>
#include <curand_kernel.h>

#define NUM_SIMULATIONS 1024
#define MAX_DEPTH 100

#define CUDA_CHECK(call) do {                                         \
    cudaError_t err = (call);                                         \
    if (err != cudaSuccess) {                                         \
        fprintf(stderr, "CUDA error %s:%d: %s\n", __FILE__, __LINE__, \
                cudaGetErrorString(err));                             \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

struct GameState {
    int moves[10];
    int num_moves;
    bool is_terminal;
    float reward;

    __device__ GameState next_state(int action) const {
        GameState new_state = *this;
        new_state.reward += (action % 2 == 0) ? 1.0f : -1.0f;
        new_state.is_terminal = (new_state.reward > 10.0f || new_state.reward < -10.0f);
        return new_state;
    }

    __device__ int get_random_action(curandState* state) const {
        if (num_moves == 0) return -1;
        unsigned int r = curand(state);
        return moves[r % num_moves];
    }
};

struct Node {
    GameState state;
    int visits;
    float value;
};

__device__ float rollout(GameState state, curandState* rand_state) {
    int depth = 0;
    while (!state.is_terminal && depth < MAX_DEPTH) {
        int action = state.get_random_action(rand_state);
        if (action == -1) break;
        state = state.next_state(action);
        ++depth;
    }
    return state.reward;
}

__global__ void mcts_kernel(const Node* nodes, int num_nodes, float* results) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= num_nodes) return;

    // Uncomment if you really need a debug print per active thread:
    // printf("mcts_kernel running for node %d\n", idx);

    curandState rand_state;
    // seed can be anything; using idx for determinism
    curand_init(/*seed=*/1234ULL, /*sequence=*/idx, /*offset=*/0ULL, &rand_state);

    float total_reward = 0.0f;
    for (int i = 0; i < NUM_SIMULATIONS; ++i) {
        total_reward += rollout(nodes[idx].state, &rand_state);
    }
    results[idx] = total_reward / NUM_SIMULATIONS;
}

void run_mcts(Node* host_nodes, int num_nodes) {
    printf("Entering run_mcts\n");
    Node* device_nodes = nullptr;
    float* device_results = nullptr;
    float* host_results = (float*)malloc(num_nodes * sizeof(float));

    CUDA_CHECK(cudaMalloc(&device_nodes, num_nodes * sizeof(Node)));
    CUDA_CHECK(cudaMalloc(&device_results, num_nodes * sizeof(float)));
    CUDA_CHECK(cudaMemcpy(device_nodes, host_nodes, num_nodes * sizeof(Node), cudaMemcpyHostToDevice));

    int threadsPerBlock = 128;
    int blocksPerGrid = (num_nodes + threadsPerBlock - 1) / threadsPerBlock;

    mcts_kernel<<<blocksPerGrid, threadsPerBlock>>>(device_nodes, num_nodes, device_results);
    CUDA_CHECK(cudaGetLastError());
    CUDA_CHECK(cudaDeviceSynchronize());

    CUDA_CHECK(cudaMemcpy(host_results, device_results, num_nodes * sizeof(float), cudaMemcpyDeviceToHost));

    for (int i = 0; i < num_nodes; ++i) {
        host_nodes[i].value = host_results[i];
    }

    free(host_results);
    CUDA_CHECK(cudaFree(device_nodes));
    CUDA_CHECK(cudaFree(device_results));
}

int main() {
    printf("Entering main\n");

    Node root{};
    // Initialize game state properly
    root.state.num_moves = 10;
    for (int i = 0; i < root.state.num_moves; ++i) root.state.moves[i] = i;
    root.state.is_terminal = false;
    root.state.reward = 0.0f;

    root.visits = 0;
    root.value = 0.0f;

    run_mcts(&root, 1);

    printf("MCTS result: %f\n", root.value);
    return 0;
}


Overwriting MCTS.cu


In [8]:
!nvcc MCTS.cu -o MCTS -gencode arch=compute_75,code=sm_75 -lcublas

!./MCTS

Entering main
Entering run_mcts
MCTS result: 0.042969
