Skip to content

Commit

Permalink
Merge pull request #1025 from kevpar/close-stdio
Browse files Browse the repository at this point in the history
internal/cmd: Close individual IO pipes when the relay finishes
  • Loading branch information
kevpar committed May 13, 2021
2 parents 710d704 + ce4f347 commit 0f5799e
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 22 deletions.
23 changes: 15 additions & 8 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ func Command(host cow.ProcessHost, name string, arg ...string) *Cmd {
Spec: &specs.Process{
Args: append([]string{name}, arg...),
},
Log: logrus.NewEntry(logrus.StandardLogger()),
ExitState: &ExitState{},
}
if host.OS() == "windows" {
Expand All @@ -120,7 +121,8 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg
return cmd
}

func copyAndLog(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, error) {
// relayIO is a glorified io.Copy that also logs when the copy has completed.
func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, error) {
n, err := io.Copy(w, r)
if log != nil {
lvl := logrus.DebugLevel
Expand All @@ -132,7 +134,7 @@ func copyAndLog(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64
lvl = logrus.ErrorLevel
log = log.WithError(err)
}
log.Log(lvl, "command copy complete")
log.Log(lvl, "Cmd IO relay complete")
}
return n, err
}
Expand Down Expand Up @@ -207,31 +209,36 @@ func (c *Cmd) Start() error {
// us or the caller to reliably unblock the c.Stdin read when the
// process exits.
go func() {
_, err := copyAndLog(stdin, c.Stdin, c.Log, "stdin")
_, err := relayIO(stdin, c.Stdin, c.Log, "stdin")
// Report the stdin copy error. If the process has exited, then the
// caller may never see it, but if the error was due to a failure in
// stdin read, then it is likely the process is still running.
if err != nil {
c.stdinErr.Store(err)
}
// Notify the process that there is no more input.
err = p.CloseStdin(context.TODO())
if err != nil && c.Log != nil {
c.Log.WithError(err).Warn("failed to close pod stdin")
if err := p.CloseStdin(context.TODO()); err != nil && c.Log != nil {
c.Log.WithError(err).Warn("failed to close Cmd stdin")
}
}()
}

if c.Stdout != nil {
c.iogrp.Go(func() error {
_, err := copyAndLog(c.Stdout, stdout, c.Log, "stdout")
_, err := relayIO(c.Stdout, stdout, c.Log, "stdout")
if err := p.CloseStdout(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stdout")
}
return err
})
}

if c.Stderr != nil {
c.iogrp.Go(func() error {
_, err := copyAndLog(c.Stderr, stderr, c.Log, "stderr")
_, err := relayIO(c.Stderr, stderr, c.Log, "stderr")
if err := p.CloseStderr(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stderr")
}
return err
})
}
Expand Down
8 changes: 8 additions & 0 deletions internal/cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ func (p *localProcess) CloseStdin(ctx context.Context) error {
return p.stdin.Close()
}

func (p *localProcess) CloseStdout(ctx context.Context) error {
return p.stdout.Close()
}

func (p *localProcess) CloseStderr(ctx context.Context) error {
return p.stderr.Close()
}

func (p *localProcess) ExitCode() (int, error) {
select {
case <-p.ch:
Expand Down
6 changes: 6 additions & 0 deletions internal/cow/cow.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ type Process interface {
// CloseStdin causes the process's stdin handle to receive EOF/EPIPE/whatever
// is appropriate to indicate that no more data is available.
CloseStdin(ctx context.Context) error
// CloseStdout closes the stdout connection to the process. It is used to indicate
// that we are done receiving output on the shim side.
CloseStdout(ctx context.Context) error
// CloseStderr closes the stderr connection to the process. It is used to indicate
// that we are done receiving output on the shim side.
CloseStderr(ctx context.Context) error
// Pid returns the process ID.
Pid() int
// Stdio returns the stdio streams for a process. These may be nil if a stream
Expand Down
9 changes: 1 addition & 8 deletions internal/gcs/iochannel.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gcs

import (
"io"
"net"
)

Expand Down Expand Up @@ -56,13 +55,7 @@ func (c *ioChannel) Read(b []byte) (int, error) {
if c.c == nil {
return 0, c.err
}
n, err := c.c.Read(b)
if err == io.EOF {
// Close the underlying connection so that the VM
// knows that all data has been read.
c.c.Close()
}
return n, err
return c.c.Read(b)
}

func (c *ioChannel) Write(b []byte) (int, error) {
Expand Down
31 changes: 25 additions & 6 deletions internal/gcs/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,13 @@ func (p *Process) Close() error {
trace.StringAttribute("cid", p.cid),
trace.Int64Attribute("pid", int64(p.id)))

err := p.stdin.Close()
if err != nil {
if err := p.stdin.Close(); err != nil {
log.G(ctx).WithError(err).Warn("close stdin failed")
}
err = p.stdout.Close()
if err != nil {
if err := p.stdout.Close(); err != nil {
log.G(ctx).WithError(err).Warn("close stdout failed")
}
err = p.stderr.Close()
if err != nil {
if err := p.stderr.Close(); err != nil {
log.G(ctx).WithError(err).Warn("close stderr failed")
}
return nil
Expand All @@ -158,6 +155,28 @@ func (p *Process) CloseStdin(ctx context.Context) (err error) {
return p.stdinCloseWriteErr
}

func (p *Process) CloseStdout(ctx context.Context) (err error) {
ctx, span := trace.StartSpan(ctx, "gcs::Process::CloseStdout") //nolint:ineffassign,staticcheck
defer span.End()
defer func() { oc.SetSpanStatus(span, err) }()
span.AddAttributes(
trace.StringAttribute("cid", p.cid),
trace.Int64Attribute("pid", int64(p.id)))

return p.stdout.Close()
}

func (p *Process) CloseStderr(ctx context.Context) (err error) {
ctx, span := trace.StartSpan(ctx, "gcs::Process::CloseStderr") //nolint:ineffassign,staticcheck
defer span.End()
defer func() { oc.SetSpanStatus(span, err) }()
span.AddAttributes(
trace.StringAttribute("cid", p.cid),
trace.Int64Attribute("pid", int64(p.id)))

return p.stderr.Close()
}

// ExitCode returns the process's exit code, or an error if the process is still
// running or the exit code is otherwise unknown.
func (p *Process) ExitCode() (_ int, err error) {
Expand Down
49 changes: 49 additions & 0 deletions internal/hcs/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,55 @@ func (process *Process) CloseStdin(ctx context.Context) error {
return nil
}

func (process *Process) CloseStdout(ctx context.Context) (err error) {
ctx, span := trace.StartSpan(ctx, "hcs::Process::CloseStdout") //nolint:ineffassign,staticcheck
defer span.End()
defer func() { oc.SetSpanStatus(span, err) }()
span.AddAttributes(
trace.StringAttribute("cid", process.SystemID()),
trace.Int64Attribute("pid", int64(process.processID)))

process.handleLock.Lock()
defer process.handleLock.Unlock()

if process.handle == 0 {
return nil
}

process.stdioLock.Lock()
defer process.stdioLock.Unlock()
if process.stdout != nil {
process.stdout.Close()
process.stdout = nil
}
return nil
}

func (process *Process) CloseStderr(ctx context.Context) (err error) {
ctx, span := trace.StartSpan(ctx, "hcs::Process::CloseStderr") //nolint:ineffassign,staticcheck
defer span.End()
defer func() { oc.SetSpanStatus(span, err) }()
span.AddAttributes(
trace.StringAttribute("cid", process.SystemID()),
trace.Int64Attribute("pid", int64(process.processID)))

process.handleLock.Lock()
defer process.handleLock.Unlock()

if process.handle == 0 {
return nil
}

process.stdioLock.Lock()
defer process.stdioLock.Unlock()
if process.stderr != nil {
process.stderr.Close()
process.stderr = nil

}
return nil
}

// Close cleans up any state associated with the process but does not kill
// or wait on it.
func (process *Process) Close() (err error) {
Expand Down
14 changes: 14 additions & 0 deletions internal/jobcontainers/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ func (p *JobProcess) CloseStdin(ctx context.Context) error {
return p.stdin.Close()
}

// CloseStdout closes the stdout pipe of the process.
func (p *JobProcess) CloseStdout(ctx context.Context) error {
p.stdioLock.Lock()
defer p.stdioLock.Unlock()
return p.stdout.Close()
}

// CloseStderr closes the stderr pipe of the process.
func (p *JobProcess) CloseStderr(ctx context.Context) error {
p.stdioLock.Lock()
defer p.stdioLock.Unlock()
return p.stderr.Close()
}

// Wait waits for the process to exit. If the process has already exited returns
// the previous error (if any).
func (p *JobProcess) Wait() error {
Expand Down

0 comments on commit 0f5799e

Please sign in to comment.