Skip to content

Commit

Permalink
bind,internal/importers: add Unwrap methods to unwrap Java wrappers
Browse files Browse the repository at this point in the history
For Java classes implemented in Go, it is useful to take a Java instance
and extract its wrapped Go instance. For example, consider the
java.lang.Runnable implementation wrapping a Go function:

package somepkg

type GoRunnable struct {
    lang.Runnable
    f func()
}

Java methods that take a java.lang.Runnable cannot directly take a
*GoRunnable, so this CL adds a Unwrap method:

import gorun "Java/somepkg/GoRunnable"

...

r := gorun.New()
r.Unwrap().(*GoRunnable).f = func() { ... }
javapkg.Run(r)

The extra interface conversion is unfortunately needed to avoid
import cycles.

Change-Id: Ib775a5712cd25aa75a19d364a55d76b1e11dce77
Reviewed-on: https://go-review.googlesource.com/35295
Reviewed-by: David Crawshaw <crawshaw@golang.org>
  • Loading branch information
Elias Naur committed Jan 18, 2017
1 parent a0f998b commit c243211
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 61 deletions.
17 changes: 7 additions & 10 deletions bind/bind_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"log"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strings"
Expand Down Expand Up @@ -65,6 +66,10 @@ func fileRefs(t *testing.T, filename string, pkgPrefix string) *importers.Refere
if err != nil {
t.Fatalf("%s: %v", filename, err)
}
fakePath := path.Dir(filename)
for i := range refs.Embedders {
refs.Embedders[i].PkgPath = fakePath
}
return refs
}

Expand Down Expand Up @@ -299,11 +304,7 @@ func TestGenJava(t *testing.T) {
Buf: new(bytes.Buffer),
},
}
var genNames []string
for _, emb := range refs.Embedders {
genNames = append(genNames, emb.Pkg+"."+emb.Name)
}
cg.Init(classes, genNames)
cg.Init(classes, refs.Embedders)
genJavaPackages(t, tmpGopath, cg)
cg.Buf = &buf
}
Expand Down Expand Up @@ -419,11 +420,7 @@ func TestGenGoJavaWrappers(t *testing.T) {
Buf: &buf,
},
}
var genNames []string
for _, emb := range refs.Embedders {
genNames = append(genNames, emb.Pkg+"."+emb.Name)
}
cg.Init(classes, genNames)
cg.Init(classes, refs.Embedders)
genJavaPackages(t, tmpGopath, cg)
pkg := typeCheck(t, filename, tmpGopath)
cg.GenGo()
Expand Down
64 changes: 48 additions & 16 deletions bind/genclasses.go
Expand Up @@ -12,6 +12,7 @@ import (
"unicode"
"unicode/utf8"

"golang.org/x/mobile/internal/importers"
"golang.org/x/mobile/internal/importers/java"
)

Expand All @@ -25,6 +26,9 @@ type (
// will work.
ClassGen struct {
*Printer
// JavaPkg is the Java package prefix for the generated classes. The prefix is prepended to the Go
// package name to create the full Java package name.
JavaPkg string
imported map[string]struct{}
// The list of imported Java classes
classes []*java.Class
Expand All @@ -35,8 +39,12 @@ type (
// For each Go package path, the Java class with static functions
// or constants.
clsPkgs map[string]*java.Class
// supers is the map of classes that need Super methods
supers map[string]struct{}
// goClsMap is the map of Java class names to Go type names, qualified with package name. Go types
// that implement Java classes need Super methods and Unwrap methods.
goClsMap map[string]string
// goClsImports is the list of imports of user packages that contains the Go types implementing Java
// classes.
goClsImports []string
}
)

Expand Down Expand Up @@ -110,12 +118,22 @@ func (g *ClassGen) goType(t *java.Type, local bool) string {
}

// Init initializes the class wrapper generator. Classes is the
// list of classes to wrap, supers is the list of class names
// that need Super methods.
func (g *ClassGen) Init(classes []*java.Class, supers []string) {
g.supers = make(map[string]struct{})
for _, s := range supers {
g.supers[s] = struct{}{}
// list of classes to wrap, goClasses is the list of Java classes
// implemented in Go.
func (g *ClassGen) Init(classes []*java.Class, goClasses []importers.Struct) {
g.goClsMap = make(map[string]string)
impMap := make(map[string]struct{})
for _, s := range goClasses {
n := s.Pkg + "." + s.Name
jn := n
if g.JavaPkg != "" {
jn = g.JavaPkg + "." + jn
}
g.goClsMap[jn] = n
if _, exists := impMap[s.PkgPath]; !exists {
impMap[s.PkgPath] = struct{}{}
g.goClsImports = append(g.goClsImports, s.PkgPath)
}
}
g.classes = classes
g.imported = make(map[string]struct{})
Expand Down Expand Up @@ -194,6 +212,9 @@ func (g *ClassGen) GenGo() {
pkgName := strings.Replace(cls.Name, ".", "/", -1)
g.Printf("import %q\n", "Java/"+pkgName)
}
for _, imp := range g.goClsImports {
g.Printf("import %q\n", imp)
}
if len(g.classes) > 0 {
g.Printf("import \"unsafe\"\n\n")
g.Printf("import \"reflect\"\n\n")
Expand Down Expand Up @@ -235,7 +256,7 @@ func (g *ClassGen) GenH() {
g.Printf("extern ")
g.genCMethodDecl("cproxy", cls.JNIName, f)
g.Printf(";\n")
if _, ok := g.supers[cls.Name]; ok {
if _, ok := g.goClsMap[cls.Name]; ok {
g.Printf("extern ")
g.genCMethodDecl("csuper", cls.JNIName, f)
g.Printf(";\n")
Expand All @@ -252,7 +273,7 @@ func (g *ClassGen) GenC() {
g.Printf(classesCHeader)
for _, cls := range g.classes {
g.Printf("static jclass class_%s;\n", cls.JNIName)
if _, ok := g.supers[cls.Name]; ok {
if _, ok := g.goClsMap[cls.Name]; ok {
g.Printf("static jclass sclass_%s;\n", cls.JNIName)
}
for _, fs := range cls.Funcs {
Expand All @@ -267,7 +288,7 @@ func (g *ClassGen) GenC() {
for _, f := range fs.Funcs {
if g.isFuncSupported(f) {
g.Printf("static jmethodID m_%s_%s;\n", cls.JNIName, f.JNIName)
if _, ok := g.supers[cls.Name]; ok {
if _, ok := g.goClsMap[cls.Name]; ok {
g.Printf("static jmethodID sm_%s_%s;\n", cls.JNIName, f.JNIName)
}
}
Expand All @@ -283,7 +304,7 @@ func (g *ClassGen) GenC() {
for _, cls := range g.classes {
g.Printf("clazz = (*env)->FindClass(env, %q);\n", strings.Replace(cls.FindName, ".", "/", -1))
g.Printf("class_%s = (*env)->NewGlobalRef(env, clazz);\n", cls.JNIName)
if _, ok := g.supers[cls.Name]; ok {
if _, ok := g.goClsMap[cls.Name]; ok {
g.Printf("sclass_%s = (*env)->GetSuperclass(env, clazz);\n", cls.JNIName)
g.Printf("sclass_%s = (*env)->NewGlobalRef(env, sclass_%s);\n", cls.JNIName, cls.JNIName)
}
Expand All @@ -304,7 +325,7 @@ func (g *ClassGen) GenC() {
for _, f := range fs.Funcs {
if g.isFuncSupported(f) {
g.Printf("m_%s_%s = go_seq_get_method_id(clazz, %q, %q);\n", cls.JNIName, f.JNIName, f.Name, f.Desc)
if _, ok := g.supers[cls.Name]; ok {
if _, ok := g.goClsMap[cls.Name]; ok {
g.Printf("sm_%s_%s = go_seq_get_method_id(sclass_%s, %q, %q);\n", cls.JNIName, f.JNIName, cls.JNIName, f.Name, f.Desc)
}
}
Expand All @@ -322,7 +343,7 @@ func (g *ClassGen) GenC() {
}
g.genCMethodDecl("cproxy", cls.JNIName, f)
g.genCMethodBody(cls, f, false)
if _, ok := g.supers[cls.Name]; ok {
if _, ok := g.goClsMap[cls.Name]; ok {
g.genCMethodDecl("csuper", cls.JNIName, f)
g.genCMethodBody(cls, f, true)
}
Expand Down Expand Up @@ -561,11 +582,17 @@ func (g *ClassGen) genGo(cls *java.Class) {
g.Printf(" return p.ToString()\n")
g.Printf("}\n")
}
if _, ok := g.supers[cls.Name]; ok {
if goName, ok := g.goClsMap[cls.Name]; ok {
g.Printf("func (p *proxy_class_%s) Super() Java.%s {\n", cls.JNIName, goClsName(cls.Name))
g.Printf(" return &super_%s{p}\n", cls.JNIName)
g.Printf("}\n\n")
g.Printf("type super_%s struct {*proxy_class_%[1]s}\n\n", cls.JNIName)
g.Printf("func (p *proxy_class_%s) Unwrap() interface{} {\n", cls.JNIName)
g.Indent()
g.Printf("goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))\n")
g.Printf("return _seq.FromRefNum(int32(goRefnum)).Get().(*%s)\n", goName)
g.Outdent()
g.Printf("}\n\n")
for _, fs := range cls.AllMethods {
if !g.isFuncSetSupported(fs) {
continue
Expand Down Expand Up @@ -847,8 +874,13 @@ func (g *ClassGen) genInterface(cls *java.Class) {
g.genFuncDecl(true, fs)
g.Printf("\n")
}
if _, ok := g.supers[cls.Name]; ok {
if goName, ok := g.goClsMap[cls.Name]; ok {
g.Printf("Super() %s\n", goClsName(cls.Name))
g.Printf("// Unwrap returns the Go object this Java instance\n")
g.Printf("// is wrapping.\n")
g.Printf("// The return value is a %s, but the delclared type is\n", goName)
g.Printf("// interface{} to avoid import cycles.\n")
g.Printf("Unwrap() interface{}\n")
}
if cls.Throwable {
g.Printf("Error() string\n")
Expand Down
6 changes: 6 additions & 0 deletions bind/java/ClassesTest.java
Expand Up @@ -18,6 +18,7 @@
import javapkg.GoRunnable;
import javapkg.GoSubset;
import javapkg.GoInputStream;
import javapkg.GoArrayList;

public class ClassesTest extends InstrumentationTestCase {
public void testConst() {
Expand Down Expand Up @@ -148,4 +149,9 @@ public void testCast() {
Runnable r4c = Javapkg.castRunnable(new Object());
assertTrue("Invalid cast", r4c == null);
}

public void testUnwrap() {
GoArrayList l = new GoArrayList();
Javapkg.unwrapGoArrayList(l);
}
}
3 changes: 3 additions & 0 deletions bind/java/seq.h
Expand Up @@ -34,6 +34,9 @@ typedef jlong nint;

extern void go_seq_dec_ref(int32_t ref);
extern void go_seq_inc_ref(int32_t ref);
// go_seq_unwrap takes a reference number to a Java wrapper and returns
// a reference number to its wrapped Go object.
extern int32_t go_seq_unwrap(jint refnum);
extern int32_t go_seq_to_refnum(JNIEnv *env, jobject o);
extern int32_t go_seq_to_refnum_go(JNIEnv *env, jobject o);
extern jobject go_seq_from_refnum(JNIEnv *env, int32_t refnum, jclass proxy_class, jmethodID proxy_cons);
Expand Down
8 changes: 8 additions & 0 deletions bind/java/seq_android.c.support
Expand Up @@ -233,6 +233,14 @@ int32_t go_seq_to_refnum(JNIEnv *env, jobject o) {
return (int32_t)(*env)->CallStaticIntMethod(env, seq_class, seq_incRef, o);
}

int32_t go_seq_unwrap(jint refnum) {
JNIEnv *env = go_seq_push_local_frame(0);
jobject jobj = go_seq_from_refnum(env, refnum, NULL, NULL);
int32_t goref = go_seq_to_refnum_go(env, jobj);
go_seq_pop_local_frame(env);
return goref;
}

jobject go_seq_from_refnum(JNIEnv *env, int32_t refnum, jclass proxy_class, jmethodID proxy_cons) {
if (refnum == NULL_REFNUM) {
return NULL;
Expand Down
41 changes: 41 additions & 0 deletions bind/testdata/classes.go.golden
Expand Up @@ -475,22 +475,42 @@ type Java_lang_System interface {
type Java_Future interface {
Get(a0 ...interface{}) (Java_lang_Object, error)
Super() Java_Future
// Unwrap returns the Go object this Java instance
// is wrapping.
// The return value is a java.Future, but the delclared type is
// interface{} to avoid import cycles.
Unwrap() interface{}
}

type Java_InputStream interface {
Read(a0 ...interface{}) (int32, error)
ToString() string
Super() Java_InputStream
// Unwrap returns the Go object this Java instance
// is wrapping.
// The return value is a java.InputStream, but the delclared type is
// interface{} to avoid import cycles.
Unwrap() interface{}
}

type Java_Object interface {
ToString() string
Super() Java_Object
// Unwrap returns the Go object this Java instance
// is wrapping.
// The return value is a java.Object, but the delclared type is
// interface{} to avoid import cycles.
Unwrap() interface{}
}

type Java_Runnable interface {
Run()
Super() Java_Runnable
// Unwrap returns the Go object this Java instance
// is wrapping.
// The return value is a java.Runnable, but the delclared type is
// interface{} to avoid import cycles.
Unwrap() interface{}
}

type Java_util_Iterator interface {
Expand Down Expand Up @@ -559,6 +579,7 @@ import "Java/java/util/Spliterator/OfLong"
import "Java/java/util/PrimitiveIterator/OfDouble"
import "Java/java/util/Spliterator/OfDouble"
import "Java/java/io/Console"
import "testdata"
import "unsafe"

import "reflect"
Expand Down Expand Up @@ -1213,6 +1234,11 @@ func (p *proxy_class_java_Future) Super() Java.Java_Future {

type super_java_Future struct {*proxy_class_java_Future}

func (p *proxy_class_java_Future) Unwrap() interface{} {
goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
return _seq.FromRefNum(int32(goRefnum)).Get().(*java.Future)
}

func (p *super_java_Future) Get(a0 ...interface{}) (Java.Java_lang_Object, error) {
switch 0 + len(a0) {
case 0:
Expand Down Expand Up @@ -1374,6 +1400,11 @@ func (p *proxy_class_java_InputStream) Super() Java.Java_InputStream {

type super_java_InputStream struct {*proxy_class_java_InputStream}

func (p *proxy_class_java_InputStream) Unwrap() interface{} {
goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
return _seq.FromRefNum(int32(goRefnum)).Get().(*java.InputStream)
}

func (p *super_java_InputStream) Read(a0 ...interface{}) (int32, error) {
switch 0 + len(a0) {
case 0:
Expand Down Expand Up @@ -1494,6 +1525,11 @@ func (p *proxy_class_java_Object) Super() Java.Java_Object {

type super_java_Object struct {*proxy_class_java_Object}

func (p *proxy_class_java_Object) Unwrap() interface{} {
goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
return _seq.FromRefNum(int32(goRefnum)).Get().(*java.Object)
}

func (p *super_java_Object) ToString() string {
res := C.csuper_java_Object_toString(C.jint(p.Bind_proxy_refnum__()))
_res := decodeString(res.res)
Expand Down Expand Up @@ -1555,6 +1591,11 @@ func (p *proxy_class_java_Runnable) Super() Java.Java_Runnable {

type super_java_Runnable struct {*proxy_class_java_Runnable}

func (p *proxy_class_java_Runnable) Unwrap() interface{} {
goRefnum := C.go_seq_unwrap(C.jint(p.Bind_proxy_refnum__()))
return _seq.FromRefNum(int32(goRefnum)).Get().(*java.Runnable)
}

func (p *super_java_Runnable) Run() {
res := C.csuper_java_Runnable_run(C.jint(p.Bind_proxy_refnum__()))
var _exc error
Expand Down
4 changes: 4 additions & 0 deletions bind/testpkg/javapkg/classes.go
Expand Up @@ -138,6 +138,10 @@ func NewGoArrayListWithCap(_ int32) *GoArrayList {
return new(GoArrayList)
}

func UnwrapGoArrayList(l gopkg.GoArrayList) {
_ = l.Unwrap().(*GoArrayList)
}

func CallSubset(s Character.Subset) {
s.ToString()
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/gobind/gen.go
Expand Up @@ -20,6 +20,7 @@ import (
"unicode/utf8"

"golang.org/x/mobile/bind"
"golang.org/x/mobile/internal/importers"
"golang.org/x/mobile/internal/importers/java"
)

Expand Down Expand Up @@ -140,15 +141,16 @@ func genPkg(p *types.Package, allPkg []*types.Package, classes []*java.Class) {
}
}

func genJavaPackages(ctx *build.Context, dir string, classes []*java.Class, genNames []string) error {
func genJavaPackages(ctx *build.Context, dir string, classes []*java.Class, embedders []importers.Struct) error {
var buf bytes.Buffer
cg := &bind.ClassGen{
JavaPkg: *javaPkg,
Printer: &bind.Printer{
IndentEach: []byte("\t"),
Buf: &buf,
},
}
cg.Init(classes, genNames)
cg.Init(classes, embedders)
pkgBase := filepath.Join(dir, "src", "Java")
if err := os.MkdirAll(pkgBase, 0700); err != nil {
return err
Expand Down

0 comments on commit c243211

Please sign in to comment.