diff --git a/output.go b/output.go index e22d369..c52db33 100644 --- a/output.go +++ b/output.go @@ -178,9 +178,12 @@ func (o *Output) HasDarkBackground() bool { return l < 0.5 } -// TTY returns the terminal's file descriptor. This may be nil if the output is +// TTY returns the underlying terminal output. This may be nil if the output is // not a terminal. -func (o Output) TTY() File { +func (o Output) TTY() io.Writer { + if o.assumeTTY || o.unsafe { + return o.tty + } if f, ok := o.tty.(File); ok { return f } diff --git a/termenv.go b/termenv.go index 4ceb271..930fd4e 100644 --- a/termenv.go +++ b/termenv.go @@ -31,11 +31,11 @@ func (o *Output) isTTY() bool { if len(o.environ.Getenv("CI")) > 0 { return false } - if o.TTY() == nil { - return false + if f, ok := o.TTY().(File); ok { + return isatty.IsTerminal(f.Fd()) } - return isatty.IsTerminal(o.TTY().Fd()) + return false } // ColorProfile returns the supported color profile: diff --git a/termenv_unix.go b/termenv_unix.go index 11746d2..e965cea 100644 --- a/termenv_unix.go +++ b/termenv_unix.go @@ -109,8 +109,8 @@ func (o Output) backgroundColor() Color { return ANSIColor(0) } -func (o *Output) waitForData(timeout time.Duration) error { - fd := o.TTY().Fd() +func (o *Output) waitForData(tty File, timeout time.Duration) error { + fd := tty.Fd() tv := unix.NsecToTimeval(int64(timeout)) var readfds unix.FdSet readfds.Set(int(fd)) @@ -133,15 +133,15 @@ func (o *Output) waitForData(timeout time.Duration) error { return nil } -func (o *Output) readNextByte() (byte, error) { - if !o.unsafe { - if err := o.waitForData(OSCTimeout); err != nil { +func (o *Output) readNextByte(tty io.ReadWriter) (byte, error) { + if f, ok := tty.(File); ok && !o.unsafe { + if err := o.waitForData(f, OSCTimeout); err != nil { return 0, err } } var b [1]byte - n, err := o.TTY().Read(b[:]) + n, err := tty.Read(b[:]) if err != nil { return 0, err } @@ -156,15 +156,15 @@ func (o *Output) readNextByte() (byte, error) { // readNextResponse reads either an OSC response or a cursor position response: // - OSC response: "\x1b]11;rgb:1111/1111/1111\x1b\\" // - cursor position response: "\x1b[42;1R" -func (o *Output) readNextResponse() (response string, isOSC bool, err error) { - start, err := o.readNextByte() +func (o *Output) readNextResponse(tty io.ReadWriter) (response string, isOSC bool, err error) { + start, err := o.readNextByte(tty) if err != nil { return "", false, err } // first byte must be ESC for start != ESC { - start, err = o.readNextByte() + start, err = o.readNextByte(tty) if err != nil { return "", false, err } @@ -173,7 +173,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) { response += string(start) // next byte is either '[' (cursor position response) or ']' (OSC response) - tpe, err := o.readNextByte() + tpe, err := o.readNextByte(tty) if err != nil { return "", false, err } @@ -191,7 +191,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) { } for { - b, err := o.readNextByte() + b, err := o.readNextByte(tty) if err != nil { return "", false, err } @@ -227,13 +227,17 @@ func (o Output) termStatusReport(sequence int) (string, error) { return "", ErrStatusReport } - tty := o.TTY() - if tty == nil { + tty, ok := o.TTY().(io.ReadWriter) + if tty == nil || !ok { return "", ErrStatusReport } if !o.unsafe { - fd := int(tty.Fd()) + f, ok := tty.(File) + if !ok { + return "", ErrStatusReport + } + fd := int(f.Fd()) // if in background, we can't control the terminal if !isForeground(fd) { return "", ErrStatusReport @@ -260,7 +264,7 @@ func (o Output) termStatusReport(sequence int) (string, error) { fmt.Fprintf(tty, CSI+"6n") // read the next response - res, isOSC, err := o.readNextResponse() + res, isOSC, err := o.readNextResponse(tty) if err != nil { return "", fmt.Errorf("%s: %s", ErrStatusReport, err) } @@ -271,7 +275,7 @@ func (o Output) termStatusReport(sequence int) (string, error) { } // read the cursor query response next and discard the result - _, _, err = o.readNextResponse() + _, _, err = o.readNextResponse(tty) if err != nil { return "", err } diff --git a/termenv_windows.go b/termenv_windows.go index 1d9c618..ee755d4 100644 --- a/termenv_windows.go +++ b/termenv_windows.go @@ -103,8 +103,8 @@ func EnableVirtualTerminalProcessing(o *Output) (restoreFunc func() error, err e } // If o is not a tty, then there is nothing to do. - tty := o.TTY() - if tty == nil { + tty, ok := o.TTY().(File) + if tty == nil || !ok { return }