Permalink
Browse files

Implement struct-by-value parameter and return type support for FFI c…

…allbacks.
  • Loading branch information...
1 parent d48c37b commit eff3e427ba9e260df2d23f18fd077d5c74a06022 Wayne Meissner committed with headius Apr 24, 2010
Showing with 134 additions and 62 deletions.
  1. +134 −62 src/org/jruby/ext/ffi/jffi/CallbackManager.java
@@ -17,15 +17,21 @@
import org.jruby.RubyProc;
import org.jruby.anno.JRubyClass;
import org.jruby.ext.ffi.AllocatedDirectMemoryIO;
+import org.jruby.ext.ffi.ArrayMemoryIO;
import org.jruby.ext.ffi.CallbackInfo;
+import org.jruby.ext.ffi.DirectMemoryIO;
import org.jruby.ext.ffi.InvalidMemoryIO;
import org.jruby.ext.ffi.MemoryIO;
+import org.jruby.ext.ffi.NullMemoryIO;
import org.jruby.ext.ffi.Platform;
import org.jruby.ext.ffi.Pointer;
+import org.jruby.ext.ffi.Struct;
+import org.jruby.ext.ffi.StructByValue;
import org.jruby.ext.ffi.Type;
import org.jruby.ext.ffi.Util;
import org.jruby.runtime.Block;
import org.jruby.runtime.ObjectAllocator;
+import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
@@ -178,17 +184,18 @@ public ClosureInfo(Ruby runtime, Type returnType, Type[] paramTypes, CallingConv
com.kenai.jffi.Type[] ffiParameterTypes = new com.kenai.jffi.Type[paramTypes.length];
for (int i = 0; i < paramTypes.length; ++i) {
- if (!isParameterTypeValid(paramTypes[i])) {
- throw runtime.newArgumentError("Invalid callback parameter type: " + paramTypes[i]);
+ if (!isParameterTypeValid(paramTypes[i]) || (ffiParameterTypes[i] = FFIUtil.getFFIType(paramTypes[i])) == null) {
+ throw runtime.newTypeError("invalid callback parameter type: " + paramTypes[i]);
}
- ffiParameterTypes[i] = FFIUtil.getFFIType(paramTypes[i].getNativeType());
}
- if (!isReturnTypeValid(returnType)) {
- runtime.newArgumentError("Invalid callback return type: " + returnType);
+ com.kenai.jffi.Type ffiReturnType = null;
+ if (!isReturnTypeValid(returnType) || (ffiReturnType = FFIUtil.getFFIType(returnType)) == null) {
+ runtime.newTypeError("invalid callback return type: " + returnType);
}
- this.callContext = new com.kenai.jffi.CallContext(FFIUtil.getFFIType(returnType.getNativeType()), ffiParameterTypes, convention);
+
+ this.callContext = new com.kenai.jffi.CallContext(ffiReturnType, ffiParameterTypes, convention);
this.returnType = returnType;
this.parameterTypes = (Type[]) paramTypes.clone();
}
@@ -228,17 +235,20 @@ void dispose() {
}
protected final void invoke(Closure.Buffer buffer, Object recv) {
+ ThreadContext context = runtime.getCurrentContext();
+
IRubyObject[] params = new IRubyObject[closureInfo.parameterTypes.length];
for (int i = 0; i < params.length; ++i) {
params[i] = fromNative(runtime, closureInfo.parameterTypes[i], buffer, i);
}
+
IRubyObject retVal;
if (recv instanceof RubyProc) {
- retVal = ((RubyProc) recv).call(runtime.getCurrentContext(), params);
+ retVal = ((RubyProc) recv).call(context, params);
} else if (recv instanceof Block) {
- retVal = ((Block) recv).call(runtime.getCurrentContext(), params);
+ retVal = ((Block) recv).call(context, params);
} else {
- retVal = ((IRubyObject) recv).callMethod(runtime.getCurrentContext(), "call", params);
+ retVal = ((IRubyObject) recv).callMethod(context, "call", params);
}
setReturnValue(runtime, closureInfo.returnType, buffer, retVal);
@@ -421,9 +431,46 @@ private static final void setReturnValue(Ruby runtime, Type type,
buffer.setAddressReturn(addressValue(cb));
} else {
buffer.setAddressReturn(0L);
+ throw runtime.newTypeError("invalid callback return value, expected Proc or callable object");
}
+
+ } else if (type instanceof StructByValue) {
+
+ if (value instanceof Struct) {
+ Struct s = (Struct) value;
+ MemoryIO memory = s.getMemory().getMemoryIO();
+
+ if (memory instanceof DirectMemoryIO) {
+ long address = ((DirectMemoryIO) memory).getAddress();
+ if (address != 0) {
+ buffer.setStructReturn(address);
+ } else {
+ // Zero it out
+ buffer.setStructReturn(new byte[type.getNativeSize()], 0);
+ }
+
+ } else if (memory instanceof ArrayMemoryIO) {
+ ArrayMemoryIO arrayMemory = (ArrayMemoryIO) memory;
+ if (arrayMemory.arrayLength() < type.getNativeSize()) {
+ throw runtime.newRuntimeError("size of struct returned from callback too small");
+ }
+
+ buffer.setStructReturn(arrayMemory.array(), arrayMemory.arrayOffset());
+
+ } else {
+ throw runtime.newRuntimeError("struct return value has illegal backing memory");
+ }
+ } else if (value.isNil()) {
+ // Zero it out
+ buffer.setStructReturn(new byte[type.getNativeSize()], 0);
+
+ } else {
+ throw runtime.newTypeError(value, runtime.fastGetModule("FFI").fastGetClass("Struct"));
+ }
+
} else {
buffer.setLongReturn(0L);
+ throw runtime.newRuntimeError("unsupported return type from struct: " + type);
}
}
@@ -438,63 +485,81 @@ private static final void setReturnValue(Ruby runtime, Type type,
*/
private static final IRubyObject fromNative(Ruby runtime, Type type,
Closure.Buffer buffer, int index) {
- switch (type.getNativeType()) {
- case VOID:
- return runtime.getNil();
- case CHAR:
- return Util.newSigned8(runtime, buffer.getByte(index));
- case UCHAR:
- return Util.newUnsigned8(runtime, buffer.getByte(index));
- case SHORT:
- return Util.newSigned16(runtime, buffer.getShort(index));
- case USHORT:
- return Util.newUnsigned16(runtime, buffer.getShort(index));
- case INT:
- return Util.newSigned32(runtime, buffer.getInt(index));
- case UINT:
- return Util.newUnsigned32(runtime, buffer.getInt(index));
- case LONG_LONG:
- return Util.newSigned64(runtime, buffer.getLong(index));
- case ULONG_LONG:
- return Util.newUnsigned64(runtime, buffer.getLong(index));
-
- case LONG:
- return LONG_SIZE == 32
- ? Util.newSigned32(runtime, buffer.getInt(index))
- : Util.newSigned64(runtime, buffer.getLong(index));
- case ULONG:
- return LONG_SIZE == 32
- ? Util.newUnsigned32(runtime, buffer.getInt(index))
- : Util.newUnsigned64(runtime, buffer.getLong(index));
-
- case FLOAT:
- return runtime.newFloat(buffer.getFloat(index));
- case DOUBLE:
- return runtime.newFloat(buffer.getDouble(index));
- case POINTER: {
- final long address = buffer.getAddress(index);
- if (type instanceof CallbackInfo) {
- CallbackInfo cbInfo = (CallbackInfo) type;
- if (address != 0) {
- return new JFFIInvoker(runtime, address,
- cbInfo.getReturnType(), cbInfo.getParameterTypes(),
- cbInfo.isStdcall() ? CallingConvention.STDCALL : CallingConvention.DEFAULT);
- } else {
- return runtime.getNil();
- }
- } else {
- return new Pointer(runtime, NativeMemoryIO.wrap(runtime, address));
- }
+ if (type instanceof Type.Builtin) {
+ switch (type.getNativeType()) {
+ case VOID:
+ return runtime.getNil();
+ case CHAR:
+ return Util.newSigned8(runtime, buffer.getByte(index));
+ case UCHAR:
+ return Util.newUnsigned8(runtime, buffer.getByte(index));
+ case SHORT:
+ return Util.newSigned16(runtime, buffer.getShort(index));
+ case USHORT:
+ return Util.newUnsigned16(runtime, buffer.getShort(index));
+ case INT:
+ return Util.newSigned32(runtime, buffer.getInt(index));
+ case UINT:
+ return Util.newUnsigned32(runtime, buffer.getInt(index));
+ case LONG_LONG:
+ return Util.newSigned64(runtime, buffer.getLong(index));
+ case ULONG_LONG:
+ return Util.newUnsigned64(runtime, buffer.getLong(index));
+
+ case LONG:
+ return LONG_SIZE == 32
+ ? Util.newSigned32(runtime, buffer.getInt(index))
+ : Util.newSigned64(runtime, buffer.getLong(index));
+ case ULONG:
+ return LONG_SIZE == 32
+ ? Util.newUnsigned32(runtime, buffer.getInt(index))
+ : Util.newUnsigned64(runtime, buffer.getLong(index));
+
+ case FLOAT:
+ return runtime.newFloat(buffer.getFloat(index));
+ case DOUBLE:
+ return runtime.newFloat(buffer.getDouble(index));
+
+ case POINTER:
+ return new Pointer(runtime, NativeMemoryIO.wrap(runtime, buffer.getAddress(index)));
+
+ case STRING:
+ return getStringParameter(runtime, buffer, index);
+
+ case BOOL:
+ return runtime.newBoolean(buffer.getByte(index) != 0);
+
+ default:
+ throw runtime.newTypeError("invalid callback parameter type " + type);
}
- case STRING:
- return getStringParameter(runtime, buffer, index);
+
+ } else if (type instanceof CallbackInfo) {
+ final CallbackInfo cbInfo = (CallbackInfo) type;
+ final long address = buffer.getAddress(index);
+
+ return address != 0
+ ? new Function(runtime, cbInfo.getMetaClass(),
+ new CodeMemoryIO(runtime, address),
+ cbInfo.getReturnType(), cbInfo.getParameterTypes(),
+ cbInfo.isStdcall() ? CallingConvention.STDCALL : CallingConvention.DEFAULT, runtime.getNil())
+
+ : runtime.getNil();
+
+ } else if (type instanceof StructByValue) {
+ StructByValue sbv = (StructByValue) type;
+ final long address = buffer.getStruct(index);
+ DirectMemoryIO memory = address != 0
+ ? new BoundedNativeMemoryIO(runtime, address, type.getNativeSize())
+ : new NullMemoryIO(runtime);
- case BOOL:
- return runtime.newBoolean(buffer.getByte(index) != 0);
+ return sbv.getStructClass().newInstance(runtime.getCurrentContext(),
+ new IRubyObject[] { new Pointer(runtime, memory) },
+ Block.NULL_BLOCK);
- default:
- throw new IllegalArgumentException("Invalid type " + type);
+ } else {
+ throw runtime.newTypeError("unsupported callback parameter type: " + type);
}
+
}
/**
@@ -535,8 +600,12 @@ private static final boolean isReturnTypeValid(Type type) {
case BOOL:
return true;
}
+
} else if (type instanceof CallbackInfo) {
return true;
+
+ } else if (type instanceof StructByValue) {
+ return true;
}
return false;
}
@@ -569,6 +638,9 @@ private static final boolean isParameterTypeValid(Type type) {
}
} else if (type instanceof CallbackInfo) {
return true;
+
+ } else if (type instanceof StructByValue) {
+ return true;
}
return false;
}

0 comments on commit eff3e42

Please sign in to comment.