Skip to content

Commit

Permalink
Merge 364ff89 into 390aa2b
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Sep 20, 2018
2 parents 390aa2b + 364ff89 commit 178b1aa
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 21 deletions.
7 changes: 4 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func InjectPry(filePath string) (string, error) {
}

// GenerateFile generates and executes a temp file with the given imports
func GenerateFile(imports []string) error {
func GenerateFile(imports []string, extraStatements string) error {
dir, err := ioutil.TempDir("", "pry")
if err != nil {
return err
Expand All @@ -159,7 +159,7 @@ func GenerateFile(imports []string) error {
for _, imp := range imports {
file += fmt.Sprintf("\t%#v\n", imp)
}
file += ")\nfunc main() {\n\tpry.Pry()\n}"
file += ")\nfunc main() {\n\t" + extraStatements + "\n\tpry.Pry()\n}\n"

newPath := dir + "/main.go"
ioutil.WriteFile(newPath, []byte(file), 0644)
Expand All @@ -186,6 +186,7 @@ func main() {
// FLAGS
imports := flag.String("i", "fmt,math", "packages to import, comma seperated")
revert := flag.Bool("r", true, "whether to revert changes on exit")
execute := flag.String("e", "", "statements to execute")
flag.BoolVar(&debug, "d", false, "display debug statements")

flag.CommandLine.Usage = func() {
Expand All @@ -201,7 +202,7 @@ func main() {
flag.Parse()
cmdArgs := flag.Args()
if len(cmdArgs) == 0 {
err := GenerateFile(strings.Split(*imports, ","))
err := GenerateFile(strings.Split(*imports, ","), *execute)
if err != nil {
panic(err)
}
Expand Down
41 changes: 27 additions & 14 deletions pry/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,78 +7,91 @@ import (
"github.com/pkg/errors"
)

// InterpretError is an error returned by the interpreter and shouldn't be
// passed to the user or running code.
type InterpretError struct {
err error
}

func (a *InterpretError) Error() error {
if a == nil {
return nil
}
return a.err
}

// Append is a runtime replacement for the append function
func Append(arr interface{}, elems ...interface{}) (interface{}, error) {
func Append(arr interface{}, elems ...interface{}) (interface{}, *InterpretError) {
arrVal := reflect.ValueOf(arr)
valArr := make([]reflect.Value, len(elems))
for i, elem := range elems {
if reflect.TypeOf(arr) != reflect.SliceOf(reflect.TypeOf(elem)) {
return nil, fmt.Errorf("%T cannot append to %T", elem, arr)
return nil, &InterpretError{fmt.Errorf("%T cannot append to %T", elem, arr)}
}
valArr[i] = reflect.ValueOf(elem)
}
return reflect.Append(arrVal, valArr...).Interface(), nil
}

// Make is a runtime replacement for the make function
func Make(t interface{}, args ...interface{}) (interface{}, error) {
func Make(t interface{}, args ...interface{}) (interface{}, *InterpretError) {
typ, isType := t.(reflect.Type)
if !isType {
return nil, fmt.Errorf("invalid type %#v", t)
return nil, &InterpretError{fmt.Errorf("invalid type %#v", t)}
}
switch typ.Kind() {
case reflect.Slice:
if len(args) < 1 || len(args) > 2 {
return nil, errors.New("invalid number of arguments. Missing len or extra?")
return nil, &InterpretError{errors.New("invalid number of arguments. Missing len or extra?")}
}
length, isInt := args[0].(int)
if !isInt {
return nil, errors.New("len is not int")
return nil, &InterpretError{errors.New("len is not int")}
}
capacity := length
if len(args) == 2 {
capacity, isInt = args[0].(int)
if !isInt {
return nil, errors.New("len is not int")
return nil, &InterpretError{errors.New("len is not int")}
}
}
if length < 0 || capacity < 0 {
return nil, errors.Errorf("negative length or capacity")
return nil, &InterpretError{errors.Errorf("negative length or capacity")}
}
slice := reflect.MakeSlice(typ, length, capacity)
return slice.Interface(), nil

case reflect.Chan:
if len(args) > 1 {
fmt.Printf("CHAN ARGS %#v", args)
return nil, errors.New("too many arguments")
return nil, &InterpretError{errors.New("too many arguments")}
}
size := 0
if len(args) == 1 {
var isInt bool
size, isInt = args[0].(int)
if !isInt {
return nil, errors.New("size is not int")
return nil, &InterpretError{errors.New("size is not int")}
}
}
if size < 0 {
return nil, errors.Errorf("negative buffer size")
return nil, &InterpretError{errors.Errorf("negative buffer size")}
}
buffer := reflect.MakeChan(typ, size)
return buffer.Interface(), nil

default:
return nil, fmt.Errorf("unknown kind type %T", t)
return nil, &InterpretError{fmt.Errorf("unknown kind type %T", t)}
}
}

// Close is a runtime replacement for the "close" function.
func Close(t interface{}) (interface{}, error) {
func Close(t interface{}) (interface{}, *InterpretError) {
reflect.ValueOf(t).Close()
return nil, nil
}

// Len is a runtime replacement for the len function
func Len(t interface{}) (interface{}, error) {
func Len(t interface{}) (interface{}, *InterpretError) {
return reflect.ValueOf(t).Len(), nil
}
16 changes: 12 additions & 4 deletions pry/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,12 @@ func (scope *Scope) Interpret(expr ast.Node) (interface{}, error) {
rhs = rhs[:0]

for i := 0; i < rhsLen; i++ {
rhs[i] = rhsV.Index(i).Interface()
rhs = append(rhs, rhsV.Index(i).Interface())
}
}

if len(rhs) != len(e.Lhs) {
return nil, fmt.Errorf("assignment count mismatch: %d = %d", len(e.Lhs), len(rhs))
return nil, fmt.Errorf("assignment count mismatch: %d = %d (%+v)", len(e.Lhs), len(rhs), rhs)
}

for i, id := range e.Lhs {
Expand Down Expand Up @@ -1059,13 +1059,21 @@ func (scope *Scope) ExecuteFunc(funExpr ast.Expr, args []interface{}) (interface
return nil, errors.Errorf("number of arguments doesn't match function; expected %d; got %+v", funVal.Type().NumIn(), args)
}
values := ValuesToInterfaces(funVal.Call(valueArgs))
if len(values) > 0 {
if last, ok := values[len(values)-1].(*InterpretError); ok {
values = values[:len(values)-1]
if err := last.Error(); err != nil {
return nil, err
}
}
}

if len(values) == 0 {
return nil, nil
} else if len(values) == 1 {
return values[0], nil
}
err, _ = values[1].(error)
return values[0], err
return values, nil
}

// ConfigureTypes configures the scope type checker
Expand Down
23 changes: 23 additions & 0 deletions pry/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,29 @@ func TestAppend(t *testing.T) {
}
}

func TestMultiReturn(t *testing.T) {
t.Parallel()

scope := NewScope()
scope.Set("f", func() (int, error) {
return 0, nil
})

_, err := scope.InterpretString(`a, err := f()`)
if err != nil {
t.Error(err)
}
expected := 0
outV, found := scope.Get("a")
if !found {
t.Errorf("failed to find \"a\"")
}
out := outV.(int)
if !reflect.DeepEqual(expected, out) {
t.Errorf("Expected %#v got %#v.", expected, out)
}
}

func TestDeclareAssignVar(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 178b1aa

Please sign in to comment.