diff --git a/interp/interp.go b/interp/interp.go index a4afd1fd..1355bfb0 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -9,7 +9,6 @@ package interp import ( "bufio" - "context" "errors" "fmt" "io" @@ -79,7 +78,10 @@ type interp struct { noExec bool noFileWrites bool noFileReads bool - timeout time.Duration + // limited the amount of runtime + runtimeLimit time.Duration + // Time of whn + startTime time.Time // Scalars, arrays, and function state globals []value @@ -196,7 +198,7 @@ type Config struct { NoFileWrites bool NoFileReads bool - Timeout time.Duration + RuntimeLimit time.Duration } // ExecProgram executes the parsed program using the given interpreter @@ -234,7 +236,10 @@ func ExecProgram(program *Program, config *Config) (int, error) { p.noExec = config.NoExec p.noFileWrites = config.NoFileWrites p.noFileReads = config.NoFileReads - p.timeout = config.Timeout + p.runtimeLimit = config.RuntimeLimit + if p.runtimeLimit > 0 { + p.startTime = time.Now() + } err := p.initNativeFuncs(config.Funcs) if err != nil { return 0, err @@ -277,26 +282,6 @@ func ExecProgram(program *Program, config *Config) (int, error) { p.scanners = make(map[string]*bufio.Scanner) defer p.closeAll() - // If timeout is set then run with timeout - if p.timeout.Nanoseconds() != 0 { - ctx, cancel := context.WithTimeout(context.Background(), p.timeout) - defer cancel() - errChan := make(chan error) - go func() { - var execError error - p.exitStatus, execError = execProgram(p, program) - errChan <- execError - }() - select { - case err := <-errChan: - if err != nil { - return p.exitStatus, err - } - case <-ctx.Done(): - return p.exitStatus, fmt.Errorf("Runtime exceeded timeout %v", p.timeout) - } - } - return execProgram(p, program) } func execProgram(p *interp, program *Program) (int, error) { @@ -318,7 +303,7 @@ func execProgram(p *interp, program *Program) (int, error) { if err != nil && err != errExit { return 0, err } - return 0, nil + return p.exitStatus, nil } // Exec provides a simple way to parse and execute an AWK program @@ -439,6 +424,10 @@ func (p *interp) executes(stmts Stmts) error { // Execute a single statement func (p *interp) execute(stmt Stmt) error { + // Check runtime limit + if runtimeError := p.exceedsRuntimeLimit(); runtimeError != nil { + return runtimeError + } switch s := stmt.(type) { case *ExprStmt: @@ -665,8 +654,23 @@ func (p *interp) execute(stmt Stmt) error { return nil } +// Check runtime limitations +func (p *interp) exceedsRuntimeLimit() error { + if p.runtimeLimit > 0 { + if p.runtimeLimit < time.Since(p.startTime) { + return fmt.Errorf("Runtime exceeded limit of %v", p.runtimeLimit) + } + } + return nil +} + // Evaluate a single expression, return expression value and error func (p *interp) eval(expr Expr) (value, error) { + // Check runtime limit + if runtimeError := p.exceedsRuntimeLimit(); runtimeError != nil { + return num(1), runtimeError + } + switch e := expr.(type) { case *NumExpr: // Number literal diff --git a/interp/interp_test.go b/interp/interp_test.go index bab3a8e2..28dd037c 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -1005,7 +1005,7 @@ func TestSafeMode(t *testing.T) { }) } } -func TestTimeout(t *testing.T) { +func TestRuntimeLimit(t *testing.T) { tests := []struct { src string in string @@ -1013,21 +1013,34 @@ func TestTimeout(t *testing.T) { err string args []string }{ - {`BEGIN { print "hi" }`, "", "hi\nhi\n", "", nil}, - {`BEGIN { while(i<1){} }`, "", "", "Runtime exceeded timeout 5ms", nil}, + {`BEGIN { print "hi" }`, "", "hi\n", "", nil}, + {`BEGIN { while(i<1){} }`, "", "", "Runtime exceeded limit of 5ms", nil}, {`BEGIN { while(i<1){i++} }`, "", "", "", nil}, + {`BEGIN { while(1){} }`, "", "", "Runtime exceeded limit of 5ms", nil}, } for _, test := range tests { testName := test.src if len(testName) > 70 { testName = testName[:70] } - t.Run(testName, func(t *testing.T) { - testGoAWK(t, test.src, test.in, test.out, test.err, nil, func(config *interp.Config) { - config.Args = test.args - config.Timeout = 5 * time.Millisecond + timeout := 5 * time.Millisecond + testRun := make(chan bool, 0) + go func() { + t.Run(testName, func(t *testing.T) { + testGoAWK(t, test.src, test.in, test.out, test.err, nil, func(config *interp.Config) { + config.Args = test.args + config.RuntimeLimit = timeout + }) }) - }) + testRun <- true + }() + + select { + case <-testRun: + break + case <-time.After(timeout): + t.Errorf("Failed to stop runtime in %v: %s", timeout, test.src) + } } }