Skip to content

Commit

Permalink
Merge pull request #30 from d4l3k/close-tty
Browse files Browse the repository at this point in the history
if, defers, correct func scoping, struct{} support
  • Loading branch information
d4l3k committed Jul 17, 2017
2 parents 408aff5 + c343750 commit 911b948
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 55 deletions.
133 changes: 108 additions & 25 deletions pry/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,28 @@ type Scope struct {

isSelect bool
typeAssert reflect.Type
isFunction bool
defers []*Defer

sync.Mutex
}

type Defer struct {
fun ast.Expr
scope *Scope
arguments []interface{}
}

func (scope *Scope) Defer(d *Defer) error {
for ; scope != nil; scope = scope.Parent {
if scope.isFunction {
scope.defers = append(scope.defers, d)
return nil
}
}
return errors.New("defer: can't find function scope")
}

// NewScope creates a new initialized scope
func NewScope() *Scope {
s := &Scope{
Expand Down Expand Up @@ -211,38 +229,16 @@ func (scope *Scope) Interpret(expr ast.Node) (interface{}, error) {
return nil, fmt.Errorf("unknown field %#v", sel.Name)

case *ast.CallExpr:
fun, err := scope.Interpret(e.Fun)
if err != nil {
return nil, err
}

args := make([]reflect.Value, len(e.Args))
args := make([]interface{}, len(e.Args))
for i, arg := range e.Args {
interpretedArg, err := scope.Interpret(arg)
if err != nil {
return nil, err
}
args[i] = reflect.ValueOf(interpretedArg)
args[i] = interpretedArg
}

switch funV := fun.(type) {
case reflect.Type:
return args[0].Convert(funV).Interface(), nil
case *Func:
// TODO enforce func return values
return scope.Interpret(funV.Def.Body)
}

funVal := reflect.ValueOf(fun)

values := ValuesToInterfaces(funVal.Call(args))
if len(values) == 0 {
return nil, nil
} else if len(values) == 1 {
return values[0], nil
}
err, _ = values[1].(error)
return values[0], err
return scope.ExecuteFunc(e.Fun, args)

case *ast.GoStmt:
go func() {
Expand Down Expand Up @@ -816,11 +812,98 @@ func (scope *Scope) Interpret(expr ast.Node) (interface{}, error) {
}
return out, nil

case *ast.IfStmt:
currentScope := scope.NewChild()
if e.Init != nil {
if _, err := currentScope.Interpret(e.Init); err != nil {
return nil, err
}
}
cond, err := currentScope.Interpret(e.Cond)
if err != nil {
return nil, err
}
if cond == true {
return currentScope.Interpret(e.Body)
}
return currentScope.Interpret(e.Else)

case *ast.DeferStmt:
var args []interface{}
for _, arg := range e.Call.Args {
v, err := scope.Interpret(arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
scope.Defer(&Defer{
fun: e.Call.Fun,
scope: scope,
arguments: args,
})
return nil, nil

case *ast.StructType:
if len(e.Fields.List) > 0 {
return nil, errors.New("don't support non-empty structs yet")
}
return reflect.TypeOf(struct{}{}), nil

default:
return nil, fmt.Errorf("unknown node %#v", e)
}
}

func (scope *Scope) ExecuteFunc(funExpr ast.Expr, args []interface{}) (interface{}, error) {
fun, err := scope.Interpret(funExpr)
if err != nil {
return nil, err
}

switch funV := fun.(type) {
case reflect.Type:
return reflect.ValueOf(args[0]).Convert(funV).Interface(), nil
case *Func:
// TODO enforce func return values
currentScope := scope.NewChild()
i := 0
for _, arg := range funV.Def.Type.Params.List {
for _, name := range arg.Names {
currentScope.Set(name.Name, args[i])
i++
}
}
currentScope.isFunction = true
ret, err := currentScope.Interpret(funV.Def.Body)
if err != nil {
return nil, err
}
for i := len(currentScope.defers) - 1; i >= 0; i-- {
d := currentScope.defers[i]
if _, err := d.scope.ExecuteFunc(d.fun, d.arguments); err != nil {
return nil, err
}
}
return ret, nil
}

funVal := reflect.ValueOf(fun)

var valueArgs []reflect.Value
for _, v := range args {
valueArgs = append(valueArgs, reflect.ValueOf(v))
}
values := ValuesToInterfaces(funVal.Call(valueArgs))
if len(values) == 0 {
return nil, nil
} else if len(values) == 1 {
return values[0], nil
}
err, _ = values[1].(error)
return values[0], err
}

// ConfigureTypes configures the scope type checker
func (scope *Scope) ConfigureTypes(path string, line int) error {
scope.path = path
Expand Down
117 changes: 117 additions & 0 deletions pry/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,123 @@ func TestSwitchNone(t *testing.T) {
}
}

func TestIf(t *testing.T) {
t.Parallel()

scope := NewScope()

out, err := scope.InterpretString(`
a := 0
if true {
a = 1
} else {
a = 2
}
a
`)
if err != nil {
t.Error(err)
}
expected := 1
if !reflect.DeepEqual(expected, out) {
t.Errorf("Expected %#v got %#v.", expected, out)
}
}

func TestIfElse(t *testing.T) {
t.Parallel()

scope := NewScope()

out, err := scope.InterpretString(`
a := 0
if false {
a = 1
} else {
a = 2
}
a
`)
if err != nil {
t.Error(err)
}
expected := 2
if !reflect.DeepEqual(expected, out) {
t.Errorf("Expected %#v got %#v.", expected, out)
}
}

func TestIfIfElse(t *testing.T) {
t.Parallel()

scope := NewScope()

out, err := scope.InterpretString(`
a := 0
if false {
a = 1
} else if true {
a = 2
}
a
`)
if err != nil {
t.Error(err)
}
expected := 2
if !reflect.DeepEqual(expected, out) {
t.Errorf("Expected %#v got %#v.", expected, out)
}
}

func TestFunctionArgs(t *testing.T) {
t.Parallel()

scope := NewScope()

out, err := scope.InterpretString(`
f := func(b, c int) {
return b + c
}
f(10, 5)
`)
if err != nil {
t.Error(err)
}
expected := 15
if !reflect.DeepEqual(expected, out) {
t.Errorf("Expected %#v got %#v.", expected, out)
}
}

func TestDefer(t *testing.T) {
t.Parallel()

scope := NewScope()

out, err := scope.InterpretString(`
a := 0
f := func() {
defer func() {
a = 2
}()
defer func() {
a = 3
}()
a = 1
}
f()
a
`)
if err != nil {
t.Error(err)
}
expected := 2
if !reflect.DeepEqual(expected, out) {
t.Errorf("Expected %#v got %#v.", expected, out)
}
}

// TODO Packages

// TODO References
64 changes: 34 additions & 30 deletions pry/pry.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,7 @@ func Apply(scope *Scope) {
panic(err)
}

fmt.Fprintf(out, "\nFrom %s @ line %d :\n\n", filePathRaw, lineNum)
file, err := ioutil.ReadFile(filePath)
if err != nil {
fmt.Fprintln(out, err)
}
lines := strings.Split((string)(file), "\n")
lineNum--
start := lineNum - 5
if start < 0 {
start = 0
}
end := lineNum + 6
if end > len(lines) {
end = len(lines)
}
maxLen := len(fmt.Sprint(end))
for i := start; i < end; i++ {
caret := " "
if i == lineNum {
caret = "=>"
}
numStr := fmt.Sprint(i + 1)
if len(numStr) < maxLen {
numStr = " " + numStr
}
num := ansi.Color(numStr, "blue+b")
highlightedLine := Highlight(strings.Replace(lines[i], "\t", " ", -1))
fmt.Fprintf(out, " %s %s: %s\n", caret, num, highlightedLine)
}
fmt.Fprintln(out)
displayFilePosition(filePathRaw, filePath, lineNum)

history := []string{}
currentPos := 0
Expand Down Expand Up @@ -188,6 +159,39 @@ func Apply(scope *Scope) {
}
}

func displayFilePosition(filePathRaw, filePath string, lineNum int) {
fmt.Fprintf(out, "\nFrom %s @ line %d :\n\n", filePathRaw, lineNum)
file, err := ioutil.ReadFile(filePath)
if err != nil {
fmt.Fprintln(out, err)
}
lines := strings.Split((string)(file), "\n")
lineNum--
start := lineNum - 5
if start < 0 {
start = 0
}
end := lineNum + 6
if end > len(lines) {
end = len(lines)
}
maxLen := len(fmt.Sprint(end))
for i := start; i < end; i++ {
caret := " "
if i == lineNum {
caret = "=>"
}
numStr := fmt.Sprint(i + 1)
if len(numStr) < maxLen {
numStr = " " + numStr
}
num := ansi.Color(numStr, "blue+b")
highlightedLine := Highlight(strings.Replace(lines[i], "\t", " ", -1))
fmt.Fprintf(out, " %s %s: %s\n", caret, num, highlightedLine)
}
fmt.Fprintln(out)
}

// displaySuggestions renders the live autocomplete from GoCode.
func displaySuggestions(scope *Scope, line string, index, promptWidth int) {
// Suggestions
Expand Down

0 comments on commit 911b948

Please sign in to comment.