-
Notifications
You must be signed in to change notification settings - Fork 0
/
ptywrapper.go
323 lines (266 loc) · 7.57 KB
/
ptywrapper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
package ptywrapper
import (
// Modules in GOROOT
"os"
"os/exec"
"os/signal"
"syscall"
"sync"
"time"
"context"
"bufio"
"bytes"
"strings"
"regexp"
"io"
// External modules
"github.com/creack/pty"
"golang.org/x/term"
"golang.org/x/sys/unix"
)
//
//// STRINGS
//
func cleanupString(originalString string) string {
// Regular expression pattern to match ANSI escape sequences
// This makes it easier to store, parse, read and use the command output as input for other programs
reg := regexp.MustCompile(`\x1b\[[0-9;]*[a-zA-Z]`)
cleanedString := reg.ReplaceAllString(originalString, "")
// Remove the first and last newline characters, if they exist.
cleanedString = strings.TrimLeft(cleanedString, "\n")
cleanedString = strings.TrimRight(cleanedString, "\n")
// Remove all '\r' special characters
cleanedString = strings.ReplaceAll(cleanedString, "\r", "")
return cleanedString
}
//
//// CONTEXT
//
type contextWrapper struct {
Ctx context.Context
Cancel context.CancelFunc
}
func generateContextWrapper() contextWrapper {
// Create a new context
ctx, cancel := context.WithCancel(context.Background())
// Generate a new contextWrapper
wrapper := contextWrapper{
Ctx: ctx,
Cancel: cancel,
}
return wrapper
}
//
//// I/O WRITERS
//
type Writer struct {
src *os.File
dest *os.File
ctx contextWrapper
}
func (w *Writer) Write(p []byte) (n int, err error) {
return w.dest.Write(p)
}
//
//// COMMAND
//
type Command struct {
Entry string
Args []string
Env []string
Discard bool // Discard output (will still save it as a variable)
Completed bool
Output string
ExitCode int
}
func (command *Command) RunInPTY() (Command, error) {
// Create a command
c := exec.Command(command.Entry, command.Args...)
c.SysProcAttr = &syscall.SysProcAttr{
Setctty: true, // Set controlling terminal to the pseudo terminal
Setsid: true, // Start the command in a new session
}
// Set environment (use custom environment if available)
if command.Env != nil {
c.Env = command.Env
} else {
c.Env = os.Environ()
}
// Open a pty
// Terminology:
// - primary => ptm (master)
// - secondary => pts (slave)
primary, secondary, err := pty.Open()
if err != nil {
return *command, err
}
defer primary.Close()
defer secondary.Close()
// Set stdin, stdout and sterr for the command
c.Stdin = secondary
c.Stdout = secondary
c.Stderr = secondary
// Get the file descriptor for stdin
fd := int(os.Stdin.Fd())
// Make stdin raw and save the old state
oldState, err := term.MakeRaw(fd)
if err != nil {
return *command, err
}
defer func() { _ = term.Restore(fd, oldState) }() // Ensure the old state is restored when the function returns
// Enable non-blocking I/O on stdin
flag, err := unix.FcntlInt(uintptr(fd), unix.F_GETFL, 0)
if err != nil {
return *command, err
}
flag, err = unix.FcntlInt(uintptr(fd), unix.F_SETFL, flag|unix.O_NONBLOCK)
if err != nil {
return *command, err
}
// Resize the pty
ch := make(chan os.Signal, 1)
errCh := make(chan error, 1)
signal.Notify(ch, syscall.SIGWINCH)
go func() {
for range ch {
if err := pty.InheritSize(os.Stdin, primary); err != nil {
errCh <- err // Send the error to the error channel
return
}
}
}()
ch <- syscall.SIGWINCH // Initial resize
select {
case err := <-errCh:
return *command, err
default:
// No error, continue execution
}
func() { signal.Stop(ch); close(ch); close(errCh)}() // Cleanup signal and channels when done
// Start the command
err = c.Start()
if err != nil {
return *command, err
}
// Create a context to track if the command is still running
cmdExecutionContext := generateContextWrapper()
// Create a bytes buffer to capture the command's output
var cmdOutput bytes.Buffer
// Start goroutine to copy data from os.Stdin to ptm (via a custom writer)
var ptyWriterWaitGroup sync.WaitGroup
ptyWriterWaitGroup.Add(1)
stdinWriter := &Writer{
src: os.Stdin,
dest: primary,
ctx: cmdExecutionContext,
}
go func() {
defer ptyWriterWaitGroup.Done()
// Create a reader to get data from os.Stdin
reader := bufio.NewReader(stdinWriter.src)
// Create a bytes buffer
buf := make([]byte, 4096)
// Start loop
Loop:
for {
select {
case <-stdinWriter.ctx.Ctx.Done():
break Loop
default:
// Read buffer
n, err := reader.Read(buf)
// Verify if there is no data available to read
// NOTE: EAGAIN means that there is no data available to read, so it's not really an error in this case
if pathErr, ok := err.(*os.PathError); ok && pathErr.Err == syscall.EAGAIN {
// Sleep for 50 milliseconds (to slow down the loop but without removing too much fluidity from user input)
time.Sleep(time.Millisecond * 10)
continue
} else {
// Write data
_, err = stdinWriter.Write(buf[:n])
}
}
}
return
}()
// Start goroutine to copy data from ptm to os.Stdout
var stdoutWriterWaitGroup sync.WaitGroup
stdoutWriterWaitGroup.Add(1)
stdoutWriter := &Writer{
src: primary,
dest: os.Stdout,
ctx: cmdExecutionContext,
}
go func() {
defer stdoutWriterWaitGroup.Done()
// Create a bytes buffer
buf := make([]byte, 4096)
// Start loop
Loop:
for {
select {
case <-stdoutWriter.ctx.Ctx.Done():
break Loop
default:
// Read buffer
n, err := stdoutWriter.src.Read(buf)
// Verify if there is no data available to read
// NOTE: EAGAIN means that there is no data available to read, so it's not really an error in this case
if pathErr, ok := err.(*os.PathError); ok && pathErr.Err == syscall.EAGAIN {
// Sleep for 50 milliseconds (to slow down the loop but without removing too much fluidity from the output)
time.Sleep(time.Millisecond * 10)
continue
} else {
// Write data
if command.Discard == false {
_, err = stdoutWriter.Write(buf[:n])
}
// Copy bytes to output bytes buffer
_, err = io.Copy(&cmdOutput, bytes.NewReader(buf[:n]))
}
}
}
return
}()
// Wait for the command to exit
cmdExitCh := make(chan error, 1)
var cmdExecutionWaitGroup sync.WaitGroup
cmdExecutionWaitGroup.Add(1)
go func() {
defer cmdExecutionWaitGroup.Done()
cmdExitCh <- c.Wait()
// Cancel context
cmdExecutionContext.Cancel()
// Close pty
primary.Close()
secondary.Close()
// Wait for output writer to return
stdoutWriterWaitGroup.Wait()
// Wait for pty writer to return
ptyWriterWaitGroup.Wait()
return
}()
cmdExecutionWaitGroup.Wait()
// Get command exit code and save it
cmdExit := <-cmdExitCh
close(cmdExitCh)
if exitError, ok := cmdExit.(*exec.ExitError); ok {
// The command exited with a non-zero status (an error)
command.ExitCode = exitError.ExitCode()
} else if cmdExit != nil {
// Some other error occurred
return *command, cmdExit
} else {
// The command exited successfully
command.ExitCode = 0
}
// Convert command output from bytes to string
cmdOutputString := cmdOutput.String()
// Clean up command output
cmdOutputStringCleaned := cleanupString(cmdOutputString)
// Save cleaned up command output
command.Output = cmdOutputStringCleaned
// Mark command as completed and return
command.Completed = true
return *command, nil
}