Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled extension types for modules #132

Merged
merged 6 commits into from Feb 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion ext/java/org/msgpack/jruby/Encoder.java
Expand Up @@ -7,6 +7,7 @@

import org.jruby.Ruby;
import org.jruby.RubyObject;
import org.jruby.RubyModule;
import org.jruby.RubyNil;
import org.jruby.RubyBoolean;
import org.jruby.RubyNumeric;
Expand Down Expand Up @@ -375,7 +376,15 @@ private void appendExtensionValue(ExtensionValue object) {

private void appendOther(IRubyObject object, IRubyObject destination) {
if (registry != null) {
IRubyObject[] pair = registry.lookupPackerByClass(object.getType());
RubyModule lookupClass;

if (object.getType() == runtime.getSymbol()) {
lookupClass = object.getType();
} else {
lookupClass = object.getSingletonClass();
}

IRubyObject[] pair = registry.lookupPackerByModule(lookupClass);
if (pair != null) {
RubyString bytes = pair[0].callMethod(runtime.getCurrentContext(), "call", object).asString();
int type = (int) ((RubyFixnum) pair[1]).getLongValue();
Expand Down
66 changes: 33 additions & 33 deletions ext/java/org/msgpack/jruby/ExtensionRegistry.java
Expand Up @@ -3,7 +3,7 @@
import org.jruby.Ruby;
import org.jruby.RubyHash;
import org.jruby.RubyArray;
import org.jruby.RubyClass;
import org.jruby.RubyModule;
import org.jruby.RubyFixnum;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
Expand All @@ -12,35 +12,35 @@
import java.util.HashMap;

public class ExtensionRegistry {
private final Map<RubyClass, ExtensionEntry> extensionsByClass;
private final Map<RubyClass, ExtensionEntry> extensionsByAncestor;
private final Map<RubyModule, ExtensionEntry> extensionsByModule;
private final Map<RubyModule, ExtensionEntry> extensionsByAncestor;
private final ExtensionEntry[] extensionsByTypeId;

public ExtensionRegistry() {
this(new HashMap<RubyClass, ExtensionEntry>());
this(new HashMap<RubyModule, ExtensionEntry>());
}

private ExtensionRegistry(Map<RubyClass, ExtensionEntry> extensionsByClass) {
this.extensionsByClass = new HashMap<RubyClass, ExtensionEntry>(extensionsByClass);
this.extensionsByAncestor = new HashMap<RubyClass, ExtensionEntry>();
private ExtensionRegistry(Map<RubyModule, ExtensionEntry> extensionsByModule) {
this.extensionsByModule = new HashMap<RubyModule, ExtensionEntry>(extensionsByModule);
this.extensionsByAncestor = new HashMap<RubyModule, ExtensionEntry>();
this.extensionsByTypeId = new ExtensionEntry[256];
for (ExtensionEntry entry : extensionsByClass.values()) {
for (ExtensionEntry entry : extensionsByModule.values()) {
if (entry.hasUnpacker()) {
extensionsByTypeId[entry.getTypeId() + 128] = entry;
}
}
}

public ExtensionRegistry dup() {
return new ExtensionRegistry(extensionsByClass);
return new ExtensionRegistry(extensionsByModule);
}

public IRubyObject toInternalPackerRegistry(ThreadContext ctx) {
RubyHash hash = RubyHash.newHash(ctx.getRuntime());
for (RubyClass extensionClass : extensionsByClass.keySet()) {
ExtensionEntry entry = extensionsByClass.get(extensionClass);
for (RubyModule extensionModule : extensionsByModule.keySet()) {
ExtensionEntry entry = extensionsByModule.get(extensionModule);
if (entry.hasPacker()) {
hash.put(extensionClass, entry.toPackerTuple(ctx));
hash.put(extensionModule, entry.toPackerTuple(ctx));
}
}
return hash;
Expand All @@ -58,9 +58,9 @@ public IRubyObject toInternalUnpackerRegistry(ThreadContext ctx) {
return hash;
}

public void put(RubyClass cls, int typeId, IRubyObject packerProc, IRubyObject packerArg, IRubyObject unpackerProc, IRubyObject unpackerArg) {
ExtensionEntry entry = new ExtensionEntry(cls, typeId, packerProc, packerArg, unpackerProc, unpackerArg);
extensionsByClass.put(cls, entry);
public void put(RubyModule mod, int typeId, IRubyObject packerProc, IRubyObject packerArg, IRubyObject unpackerProc, IRubyObject unpackerArg) {
ExtensionEntry entry = new ExtensionEntry(mod, typeId, packerProc, packerArg, unpackerProc, unpackerArg);
extensionsByModule.put(mod, entry);
extensionsByTypeId[typeId + 128] = entry;
extensionsByAncestor.clear();
}
Expand All @@ -74,54 +74,54 @@ public IRubyObject lookupUnpackerByTypeId(int typeId) {
}
}

public IRubyObject[] lookupPackerByClass(RubyClass cls) {
ExtensionEntry e = extensionsByClass.get(cls);
public IRubyObject[] lookupPackerByModule(RubyModule mod) {
ExtensionEntry e = extensionsByModule.get(mod);
if (e == null) {
e = extensionsByAncestor.get(cls);
e = extensionsByAncestor.get(mod);
}
if (e == null) {
e = findEntryByClassOrAncestor(cls);
e = findEntryByModuleOrAncestor(mod);
if (e != null) {
extensionsByAncestor.put(e.getExtensionClass(), e);
extensionsByAncestor.put(e.getExtensionModule(), e);
}
}
if (e != null && e.hasPacker()) {
return e.toPackerProcTypeIdPair(cls.getRuntime().getCurrentContext());
return e.toPackerProcTypeIdPair(mod.getRuntime().getCurrentContext());
} else {
return null;
}
}

private ExtensionEntry findEntryByClassOrAncestor(final RubyClass cls) {
ThreadContext ctx = cls.getRuntime().getCurrentContext();
for (RubyClass extensionClass : extensionsByClass.keySet()) {
RubyArray ancestors = (RubyArray) cls.callMethod(ctx, "ancestors");
if (ancestors.callMethod(ctx, "include?", extensionClass).isTrue()) {
return extensionsByClass.get(extensionClass);
private ExtensionEntry findEntryByModuleOrAncestor(final RubyModule mod) {
ThreadContext ctx = mod.getRuntime().getCurrentContext();
for (RubyModule extensionModule : extensionsByModule.keySet()) {
RubyArray ancestors = (RubyArray) mod.callMethod(ctx, "ancestors");
if (ancestors.callMethod(ctx, "include?", extensionModule).isTrue()) {
return extensionsByModule.get(extensionModule);
}
}
return null;
}

private static class ExtensionEntry {
private final RubyClass cls;
private final RubyModule mod;
private final int typeId;
private final IRubyObject packerProc;
private final IRubyObject packerArg;
private final IRubyObject unpackerProc;
private final IRubyObject unpackerArg;

public ExtensionEntry(RubyClass cls, int typeId, IRubyObject packerProc, IRubyObject packerArg, IRubyObject unpackerProc, IRubyObject unpackerArg) {
this.cls = cls;
public ExtensionEntry(RubyModule mod, int typeId, IRubyObject packerProc, IRubyObject packerArg, IRubyObject unpackerProc, IRubyObject unpackerArg) {
this.mod = mod;
this.typeId = typeId;
this.packerProc = packerProc;
this.packerArg = packerArg;
this.unpackerProc = unpackerProc;
this.unpackerArg = unpackerArg;
}

public RubyClass getExtensionClass() {
return cls;
public RubyModule getExtensionModule() {
return mod;
}

public int getTypeId() {
Expand Down Expand Up @@ -149,7 +149,7 @@ public RubyArray toPackerTuple(ThreadContext ctx) {
}

public RubyArray toUnpackerTuple(ThreadContext ctx) {
return RubyArray.newArray(ctx.getRuntime(), new IRubyObject[] {cls, unpackerProc, unpackerArg});
return RubyArray.newArray(ctx.getRuntime(), new IRubyObject[] {mod, unpackerProc, unpackerArg});
}

public IRubyObject[] toPackerProcTypeIdPair(ThreadContext ctx) {
Expand Down
15 changes: 8 additions & 7 deletions ext/java/org/msgpack/jruby/Factory.java
Expand Up @@ -2,6 +2,7 @@


import org.jruby.Ruby;
import org.jruby.RubyModule;
import org.jruby.RubyClass;
import org.jruby.RubyObject;
import org.jruby.RubyArray;
Expand Down Expand Up @@ -70,7 +71,7 @@ public IRubyObject registeredTypesInternal(ThreadContext ctx) {
public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args) {
Ruby runtime = ctx.getRuntime();
IRubyObject type = args[0];
IRubyObject klass = args[1];
IRubyObject mod = args[1];

IRubyObject packerArg;
IRubyObject unpackerArg;
Expand All @@ -94,10 +95,10 @@ public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args) {
throw runtime.newRangeError(String.format("integer %d too big to convert to `signed char'", typeId));
}

if (!(klass instanceof RubyClass)) {
throw runtime.newArgumentError(String.format("expected Class but found %s.", klass.getType().getName()));
if (!(mod instanceof RubyModule)) {
throw runtime.newArgumentError(String.format("expected Module/Class but found %s.", mod.getType().getName()));
}
RubyClass extClass = (RubyClass) klass;
RubyModule extModule = (RubyModule) mod;

IRubyObject packerProc = runtime.getNil();
IRubyObject unpackerProc = runtime.getNil();
Expand All @@ -106,15 +107,15 @@ public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args) {
}
if (unpackerArg != null) {
if (unpackerArg instanceof RubyString || unpackerArg instanceof RubySymbol) {
unpackerProc = extClass.method(unpackerArg.callMethod(ctx, "to_sym"));
unpackerProc = extModule.method(unpackerArg.callMethod(ctx, "to_sym"));
} else {
unpackerProc = unpackerArg.callMethod(ctx, "method", runtime.newSymbol("call"));
}
}

extensionRegistry.put(extClass, (int) typeId, packerProc, packerArg, unpackerProc, unpackerArg);
extensionRegistry.put(extModule, (int) typeId, packerProc, packerArg, unpackerProc, unpackerArg);

if (extClass == runtime.getSymbol()) {
if (extModule == runtime.getSymbol()) {
hasSymbolExtType = true;
}

Expand Down
13 changes: 7 additions & 6 deletions ext/java/org/msgpack/jruby/Packer.java
Expand Up @@ -2,6 +2,7 @@


import org.jruby.Ruby;
import org.jruby.RubyModule;
import org.jruby.RubyClass;
import org.jruby.RubyObject;
import org.jruby.RubyArray;
Expand Down Expand Up @@ -78,7 +79,7 @@ public IRubyObject registeredTypesInternal(ThreadContext ctx) {
public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args, final Block block) {
Ruby runtime = ctx.getRuntime();
IRubyObject type = args[0];
IRubyObject klass = args[1];
IRubyObject mod = args[1];

IRubyObject arg;
IRubyObject proc;
Expand All @@ -100,14 +101,14 @@ public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args, final Blo
throw runtime.newRangeError(String.format("integer %d too big to convert to `signed char'", typeId));
}

if (!(klass instanceof RubyClass)) {
throw runtime.newArgumentError(String.format("expected Class but found %s.", klass.getType().getName()));
if (!(mod instanceof RubyModule)) {
throw runtime.newArgumentError(String.format("expected Module/Class but found %s.", mod.getType().getName()));
}
RubyClass extClass = (RubyClass) klass;
RubyModule extModule = (RubyModule) mod;

registry.put(extClass, (int) typeId, proc, arg, null, null);
registry.put(extModule, (int) typeId, proc, arg, null, null);

if (extClass == runtime.getSymbol()) {
if (extModule == runtime.getSymbol()) {
encoder.hasSymbolExtType = true;
}

Expand Down
11 changes: 6 additions & 5 deletions ext/java/org/msgpack/jruby/Unpacker.java
Expand Up @@ -3,6 +3,7 @@
import java.util.Arrays;

import org.jruby.Ruby;
import org.jruby.RubyModule;
import org.jruby.RubyClass;
import org.jruby.RubyString;
import org.jruby.RubyObject;
Expand Down Expand Up @@ -100,7 +101,7 @@ public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args, final Blo
Ruby runtime = ctx.getRuntime();
IRubyObject type = args[0];

RubyClass extClass;
RubyModule extModule;
IRubyObject arg;
IRubyObject proc;
if (args.length == 1) {
Expand All @@ -111,11 +112,11 @@ public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args, final Blo
if (proc == null)
System.err.println("proc from Block is null");
arg = proc;
extClass = null;
extModule = null;
} else if (args.length == 3) {
extClass = (RubyClass) args[1];
extModule = (RubyModule) args[1];
arg = args[2];
proc = extClass.method(arg);
proc = extModule.method(arg);
} else {
throw runtime.newArgumentError(String.format("wrong number of arguments (%d for 1 or 3)", 2 + args.length));
}
Expand All @@ -125,7 +126,7 @@ public IRubyObject registerType(ThreadContext ctx, IRubyObject[] args, final Blo
throw runtime.newRangeError(String.format("integer %d too big to convert to `signed char'", typeId));
}

registry.put(extClass, (int) typeId, null, null, proc, arg);
registry.put(extModule, (int) typeId, null, null, proc, arg);
return runtime.getNil();
}

Expand Down
16 changes: 8 additions & 8 deletions ext/msgpack/factory_class.c
Expand Up @@ -141,7 +141,7 @@ static VALUE Factory_register_type(int argc, VALUE* argv, VALUE self)
FACTORY(self, fc);

int ext_type;
VALUE ext_class;
VALUE ext_module;
VALUE options;
VALUE packer_arg, unpacker_arg;
VALUE packer_proc, unpacker_proc;
Expand Down Expand Up @@ -170,9 +170,9 @@ static VALUE Factory_register_type(int argc, VALUE* argv, VALUE self)
rb_raise(rb_eRangeError, "integer %d too big to convert to `signed char'", ext_type);
}

ext_class = argv[1];
if(rb_type(ext_class) != T_CLASS) {
rb_raise(rb_eArgError, "expected Class but found %s.", rb_obj_classname(ext_class));
ext_module = argv[1];
if(rb_type(ext_module) != T_MODULE && rb_type(ext_module) != T_CLASS) {
rb_raise(rb_eArgError, "expected Module/Class but found %s.", rb_obj_classname(ext_module));
}

packer_proc = Qnil;
Expand All @@ -184,19 +184,19 @@ static VALUE Factory_register_type(int argc, VALUE* argv, VALUE self)

if(unpacker_arg != Qnil) {
if(rb_type(unpacker_arg) == T_SYMBOL || rb_type(unpacker_arg) == T_STRING) {
unpacker_proc = rb_obj_method(ext_class, unpacker_arg);
unpacker_proc = rb_obj_method(ext_module, unpacker_arg);
} else {
unpacker_proc = rb_funcall(unpacker_arg, rb_intern("method"), 1, ID2SYM(rb_intern("call")));
}
}

msgpack_packer_ext_registry_put(&fc->pkrg, ext_class, ext_type, packer_proc, packer_arg);
msgpack_packer_ext_registry_put(&fc->pkrg, ext_module, ext_type, packer_proc, packer_arg);

if (ext_class == rb_cSymbol) {
if (ext_module == rb_cSymbol) {
fc->has_symbol_ext_type = true;
}

msgpack_unpacker_ext_registry_put(&fc->ukrg, ext_class, ext_type, unpacker_proc, unpacker_arg);
msgpack_unpacker_ext_registry_put(&fc->ukrg, ext_module, ext_type, unpacker_proc, unpacker_arg);

return Qnil;
}
Expand Down
23 changes: 21 additions & 2 deletions ext/msgpack/packer.c
Expand Up @@ -124,8 +124,27 @@ void msgpack_packer_write_hash_value(msgpack_packer_t* pk, VALUE v)
void msgpack_packer_write_other_value(msgpack_packer_t* pk, VALUE v)
{
int ext_type;
VALUE proc = msgpack_packer_ext_registry_lookup(&pk->ext_registry,
rb_obj_class(v), &ext_type);

VALUE lookup_class;

/*
* Objects of type Integer (Fixnum, Bignum), Float, Symbol and frozen
* String have no singleton class and raise a TypeError when trying to get
* it. See implementation of #singleton_class in ruby's source code:
* VALUE rb_singleton_class(VALUE obj);
*
* Since all but symbols are already filtered out when reaching this code
* only symbols are checked here.
*/
if (SYMBOL_P(v)) {
lookup_class = rb_obj_class(v);
} else {
lookup_class = rb_singleton_class(v);
}

VALUE proc = msgpack_packer_ext_registry_lookup(&pk->ext_registry, lookup_class,
&ext_type);

if(proc != Qnil) {
VALUE payload = rb_funcall(proc, s_call, 1, v);
StringValue(payload);
Expand Down