Skip to content

Commit

Permalink
fix(output): status report ignored when assumeTTY or unsafe is true
Browse files Browse the repository at this point in the history
Assuming that tty is of type File makes termStatusReport fail.
  • Loading branch information
aymanbagabas committed Mar 9, 2023
1 parent ada8a8c commit 446900e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
7 changes: 5 additions & 2 deletions output.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,12 @@ func (o *Output) HasDarkBackground() bool {
return l < 0.5
}

// TTY returns the terminal's file descriptor. This may be nil if the output is
// TTY returns the underlying terminal output. This may be nil if the output is
// not a terminal.
func (o Output) TTY() File {
func (o Output) TTY() io.Writer {
if o.assumeTTY || o.unsafe {
return o.tty
}
if f, ok := o.tty.(File); ok {
return f
}
Expand Down
6 changes: 3 additions & 3 deletions termenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ func (o *Output) isTTY() bool {
if len(o.environ.Getenv("CI")) > 0 {
return false
}
if o.TTY() == nil {
return false
if f, ok := o.TTY().(File); ok {
return isatty.IsTerminal(f.Fd())
}

return isatty.IsTerminal(o.TTY().Fd())
return false
}

// ColorProfile returns the supported color profile:
Expand Down
36 changes: 20 additions & 16 deletions termenv_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ func (o Output) backgroundColor() Color {
return ANSIColor(0)
}

func (o *Output) waitForData(timeout time.Duration) error {
fd := o.TTY().Fd()
func (o *Output) waitForData(tty File, timeout time.Duration) error {
fd := tty.Fd()
tv := unix.NsecToTimeval(int64(timeout))
var readfds unix.FdSet
readfds.Set(int(fd))
Expand All @@ -133,15 +133,15 @@ func (o *Output) waitForData(timeout time.Duration) error {
return nil
}

func (o *Output) readNextByte() (byte, error) {
if !o.unsafe {
if err := o.waitForData(OSCTimeout); err != nil {
func (o *Output) readNextByte(tty io.ReadWriter) (byte, error) {
if f, ok := tty.(File); ok && !o.unsafe {
if err := o.waitForData(f, OSCTimeout); err != nil {
return 0, err
}
}

var b [1]byte
n, err := o.TTY().Read(b[:])
n, err := tty.Read(b[:])
if err != nil {
return 0, err
}
Expand All @@ -156,15 +156,15 @@ func (o *Output) readNextByte() (byte, error) {
// readNextResponse reads either an OSC response or a cursor position response:
// - OSC response: "\x1b]11;rgb:1111/1111/1111\x1b\\"
// - cursor position response: "\x1b[42;1R"
func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
start, err := o.readNextByte()
func (o *Output) readNextResponse(tty io.ReadWriter) (response string, isOSC bool, err error) {
start, err := o.readNextByte(tty)
if err != nil {
return "", false, err
}

// first byte must be ESC
for start != ESC {
start, err = o.readNextByte()
start, err = o.readNextByte(tty)
if err != nil {
return "", false, err
}
Expand All @@ -173,7 +173,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
response += string(start)

// next byte is either '[' (cursor position response) or ']' (OSC response)
tpe, err := o.readNextByte()
tpe, err := o.readNextByte(tty)
if err != nil {
return "", false, err
}
Expand All @@ -191,7 +191,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
}

for {
b, err := o.readNextByte()
b, err := o.readNextByte(tty)
if err != nil {
return "", false, err
}
Expand Down Expand Up @@ -227,13 +227,17 @@ func (o Output) termStatusReport(sequence int) (string, error) {
return "", ErrStatusReport
}

tty := o.TTY()
if tty == nil {
tty, ok := o.TTY().(io.ReadWriter)
if tty == nil || !ok {
return "", ErrStatusReport
}

if !o.unsafe {
fd := int(tty.Fd())
f, ok := tty.(File)
if !ok {
return "", ErrStatusReport
}
fd := int(f.Fd())
// if in background, we can't control the terminal
if !isForeground(fd) {
return "", ErrStatusReport
Expand All @@ -260,7 +264,7 @@ func (o Output) termStatusReport(sequence int) (string, error) {
fmt.Fprintf(tty, CSI+"6n")

// read the next response
res, isOSC, err := o.readNextResponse()
res, isOSC, err := o.readNextResponse(tty)
if err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
}
Expand All @@ -271,7 +275,7 @@ func (o Output) termStatusReport(sequence int) (string, error) {
}

// read the cursor query response next and discard the result
_, _, err = o.readNextResponse()
_, _, err = o.readNextResponse(tty)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions termenv_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func EnableVirtualTerminalProcessing(o *Output) (restoreFunc func() error, err e
}

// If o is not a tty, then there is nothing to do.
tty := o.TTY()
if tty == nil {
tty, ok := o.TTY().(File)
if tty == nil || !ok {
return
}

Expand Down

0 comments on commit 446900e

Please sign in to comment.