In [2]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32
import math

In [3]:
@triton.jit
def get(current, row, col, stride_crow, stride_ccol, Nrow, Mcol):
    if (row < 0 or col < 0) or (row >= Nrow or col >= Mcol):
        return 0
    return tl.load(current + row * stride_crow + col * stride_ccol)

@triton.jit
def step(current, stride_crow, stride_ccol, step, stride_srow, stride_scol, Nrow, Mcol, BNrow : tl.constexpr, BMcol : tl.constexpr, deltaRs, deltaCs):
    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    curr_ptr = current + pid_row * BNrow * stride_crow + pid_col * BMcol * stride_ccol

    total_i = BNrow
    if Nrow - pid_row * BNrow < total_i:
        total_i = Nrow - pid_row * BNrow
    
    total_j = BMcol
    if Mcol - pid_col * BMcol < total_j:
        total_j = Mcol - pid_col * BMcol
    
    #i = 0
    #while i < total_i:
    for i in tl.static_range(BNrow):
        #j = 0
        for j in tl.static_range(BMcol):
        #while j < total_j:
            start_offseti = pid_row * BNrow + i
            start_offsetj = pid_col * BMcol + j
            
            cnt = 0
            for k in tl.static_range(8):
                deltaR = tl.load(deltaRs + k)
                deltaC = tl.load(deltaCs + k)
                cnt += get(current, start_offseti + deltaR, start_offsetj + deltaC, stride_crow, stride_ccol, Nrow, Mcol)

            state = get(current, start_offseti, start_offsetj, stride_crow, stride_ccol, Nrow, Mcol)
            new_state = 0
            
            if state == 0 and cnt == 3:
                new_state = 1
            elif state == 1 and (cnt == 2 or cnt == 3):
                new_state = 1

            tl.store(step + (pid_row * BNrow + i) * stride_srow + (pid_col * BMcol + j) * stride_scol, new_state)
            #tl.store(step + (pid_row * BNrow + i) * stride_srow + (pid_col * BMcol + j) * stride_scol, deltaRs + 2)

            #j += 1
        #i += 1
            
    
    #tl.store((step + pid_row * BNrow * stride_srow + pid_col * BMcol * stride_scol) + gridblock, data)
    

In [4]:
# Allocate input/output tensors
X = torch.zeros((8, 8), device='cuda').to(torch.int32)

In [5]:
for i in range(X.shape[1]):
    X[0][i] = 1
X

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32)

In [6]:
Y = torch.empty_like(X, device='cuda').to(torch.int32)
# SPMD launch grid
grid = (1, 1)

# extra params
deltaRs = torch.asarray([-1, 0, 1, 0, -1, -1, 1, 1], device='cuda').to(torch.int32)
deltaCs = torch.asarray([0, -1, 0, 1, -1, 1, -1, 1], device='cuda').to(torch.int32)

# enqueue GPU kernel
step[grid](X, X.stride(0), X.stride(1), 
              Y, Y.stride(0), Y.stride(1),
              X.shape[0]    , X.shape[1],
              8, 8, deltaRs, deltaCs)

<triton.compiler.compiler.CompiledKernel at 0x7ff3fd87f050>

In [7]:
Y

tensor([[0, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32)

In [8]:
X = torch.normal(0, 1, size=(512, 512), device='cuda')

In [9]:
X.stride(0)

512

In [10]:
X.stride(1)

1

# Running

In [11]:
from IPython.display import display, clear_output
import time

In [12]:
def to_string(X):
    ret = ""

    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            num = int(X[i][j])
            ret += ' ' if num == 0 else 'X'
            ret += ' '
        ret += '\n'

    return ret

In [13]:
# Allocate input/output tensors
X = torch.zeros((48, 48), device='cuda').to(torch.int32)

X[2][0] = 1
X[3][1] = 1
X[1][2] = 1
X[2][2] = 1
X[3][2] = 1

Y = torch.empty_like(X, device='cuda').to(torch.int32)
# SPMD launch grid
grid = (2, 2)

# extra params
deltaRs = torch.asarray([-1, 0, 1, 0, -1, -1, 1, 1], device='cuda').to(torch.int32)
deltaCs = torch.asarray([0, -1, 0, 1, -1, 1, -1, 1], device='cuda').to(torch.int32)

for i in range(100):
    clear_output(wait=True)
    print(to_string(X))
    time.sleep(0.1)
    
    # enqueue GPU kernel
    step[grid](X, X.stride(0), X.stride(1), 
                  Y, Y.stride(0), Y.stride(1),
                  X.shape[0]    , X.shape[1],
                  16, 16, deltaRs, deltaCs)

    X = Y.clone()

                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                                                                X X X X X X X X X X X X X X X X 
                              

In [14]:
# Allocate input/output tensors
X = torch.zeros((32, 100), device='cuda').to(torch.int32)

si = 16
sj = 16

X[si][sj] = 1
X[si + 1][sj] = 1
X[si + 2][sj] = 1
X[si + 1][sj + 1] = 1

X[si][sj + 4] = 1
X[si + 1][sj + 4] = 1
X[si + 2][sj + 4] = 1
X[si + 1][sj + 3] = 1

X[si + 2][sj + 2] = 1

X[si + 3][sj + 1] = 1
X[si + 3][sj + 3] = 1

X[si + 4][sj + 2] = 1

Y = torch.empty_like(X, device='cuda').to(torch.int32)
# SPMD launch grid
row_blocksize = 16
col_blocksize = 16
grid = (math.ceil(X.shape[0] / row_blocksize), math.ceil(X.shape[1] / col_blocksize))

# extra params
deltaRs = torch.asarray([-1, 0, 1, 0, -1, -1, 1, 1], device='cuda').to(torch.int32)
deltaCs = torch.asarray([0, -1, 0, 1, -1, 1, -1, 1], device='cuda').to(torch.int32)

for i in range(1000):
    clear_output(wait=True)
    print(to_string(X))
    time.sleep(0.05)
    
    # enqueue GPU kernel
    step[grid](X, X.stride(0), X.stride(1), 
                  Y, Y.stride(0), Y.stride(1),
                  X.shape[0]    , X.shape[1],
                  row_blocksize, col_blocksize, deltaRs, deltaCs)

    X = Y.clone()

                                                                                                                                                                                                        
                                                                                                                                                                                                        
                                                                                                                                                                                                        
                                                                                                                                                                                                        
                                                                                                                                                                                                    

In [284]:
@triton.jit
def get(current, row, col, stride_crow, stride_ccol, Nrow, Mcol):
    if (row < 0 or col < 0) or (row >= Nrow or col >= Mcol):
        return 0
    return tl.load(current + row * stride_crow + col * stride_ccol)

@triton.jit
def step(current, stride_crow, stride_ccol, step, stride_srow, stride_scol, Nrow, Mcol, BNrow : tl.constexpr, BMcol : tl.constexpr, deltaRs, deltaCs):
    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)

    curr_ptr = current + pid_row * BNrow * stride_crow + pid_col * BMcol * stride_ccol

    total_i = BNrow
    if Nrow - pid_row * BNrow < total_i:
        total_i = Nrow - pid_row * BNrow
    
    total_j = BMcol
    if Mcol - pid_col * BMcol < total_j:
        total_j = Mcol - pid_col * BMcol
    
    #i = 0
    #while i < BNrow and i + pid_row * BNrow < Nrow:
    for i in tl.static_range(BNrow):
        #j = 0
        for j in tl.static_range(BMcol):
            start_offseti = pid_row * BNrow + i
            start_offsetj = pid_col * BMcol + j
                
            cnt = 0
            for k in tl.static_range(8):
                deltaR = tl.load(deltaRs + k)
                deltaC = tl.load(deltaCs + k)
                cnt += get(current, start_offseti + deltaR, start_offsetj + deltaC, stride_crow, stride_ccol, Nrow, Mcol)

            state = get(current, start_offseti, start_offsetj, stride_crow, stride_ccol, Nrow, Mcol)
            new_state = 0
            
            if state == 0 and cnt == 3:
                new_state = 1
            elif state == 1 and (cnt == 2 or cnt == 3):
                new_state = 1

            if start_offseti < Nrow and start_offsetj < Mcol:
                tl.store(step + (pid_row * BNrow + i) * stride_srow + (pid_col * BMcol + j) * stride_scol, new_state)
            #tl.store(step + (pid_row * BNrow + i) * stride_srow + (pid_col * BMcol + j) * stride_scol, deltaRs + 2)
    
            #j += 1
        #i += 1
            
    
    #tl.store((step + pid_row * BNrow * stride_srow + pid_col * BMcol * stride_scol) + gridblock, data)
    