Skip to content

Commit

Permalink
Merge pull request #950 from apmasell/llvm14_attributes
Browse files Browse the repository at this point in the history
Fix incorrect `byval` and other attributes on LLVM 14
  • Loading branch information
esc committed Jun 5, 2023
2 parents e2dd7b0 + efb856a commit 70f057b
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 24 deletions.
2 changes: 1 addition & 1 deletion llvmlite/binding/analysis.py
Expand Up @@ -4,7 +4,6 @@

from ctypes import POINTER, c_char_p, c_int

from llvmlite import ir
from llvmlite.binding import ffi
from llvmlite.binding.module import parse_assembly

Expand All @@ -17,6 +16,7 @@ def get_function_cfg(func, show_inst=True):
are printed.
"""
assert func is not None
from llvmlite import ir
if isinstance(func, ir.Function):
mod = parse_assembly(str(func.module))
func = mod.get_function(func.name)
Expand Down
12 changes: 9 additions & 3 deletions llvmlite/ir/instructions.py
Expand Up @@ -133,7 +133,7 @@ def called_function(self):
def _descr(self, buf, add_metadata):
def descr_arg(i, a):
if i in self.arg_attributes:
attrs = ' '.join(self.arg_attributes[i]._to_list()) + ' '
attrs = ' '.join(self.arg_attributes[i]._to_list(a.type)) + ' '
else:
attrs = ''
return '{0} {1}{2}'.format(a.type, attrs, a.get_reference())
Expand All @@ -155,13 +155,19 @@ def descr_arg(i, a):
if self.tail:
tail_marker = "{0} ".format(self.tail)

fn_attrs = ' ' + ' '.join(self.attributes._to_list(fnty.return_type))\
if self.attributes else ''

fm_attrs = ' ' + ' '.join(self.fastmath._to_list(fnty.return_type))\
if self.fastmath else ''

buf.append("{tail}{op}{fastmath} {callee}({args}){attr}{meta}\n".format(
tail=tail_marker,
op=self.opname,
callee=callee_ref,
fastmath=''.join([" " + attr for attr in self.fastmath]),
fastmath=fm_attrs,
args=args,
attr=''.join([" " + attr for attr in self.attributes]),
attr=fn_attrs,
meta=(self._stringify_metadata(leading_comma=True)
if add_metadata else ""),
))
Expand Down
81 changes: 65 additions & 16 deletions llvmlite/ir/values.py
Expand Up @@ -6,12 +6,12 @@
import functools
import string
import re
from types import MappingProxyType

from llvmlite.ir import values, types, _utils
from llvmlite.ir._utils import (_StrCaching, _StringReferenceCaching,
_HasMetadata)


_VALID_CHARS = (frozenset(map(ord, string.ascii_letters)) |
frozenset(map(ord, string.digits)) |
frozenset(map(ord, ' !#$%&\'()*+,-./:;<=>?@[]^_`{|}~')))
Expand Down Expand Up @@ -858,19 +858,26 @@ class AttributeSet(set):
_known = ()

def __init__(self, args=()):
super().__init__()
if isinstance(args, str):
args = [args]
for name in args:
self.add(name)

def _expand(self, name, typ):
return name

def add(self, name):
if name not in self._known:
raise ValueError('unknown attr {!r} for {}'.format(name, self))
self._validate_add(name)
return super(AttributeSet, self).add(name)

def __iter__(self):
# In sorted order
return iter(sorted(super(AttributeSet, self).__iter__()))
def _validate_add(self, name):
pass

def _to_list(self, typ):
return [self._expand(i, typ) for i in sorted(self)]


class FunctionAttributes(AttributeSet):
Expand Down Expand Up @@ -914,15 +921,15 @@ def personality(self, val):
assert val is None or isinstance(val, GlobalValue)
self._personality = val

def __repr__(self):
attrs = list(self)
def _to_list(self, ret_type):
attrs = super()._to_list(ret_type)
if self.alignstack:
attrs.append('alignstack({0:d})'.format(self.alignstack))
if self.personality:
attrs.append('personality {persty} {persfn}'.format(
persty=self.personality.type,
persfn=self.personality.get_reference()))
return ' '.join(attrs)
return attrs


class Function(GlobalValue):
Expand Down Expand Up @@ -975,8 +982,8 @@ def descr_prototype(self, buf):
ret = self.return_value
args = ", ".join(str(a) for a in self.args)
name = self.get_reference()
attrs = self.attributes
attrs = ' {}'.format(attrs) if attrs else ''
attrs = ' ' + ' '.join(self.attributes._to_list(
self.ftype.return_type)) if self.attributes else ''
if any(self.args):
vararg = ', ...' if self.ftype.var_arg else ''
else:
Expand Down Expand Up @@ -1018,16 +1025,58 @@ def is_declaration(self):


class ArgumentAttributes(AttributeSet):
_known = frozenset(['byval', 'inalloca', 'inreg', 'nest', 'noalias',
'nocapture', 'nonnull', 'returned', 'signext',
'sret', 'zeroext'])
# List from
# https://releases.llvm.org/14.0.0/docs/LangRef.html#parameter-attributes
_known = MappingProxyType({
# Each tuple is LLVM 11 vs 14 behaviour:
# None (unsupported),
# True (emit type),
# False (emit name only)
'byref': (None, True),
'byval': (True, True),
'elementtype': (None, True),
'immarg': (False, False),
'inalloca': (False, True),
'inreg': (False, False),
'nest': (False, False),
'noalias': (False, False),
'nocapture': (False, False),
'nofree': (False, False),
'nonnull': (False, False),
'noundef': (False, False),
'preallocated': (True, True),
'returned': (False, False),
'signext': (False, False),
'sret': (False, True),
'swiftasync': (None, False),
'swifterror': (False, False),
'swiftself': (False, False),
'zeroext': (False, False),
})

def __init__(self, args=()):
self._align = 0
self._dereferenceable = 0
self._dereferenceable_or_null = 0
super(ArgumentAttributes, self).__init__(args)

def _validate_add(self, name):
import llvmlite.binding
llvm_major = llvmlite.binding.llvm_version_info[0]
requires_type = self._known.get(name)[llvm_major >= 14]
if requires_type is None:
raise ValueError(
f"Attribute {name} is not supported on current LLVM version")

def _expand(self, name, typ):
import llvmlite.binding
llvm_major = llvmlite.binding.llvm_version_info[0]
requires_type = self._known.get(name)[llvm_major >= 14]
if requires_type:
return f"{name}({typ.pointee})"
else:
return name

@property
def align(self):
return self._align
Expand Down Expand Up @@ -1055,8 +1104,8 @@ def dereferenceable_or_null(self, val):
assert isinstance(val, int) and val >= 0
self._dereferenceable_or_null = val

def _to_list(self):
attrs = sorted(self)
def _to_list(self, typ):
attrs = super()._to_list(typ)
if self.align:
attrs.append('align {0:d}'.format(self.align))
if self.dereferenceable:
Expand Down Expand Up @@ -1088,7 +1137,7 @@ class Argument(_BaseArgument):
"""

def __str__(self):
attrs = self.attributes._to_list()
attrs = self.attributes._to_list(self.type)
if attrs:
return "{0} {1} {2}".format(self.type, ' '.join(attrs),
self.get_reference())
Expand All @@ -1102,7 +1151,7 @@ class ReturnValue(_BaseArgument):
"""

def __str__(self):
attrs = self.attributes._to_list()
attrs = self.attributes._to_list(self.type)
if attrs:
return "{0} {1}".format(' '.join(attrs), self.type)
else:
Expand Down
84 changes: 80 additions & 4 deletions llvmlite/tests/test_ir.py
Expand Up @@ -84,13 +84,20 @@ def _normalize_asm(self, asm):
asm = asm.replace("\n ", "\n ")
return asm

def check_descr_regex(self, descr, asm):
expected = self._normalize_asm(asm)
self.assertRegex(descr, expected)

def check_descr(self, descr, asm):
expected = self._normalize_asm(asm)
self.assertEqual(descr, expected)

def check_block(self, block, asm):
self.check_descr(self.descr(block), asm)

def check_block_regex(self, block, asm):
self.check_descr_regex(self.descr(block), asm)

def check_module_body(self, module, asm):
expected = self._normalize_asm(asm)
actual = module._stringify_body()
Expand Down Expand Up @@ -1366,11 +1373,11 @@ def test_call_attributes(self):
2: 'noalias'
}
)
self.check_block(block, """\
self.check_block_regex(block, """\
my_block:
%"retval" = alloca i32
%"other" = alloca i32
call void @"fun"(i32* noalias sret %"retval", i32 42, i32* noalias %"other")
call void @"fun"\\(i32\\* noalias sret(\\(i32\\))? %"retval", i32 42, i32\\* noalias %"other"\\)
""") # noqa E501

def test_call_tail(self):
Expand Down Expand Up @@ -1449,11 +1456,11 @@ def test_invoke_attributes(self):
2: 'noalias'
}
)
self.check_block(block, """\
self.check_block_regex(block, """\
my_block:
%"retval" = alloca i32
%"other" = alloca i32
invoke fast fastcc void @"fun"(i32* noalias sret %"retval", i32 42, i32* noalias %"other") noinline
invoke fast fastcc void @"fun"\\(i32\\* noalias sret(\\(i32\\))? %"retval", i32 42, i32\\* noalias %"other"\\) noinline
to label %"normal" unwind label %"unwind"
""") # noqa E501

Expand Down Expand Up @@ -1779,6 +1786,75 @@ def test_fma_mixedtypes(self):
"expected types to be the same, got float, double, float",
str(raises.exception))

def test_arg_attributes(self):
def gen_code(attr_name):
fnty = ir.FunctionType(ir.IntType(32), [ir.IntType(32).as_pointer(),
ir.IntType(32)])
module = ir.Module()

func = ir.Function(module, fnty, name="sum")

bb_entry = func.append_basic_block()
bb_loop = func.append_basic_block()
bb_exit = func.append_basic_block()

builder = ir.IRBuilder()
builder.position_at_end(bb_entry)

builder.branch(bb_loop)
builder.position_at_end(bb_loop)

index = builder.phi(ir.IntType(32))
index.add_incoming(ir.Constant(index.type, 0), bb_entry)
accum = builder.phi(ir.IntType(32))
accum.add_incoming(ir.Constant(accum.type, 0), bb_entry)

func.args[0].add_attribute(attr_name)
ptr = builder.gep(func.args[0], [index])
value = builder.load(ptr)

added = builder.add(accum, value)
accum.add_incoming(added, bb_loop)

indexp1 = builder.add(index, ir.Constant(index.type, 1))
index.add_incoming(indexp1, bb_loop)

cond = builder.icmp_unsigned('<', indexp1, func.args[1])
builder.cbranch(cond, bb_loop, bb_exit)

builder.position_at_end(bb_exit)
builder.ret(added)

return str(module)

llvm_major = llvm.llvm_version_info[0]
if llvm_major >= 14:
supplemental = ('byref', 'swiftasync', 'elementtype')
else:
supplemental = ()

for attr_name in (
'byval',
'immarg',
'inalloca',
'inreg',
'nest',
'noalias',
'nocapture',
'nofree',
'nonnull',
'noundef',
'preallocated',
'returned',
'signext',
'swifterror',
'swiftself',
'zeroext',
) + supplemental:
# If this parses, we emitted the right byval attribute format
llvm.parse_assembly(gen_code(attr_name))
# sret doesn't fit this pattern and is tested in test_call_attributes


class TestBuilderMisc(TestBase):
"""
Expand Down

0 comments on commit 70f057b

Please sign in to comment.