Skip to content

Commit

Permalink
try to fix issue with long function names
Browse files Browse the repository at this point in the history
  • Loading branch information
hughperkins committed Dec 1, 2016
1 parent de5b094 commit 4b3d457
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 20 deletions.
36 changes: 21 additions & 15 deletions src/kernel_dumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "function_dumper.h"
#include "shims.h"
#include "mutations.h"
#include "EasyCL/util/easycl_stringhelper.h"

#include "llvm/IR/Constants.h"

Expand Down Expand Up @@ -122,6 +123,24 @@ std::string KernelDumper::toCl() {
if(F == 0) {
throw runtime_error("Couldnt find kernel " + kernelName);
}
// kernel name will simply be truncated to 32 characters
// other names will fit around it
if(kernelName.size() > 32) {
kernelName = kernelName.substr(0, 31);
F->setName(kernelName);
}
cout << "F name " << F->getName().str() << endl;

int i = 0;
for(auto it = M->begin(); it != M->end(); it++) {
Function *F = &*it;
string name = F->getName().str();
if(name.size() > 32) {
name = name.substr(0, 28) + easycl::toString(i);
i++;
F->setName(name);
}
}

// GlobalNames globalNames;
// TypeDumper typeDumper(&globalNames);
Expand All @@ -130,28 +149,14 @@ std::string KernelDumper::toCl() {

ostringstream moduleClStream;

// set<Function *> dumpedFunctions;
set<Function *> neededFunctions;
set<Function *> isKernel;
map<Function *, Type *> returnTypeByFunction;
map<string, string> oldNameByNewName;

isKernel.insert(F);
// dumpedFunctions.insert(F);
neededFunctions.insert(F);

// FunctionDumper functionDumper(F, true, &globalNames, &typeDumper, &functionNamesMap);
// if(_addIRToCl) {
// functionDumper.addIRToCl();
// }
// string kernelFunctionCl = functionDumper.toCl();

// functionDeclarations.insert(functionDumper.getDeclaration());
// // cout << "kernelFunctionCl:\n" << kernelFunctionCl << endl;
// moduleClStream << kernelFunctionCl;
// shimFunctionsNeeded.insert(functionDumper.shimFunctionsNeeded.begin(), functionDumper.shimFunctionsNeeded.end());
// neededFunctions.insert(functionDumper.neededFunctions.begin(), functionDumper.neededFunctions.end());
// structsToDefine.insert(functionDumper.structsToDefine.begin(), functionDumper.structsToDefine.end());

int nothingHappenedCount = 0;
while(returnTypeByFunction.size() < neededFunctions.size()) {
bool changedSomething = false;
Expand Down Expand Up @@ -182,6 +187,7 @@ std::string KernelDumper::toCl() {
if(_addIRToCl) {
childFunctionDumper.addIRToCl();
}
cout << " running generation on " << childF->getName().str() << endl;
if(!childFunctionDumper.runGeneration(returnTypeByFunction)) {
// cout << "couldnt run generation to completion yet for " << childF->getName().str() << endl;
neededFunctions.insert(childFunctionDumper.neededFunctions.begin(), childFunctionDumper.neededFunctions.end());
Expand Down
7 changes: 7 additions & 0 deletions src/new_instruction_dumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,14 @@ void NewInstructionDumper::dumpCall(LocalValueInfo *localValueInfo, const std::m
CallInst *instr = cast<CallInst>(localValueInfo->value);

// string gencode = "";
Value *calledValue = instr->getCalledValue();
string calledName = calledValue->getName().str();
// if(calledName.size() > 32) {
// calledName = calledName.substr(0, 31);
// calledValue->setName(calledName);
// }
string functionName = instr->getCalledValue()->getName().str();
// cout << "called function: [" << functionName << "]" << endl;
bool internalfunc = false;
if(functionName == "llvm.ptx.read.tid.x") {
localValueInfo->setAddressSpace(0);
Expand Down
42 changes: 42 additions & 0 deletions test/gtest/test_kernel_dumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "type_dumper.h"
#include "GlobalNames.h"
#include "LocalNames.h"
#include "EasyCL/util/easycl_stringhelper.h"

#include "llvm/IRReader/IRReader.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -233,4 +234,45 @@ v1:;
EXPECT_FALSE(cl.find(" = returnsVoid") != string::npos);
}

TEST(test_kernel_dumper, test_long_conflicting_names) {
GlobalWrapper G("mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamec");
KernelDumper *kernelDumper = G.kernelDumper.get();
// Module *M = getM();

string cl = kernelDumper->toCl();
cout << "kernel cl: [" << cl << "]" << endl;
EXPECT_TRUE(cl.find(" void mysuperlongfunctionnamemysuperlo(") != string::npos); // kernel name should be exactly 32 characters, simply truncated
vector<string>splitLine = easycl::split(cl, "\n");
for(auto it = splitLine.begin(); it != splitLine.end(); it++) {
string line = *it;
// cout << "line [" << *it << "]" << endl;
EXPECT_LE(line.size(), 128u);
}
EXPECT_EQ(R"(
kernel void mysuperlongfunctionnamemysuperl(global float* d, uint d_offset, local int *scratch);
void mysuperlongfunctionnamemysup0_g(global float* d, local int *scratch);
void mysuperlongfunctionnamemysup1_g(global float* d, local int *scratch);
kernel void mysuperlongfunctionnamemysuperl(global float* d, uint d_offset, local int *scratch) {
d += d_offset;
v1:;
mysuperlongfunctionnamemysup0_g(d, scratch);
mysuperlongfunctionnamemysup1_g(d, scratch);
return;
}
void mysuperlongfunctionnamemysup0_g(global float* d, local int *scratch) {
v1:;
return;
}
void mysuperlongfunctionnamemysup1_g(global float* d, local int *scratch) {
v1:;
return;
}
)", cl);
}

} // namespace
14 changes: 14 additions & 0 deletions test/gtest/test_kernel_dumper.ll
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,17 @@ define void @usesFunctionReturningVoid(float *%in) {
call void @returnsVoid(float *%in)
ret void
}

define void @mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamea(float *%d) {
ret void
}

define void @mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnameb(float *%d) {
ret void
}

define void @mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamec(float *%d) {
call void @mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamea(float *%d)
call void @mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnameb(float *%d)
ret void
}
2 changes: 1 addition & 1 deletion test/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def try_build(context, filepath, kernelname):
with open(filepath, 'r') as f:
cucode = f.read()
clcode = test_common.cu_to_cl(cucode, kernelname)
test_common.build_kernel(context, clcode, kernelname)
test_common.build_kernel(context, clcode, kernelname[:31])


def test_program_compiles(context):
Expand Down
59 changes: 59 additions & 0 deletions test/test_cloutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,62 @@ def test_float4_test2(cuSourcecode, context, ctx, q, float_data, float_data_gpu)
print('float_data[:8]', float_data[:8])
for i in range(4):
assert float_data[i] == float_data_orig[i + 4]


def test_long_conflicting_names(context, q):
cu_source = """
__device__ void mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionname(float *d) {
d[1] = 1.0f;
}
__device__ void mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnameb(float *d) {
d[2] = 3.0f;
}
__global__ void mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamec(float *data) {
data[0] = 123.0f;
mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionname(data);
mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnameb(data);
}
"""
mangled_name = test_common.mangle('mysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamemysuperlongfunctionnamec', ['float *'])
cl_source = test_common.cu_to_cl(cu_source, mangled_name)
print('cl_source', cl_source)
for line in cl_source.split("\n"):
if line.strip().startswith('/*'):
continue
if not line.strip().replace('kernel ', '').strip().startswith('void'):
continue
name = line.replace('kernel ', '').replace('void ', '').split('(')[0]
if name != '':
print('name', name)
assert len(name) <= 32
test_common.build_kernel(context, cl_source, mangled_name[:31])


def test_short_names(context):
cu_source = """
__device__ void funca(float *d);
__device__ void funca(float *d) {
d[1] = 1.0f;
}
__device__ void funcb(float *d, int c) {
d[2] = 3.0f + 5 - d[c];
}
__global__ void funck(float *data) {
data[0] = 123.0f;
funca(data);
funcb(data, (int)data[6]);
for(int i = 0; i < 1000; i++) {
funcb(data + i, (int)data[i + 100]);
}
}
"""
mangled_name = test_common.mangle('funck', ['float *'])
cl_source = test_common.cu_to_cl(cu_source, mangled_name)
print('cl_source', cl_source)

test_common.build_kernel(context, cl_source, mangled_name[:31])
2 changes: 1 addition & 1 deletion test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


def run_process(cmdline_list):
print('running [%s]' % ' '.join(cmdline_list))
fout = open('/tmp/pout.txt', 'w')
res = subprocess.run(cmdline_list, stdout=fout, stderr=subprocess.STDOUT)
fout.close()
with open('/tmp/pout.txt', 'r') as f:
output = f.read()
print(' '.join(res.args))
print(output)
assert res.returncode == 0
return output
Expand Down
2 changes: 1 addition & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_compile(context, cu_filepath, kernelname):

cl_code = test_common.cu_to_cl(cu_code, mangledname)

test_common.build_kernel(context, cl_code, mangledname)
test_common.build_kernel(context, cl_code, mangledname[:31])


def test_no_pointer_struct_ointer(context):
Expand Down
4 changes: 2 additions & 2 deletions test/tf/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_cwise_sqrt(context, q, float_data, float_data_gpu):
scratch = workgroup_size * 4

print('running kernel...')
prog.__getattr__('_ZN5Eigen8internal15EigenMetaKernelINS_15TensorEvaluatorIKNS_14TensorAssignOpINS_9TensorMapINS_6TensorIfLi1ELi1EiEELi16ENS_11MakePointerEEEKNS_18TensorCwiseUnaryOpINS0_14scalar_sqrt_opIfEEKNS4_INS5_IKfLi1ELi1EiEELi16ES7_EEEEEENS_9GpuDeviceEEEiEEvT_T0_')(
prog.__getattr__('_ZN5Eigen8internal15EigenMetaKernelINS_15TensorEvaluatorIKNS_14TensorAssignOpINS_9TensorMapINS_6TensorIfLi1ELi1EiEELi16ENS_11MakePointerEEEKNS_18TensorCwiseUnaryOpINS0_14scalar_sqrt_opIfEEKNS4_INS5_IKfLi1ELi1EiEELi16ES7_EEEEEENS_9GpuDeviceEEEiEEvT_T0_'[:31])(
q, (global_size,), (workgroup_size,),
eval_nopointers_gpu,
eval_ptr0_gpu, offset_type(eval_ptr0_offset),
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_cwise_sqrt_singlebuffer(context, queue, float_data, float_data_gpu):
workgroup_size = 256
scratch = workgroup_size * 4

prog.__getattr__('_ZN5Eigen8internal15EigenMetaKernelINS_15TensorEvaluatorIKNS_14TensorAssignOpINS_9TensorMapINS_6TensorIfLi1ELi1EiEELi16ENS_11MakePointerEEEKNS_18TensorCwiseUnaryOpINS0_14scalar_sqrt_opIfEEKNS4_INS5_IKfLi1ELi1EiEELi16ES7_EEEEEENS_9GpuDeviceEEEiEEvT_T0_')(
prog.__getattr__('_ZN5Eigen8internal15EigenMetaKernelINS_15TensorEvaluatorIKNS_14TensorAssignOpINS_9TensorMapINS_6TensorIfLi1ELi1EiEELi16ENS_11MakePointerEEEKNS_18TensorCwiseUnaryOpINS0_14scalar_sqrt_opIfEEKNS4_INS5_IKfLi1ELi1EiEELi16ES7_EEEEEENS_9GpuDeviceEEEiEEvT_T0_'[:31])(
queue, (global_size,), (workgroup_size,),
eval_nopointers_gpu,
eval_ptr0_gpu, offset_type(eval_ptr0_offset),
Expand Down

0 comments on commit 4b3d457

Please sign in to comment.