Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Make block, thread, and warp indices unsigned. #6112

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions analysis/commands.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

sed 's/0x[0-9a-f]*/HEX/g' annotation_without_tid.txt > nohex_annotation_without_tid.txt
sed 's/0x[0-9a-f]*/HEX/g' annotation_with_tid.txt > nohex_annotation_with_tid.txt
diff -u nohex_annotation_without_tid.txt nohex_annotation_with_tid.txt > tid_typing_change.diff
grep "^[+-]" tid_typing_change.diff > changes_only.diff
Binary file added analysis/data.xlsx
Binary file not shown.
54 changes: 54 additions & 0 deletions analysis/reg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import re
import pandas as pd


def parse(line):
_, name, attrs = line.split(' : ')
assignment = r'=([0-9]+)[,\)]'
values = [int(x) for x in re.findall(assignment, attrs)]
regs, shared, local, const, maxthreads = values
return name, regs, shared, local, const, maxthreads


def read_file(name, col_suffix):
with open(name) as f:
entries = [parse(line) for line in sorted(f.readlines())
if 'ATTRIBUTES' in line]
print(f'Total entries: {len(entries)}')

name = []
regs = []
shared = []
local = []
const = []
maxthreads = []

for values in entries:
name.append(values[0])
regs.append(values[1])
shared.append(values[2])
local.append(values[3])
const.append(values[4])
maxthreads.append(values[5])

data = {
'name': name,
f'regs_{col_suffix}': regs,
f'shared_{col_suffix}': shared,
f'local_{col_suffix}': local,
f'const_{col_suffix}': const,
f'maxthreads_{col_suffix}': maxthreads
}

return pd.DataFrame(data=data)


def read_files():
before_df = read_file('log_without_tid.txt', 'before')
after_df = read_file('log_with_tid.txt', 'after')

return pd.merge(before_df, after_df, how='outer', on='name')


if __name__ == '__main__':
read_file('log_with_tid.txt')
1 change: 1 addition & 0 deletions numba/core/typeconv/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _init_casting_rules(tm):

tcr.promote_unsafe(types.int32, types.int64)
tcr.promote_unsafe(types.uint32, types.uint64)
tcr.promote_unsafe(types.tid, types.int64)

tcr.safe_unsafe(types.uint8, types.int16)
tcr.safe_unsafe(types.uint16, types.int32)
Expand Down
2 changes: 2 additions & 0 deletions numba/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
uint32 = Integer('uint32')
uint64 = Integer('uint64')

tid = Integer('thread_idx', 32, signed=False)

int8 = Integer('int8')
int16 = Integer('int16')
int32 = Integer('int32')
Expand Down
16 changes: 8 additions & 8 deletions numba/cuda/cudadecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def typer(ndim):
raise errors.RequireLiteralValue(ndim)
val = ndim.literal_value
if val == 1:
restype = types.int32
restype = types.tid
elif val in (2, 3):
restype = types.UniTuple(types.int32, val)
restype = types.UniTuple(types.tid, val)
else:
raise ValueError('argument can only be 1, 2, 3')
return signature(restype, types.int32)
return signature(restype, ndim)
return typer


Expand Down Expand Up @@ -449,13 +449,13 @@ class Dim3_attrs(AttributeTemplate):
key = dim3

def resolve_x(self, mod):
return types.int32
return types.tid

def resolve_y(self, mod):
return types.int32
return types.tid

def resolve_z(self, mod):
return types.int32
return types.tid


@register_attr
Expand Down Expand Up @@ -599,10 +599,10 @@ def resolve_gridDim(self, mod):
return dim3

def resolve_warpsize(self, mod):
return types.int32
return types.tid

def resolve_laneid(self, mod):
return types.int32
return types.tid

def resolve_shared(self, mod):
return types.Module(cuda.shared)
Expand Down
8 changes: 4 additions & 4 deletions numba/cuda/cudaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def ptx_sync_group(context, builder, sig, args):

# -----------------------------------------------------------------------------

@lower(cuda.grid, types.int32)
@lower(cuda.grid, types.Integer)
def cuda_grid(context, builder, sig, args):
restype = sig.return_type
if restype == types.int32:
if restype == types.tid:
return nvvmutils.get_global_id(builder, dim=1)
elif isinstance(restype, types.UniTuple):
ids = nvvmutils.get_global_id(builder, dim=restype.count)
Expand All @@ -111,12 +111,12 @@ def _nthreads_for_dim(builder, dim):
return builder.mul(ntid, nctaid)


@lower(cuda.gridsize, types.int32)
@lower(cuda.gridsize, types.Integer)
def cuda_gridsize(context, builder, sig, args):
restype = sig.return_type
nx = _nthreads_for_dim(builder, 'x')

if restype == types.int32:
if restype == types.tid:
return nx
elif isinstance(restype, types.UniTuple):
ny = _nthreads_for_dim(builder, 'y')
Expand Down
6 changes: 3 additions & 3 deletions numba/cuda/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
class Dim3Model(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('x', types.int32),
('y', types.int32),
('z', types.int32)
('x', types.tid),
('y', types.tid),
('z', types.tid)
]
super().__init__(dmm, fe_type, members)

Expand Down
2 changes: 2 additions & 0 deletions numba/np/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def as_dtype(nbtype):
NotImplementedError is if no correspondence is known.
"""
nbtype = types.unliteral(nbtype)
if nbtype == types.tid:
return np.dtype('uint32')
if isinstance(nbtype, (types.Complex, types.Integer, types.Float)):
return np.dtype(str(nbtype))
if nbtype is types.bool_:
Expand Down