In [1]:
from numba import jit, njit, cuda , prange
import numba
import numpy as np

# registers = [729, 0, 0]
# program = [0,1,5,4,3,0]

register_a = 0
register_b = 1
register_c = 2

# enabling the JIT causes print statements to not print lists as expected
@njit(parallel=True)
def loop_em(start, end, blocks_per_grid, threads_per_block, program, return_array):
    print(program)

    total_range = end - start
    # handle if not exactly divisible
    total_num_threads = (blocks_per_grid * threads_per_block)
    cycles_per_thread = (total_range // total_num_threads)+1
    print(f'total_range={total_range} start={start} end={end} total_num_threads:{total_num_threads} cycles_per_thread={cycles_per_thread} program = {program}')

    # for CPU we back into the loops per core (number of threads)
    # prange maxes out at the number of virtual cores
    for global_thread_id in prange(total_num_threads):
        block_id = global_thread_id // threads_per_block
        thread_in_block_offset = global_thread_id % threads_per_block
        print(f'starting block:{block_id} thread_in_block:{thread_in_block_offset} global_thread:{global_thread_id} planned_cycles:{cycles_per_thread}')
        process_batch(program, return_array, start, end, blocks_per_grid, block_id, threads_per_block, global_thread_id, cycles_per_thread)

# the program.  The start of everything. The end for overflow protection, which batch, and the batch size
@njit()
def process_batch(program, return_array, start, end, blocks_per_grid, block_id, threads_per_block, global_thread_id, cycles_per_thread):
    for cycle_index in range(cycles_per_thread):
        # which thread block are we in?
        global_cycle_id =  cycle_index + (global_thread_id * cycles_per_thread)
        current = global_cycle_id
        # handle the fact the batch size is not an exact factor of the total range
        if (current <= end):
            # This program is 8 octal operands so the A register needs to be 8 octal digits.
            registers = np.array([current, 0, 0])
            # print(f'process_batch {start}:{current}:{end} block:{block_id} thread:{global_thread_id}  cycle:{cycle_index}:{global_cycle_id} cycles_per_thread:{cycles_per_thread} program: {program} registers: {registers}')
            output = []
            resolve_operand = lambda operand: operand if (operand<4) else registers[operand-4]
            # if (i%1000000 == 0):
            #     print(f'at: {current} : {i}/{loop_iterations}')
            
            address_ptr = 0
            while (address_ptr < len(program)):
                # numba says these are int64
                operator = program[address_ptr]
                operand = program[address_ptr+1]
                next_address_ptr = address_ptr+2
                # print (f'address: {address_ptr} operator: {operator} operand: {operand} registers: {registers} ')
                match (operator):
                    case 0: # adv division register_a ~/ 2^comboOperand
                        registers[register_a] = registers[register_a] // 2 ** resolve_operand(operand)
                    case 1: # bxl bitwise XOR (registerB , operand)
                        registers[register_b] = registers[register_b] ^ operand
                    case 2: # bst operand modulo 8
                        registers[register_b] = resolve_operand(operand) % 8
                    case 3: # jnz jump not zero
                        if (registers[register_a] != 0):
                            next_address_ptr =  operand
                    case 4: #bxc bitwise xor reg b, reg c
                        registers[register_b] = registers[register_b] ^ registers[register_c]
                    case 5: # out % modulo 8
                        output.append(resolve_operand(operand) %8)
                    case 6: # BDV integer division on A , stored in B
                        divisor = 2 ** resolve_operand(operand)
                        registers[register_b] = registers[register_a] // divisor
                    case 7: # CDV
                        divisor = 2 ** resolve_operand(operand)
                        registers[register_c] = registers[register_a] // divisor
                    case _:
                        #print('oh no')
                        result = -1
                
                address_ptr = next_address_ptr
                # print(f'now at: {address_ptr} output after {output}')
                # print(f'final registers: {registers} output {np.array(output)} 
                # This exists because I played with different lengths while experimenting
                if (len(output) == 16
                    and output[0]==program[0] 
                    and output[1]==program[1] 
                    and output[2]==program[2] 
                    # and output[3]==program[3]
                    # and output[4]==program[4]
                    # and output[5]==program[5]
                    # and output[6]==program[6]
                    # and output[7]==program[7]
                    # and output[8]==program[8]
                    # and output[9]==program[9]
                    # and output[10]==program[10]
                    # and output[11]==program[11]
                    # and output[12]==program[12]
                    # and output[13]==program[13]
                    # and output[14]==program[14]
                    # and output[15]==program[15]
                    ):
                    return_array[global_thread_id]= global_cycle_id
                    # will not print if njit is enabled
                    # print(f'matches {oct(current)} - {output} -{current}')
                    # use with njit
                    print(output)
                    print(f'matches {output} - block:{block_id} thread:{global_thread_id}  cycle:{cycle_index}:{global_cycle_id} ')
                    # print(output)
                    # from before we put the loop in
                    return
        # print(f'{oct(current)} - {output}')


In [None]:
%%time
# 3060 TI has 38 SMs with 128 processor each = 48640 cores
#
# number of blocks in a grid
num_blocks_per_grid = 38
# thread id is the id within a block
# the block id is the id within a grid
# block width is the number of threads per block
num_threads_per_block = 256

# My ryzen
num_blocks_per_grid = 8
num_threads_per_block = 3

num_blocks_per_grid = 8
num_threads_per_block = 3

total_num_threads = (num_blocks_per_grid * num_threads_per_block)
return_array = np.zeros(total_num_threads, dtype=int)

loop_em(int(0o1000000000000000),
        int(0o1111111111111111),
        num_blocks_per_grid,
        num_threads_per_block,
        np.array([2,4,1,3,7,5,4,7,0,3,1,5,5,5,3,0],),
        return_array)
print(return_array)

# 16 digits octal
# 16th digit must be 1 otherwise return is shorter than the program



[2 4 1 3 7 5 4 7 0 3 1 5 5 5 3 0]
total_range=5026338869833 start=35184372088832 end=40210710958665 total_num_threads:24 cycles_per_thread=209430786244 program = <object type:array(int64, 1d, C)>
starting block:0 thread_in_block:0 global_thread:0 planned_cycles:209430786244
starting block:2 thread_in_block:0 global_thread:6 planned_cycles:209430786244
starting block:3 thread_in_block:2 global_thread:11 planned_cycles:209430786244
starting block:7 thread_in_block:2 global_thread:23 planned_cycles:209430786244
starting block:5 thread_in_block:0 global_thread:15 planned_cycles:209430786244
starting block:4 thread_in_block:1 global_thread:13 planned_cycles:209430786244
starting block:7 thread_in_block:0 global_thread:21 planned_cycles:209430786244
starting block:5 thread_in_block:2 global_thread:17 planned_cycles:209430786244
starting block:2 thread_in_block:2 global_thread:8 planned_cycles:209430786244
starting block:6 thread_in_block:0 global_thread:18 planned_cycles:209430786244
startin