Skip to content

Commit

Permalink
Code review feedback
Browse files Browse the repository at this point in the history
Also:

1. Avoid leaking handles returned from WaitForDebugEvent

2. Avoid accidentally swallowing first-chance exceptions.
  • Loading branch information
lukehoban committed Jan 21, 2016
1 parent 59257e5 commit ce025fe
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 67 deletions.
114 changes: 73 additions & 41 deletions proc/proc_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,34 @@ func Launch(cmd []string) (*Process, error) {
return nil, err
}

argv0, _ := syscall.UTF16PtrFromString(argv0Go)
// Duplicate the stdin/stdout/stderr handles
files := []uintptr{uintptr(syscall.Stdin), uintptr(syscall.Stdout), uintptr(syscall.Stderr)}
p, _ := syscall.GetCurrentProcess()
fd := make([]syscall.Handle, len(files))
for i := range files {
err := syscall.DuplicateHandle(p, syscall.Handle(files[i]), p, &fd[i], 0, true, syscall.DUPLICATE_SAME_ACCESS)
if err != nil {
return nil, err
}
defer syscall.CloseHandle(syscall.Handle(fd[i]))
}
// Initialize the startup info and create process
si := new(sys.StartupInfo)
si.Cb = uint32(unsafe.Sizeof(*si))
si.Flags = syscall.STARTF_USESTDHANDLES
si.StdInput = sys.Handle(fd[0])
si.StdOutput = sys.Handle(fd[1])
si.StdErr = sys.Handle(fd[2])
argv0, _ := syscall.UTF16PtrFromString(argv0Go)

// Duplicate the stdin/stdout/stderr handles
files := []uintptr{uintptr(syscall.Stdin), uintptr(syscall.Stdout), uintptr(syscall.Stderr)}
p, _ := syscall.GetCurrentProcess()
fd := make([]syscall.Handle, len(files))
for i := range files {
err := syscall.DuplicateHandle(p, syscall.Handle(files[i]), p, &fd[i], 0, true, syscall.DUPLICATE_SAME_ACCESS)
if err != nil {
return nil, err
}
defer syscall.CloseHandle(syscall.Handle(fd[i]))
}

// Initialize the startup info and create process
si := new(sys.StartupInfo)
si.Cb = uint32(unsafe.Sizeof(*si))
si.Flags = syscall.STARTF_USESTDHANDLES
si.StdInput = sys.Handle(fd[0])
si.StdOutput = sys.Handle(fd[1])
si.StdErr = sys.Handle(fd[2])
pi := new(sys.ProcessInformation)
err = sys.CreateProcess(argv0, nil, nil, nil, true, DEBUGONLYTHISPROCESS, nil, nil, si, pi)
if err != nil {
return nil, err
}
defer sys.CloseHandle(sys.Handle(pi.Thread))
sys.CloseHandle(sys.Handle(pi.Process))
sys.CloseHandle(sys.Handle(pi.Thread))

dbp := New(int(pi.ProcessId))

Expand Down Expand Up @@ -104,12 +105,11 @@ func Launch(cmd []string) (*Process, error) {

// Attach to an existing process with the given PID.
func Attach(pid int) (*Process, error) {
fmt.Println("Attach")
return nil, fmt.Errorf("Not implemented: Attach")
}

// Kill kills the process.
func (dbp *Process) Kill() (err error) {
func (dbp *Process) Kill() error {
if dbp.exited {
return nil
}
Expand All @@ -121,13 +121,13 @@ func (dbp *Process) Kill() (err error) {
// this to fail on second attempt.
_ = C.TerminateProcess(C.HANDLE(dbp.os.hProcess), 1)
dbp.exited = true
return
return nil
}

func (dbp *Process) requestManualStop() (err error) {
func (dbp *Process) requestManualStop() error {
res := C.DebugBreakProcess(C.HANDLE(dbp.os.hProcess))
if res == C.FALSE {
return fmt.Errorf("Failed to break process %d", dbp.Pid)
return fmt.Errorf("failed to break process %d", dbp.Pid)
}
return nil
}
Expand Down Expand Up @@ -286,7 +286,7 @@ func (dbp *Process) findExecutable(path string) (*pe.File, error) {
if path == "" {
// TODO: Find executable path from PID/handle on Windows:
// https://msdn.microsoft.com/en-us/library/aa366789(VS.85).aspx
return nil, fmt.Errorf("Not yet implemented")
return nil, fmt.Errorf("not yet implemented")
}
f, err := os.OpenFile(path, 0, os.ModePerm)
if err != nil {
Expand All @@ -306,18 +306,27 @@ func (dbp *Process) findExecutable(path string) (*pe.File, error) {

func (dbp *Process) waitForDebugEvent() (threadID, exitCode int, err error) {
var debugEvent C.DEBUG_EVENT
var continueStatus C.DWORD
for {
continueStatus = C.DBG_CONTINUE
// Wait for a debug event...
res := C.WaitForDebugEvent(&debugEvent, C.INFINITE)
if res == C.WINBOOL(0) {
return 0, 0, fmt.Errorf("Could not WaitForDebugEvent")
if res == C.FALSE {
return 0, 0, fmt.Errorf("could not WaitForDebugEvent")
}

// ... handle each event kind ...
unionPtr := unsafe.Pointer(&debugEvent.u[0])
switch debugEvent.dwDebugEventCode {
case C.CREATE_PROCESS_DEBUG_EVENT:
debugInfo := (*C.CREATE_PROCESS_DEBUG_INFO)(unionPtr)
hFile := debugInfo.hFile
if hFile != C.HANDLE(uintptr(0)) /* NULL */ && hFile != C.HANDLE(uintptr(0xFFFFFFFFFFFFFFFF)) /* INVALID_HANDLE_VALUE */ {
res = C.CloseHandle(hFile)
if res == C.FALSE {
return 0, 0, fmt.Errorf("could not close create process file handle")
}
}
dbp.os.hProcess = sys.Handle(debugInfo.hProcess)
_, err = dbp.addThread(sys.Handle(debugInfo.hThread), int(debugEvent.dwThreadId), false)
if err != nil {
Expand All @@ -331,24 +340,49 @@ func (dbp *Process) waitForDebugEvent() (threadID, exitCode int, err error) {
return 0, 0, err
}
break
case C.LOAD_DLL_DEBUG_EVENT, C.UNLOAD_DLL_DEBUG_EVENT, C.EXIT_THREAD_DEBUG_EVENT, C.OUTPUT_DEBUG_STRING_EVENT, C.RIP_EVENT:
// TODO: Clean up exited threads, handle debug output strings, and maybe more?
case C.EXIT_THREAD_DEBUG_EVENT:
delete(dbp.Threads, int(debugEvent.dwThreadId))
break
case C.OUTPUT_DEBUG_STRING_EVENT:
//TODO: Handle debug output strings
break
case C.LOAD_DLL_DEBUG_EVENT:
debugInfo := (*C.LOAD_DLL_DEBUG_INFO)(unionPtr)
hFile := debugInfo.hFile
if hFile != C.HANDLE(uintptr(0)) /* NULL */ && hFile != C.HANDLE(uintptr(0xFFFFFFFFFFFFFFFF)) /* INVALID_HANDLE_VALUE */ {
res = C.CloseHandle(hFile)
if res == C.FALSE {
return 0, 0, fmt.Errorf("could not close DLL load file handle")
}
}
break
case C.UNLOAD_DLL_DEBUG_EVENT:
break
case C.RIP_EVENT:
break
case C.EXCEPTION_DEBUG_EVENT:
debugInfo := (*C.EXCEPTION_DEBUG_INFO)(unionPtr)
switch debugInfo.ExceptionRecord.ExceptionCode {
case C.EXCEPTION_BREAKPOINT, C.EXCEPTION_SINGLE_STEP:
continueStatus = C.DBG_CONTINUE
default:
continueStatus = C.DBG_EXCEPTION_NOT_HANDLED
}
tid := int(debugEvent.dwThreadId)
dbp.os.breakThread = tid
return tid, 0, nil
break
case C.EXIT_PROCESS_DEBUG_EVENT:
debugInfo := (*C.EXIT_PROCESS_DEBUG_INFO)(unionPtr)
return 0, int(debugInfo.dwExitCode), nil
default:
return 0, 0, fmt.Errorf("Unknown debug event code: %d", debugEvent.dwDebugEventCode)
return 0, 0, fmt.Errorf("unknown debug event code: %d", debugEvent.dwDebugEventCode)
}

// .. and then continue unless we received an event that indicated we should break into debugger.
res = C.ContinueDebugEvent(debugEvent.dwProcessId, debugEvent.dwThreadId, C.DBG_CONTINUE)
res = C.ContinueDebugEvent(debugEvent.dwProcessId, debugEvent.dwThreadId, continueStatus)
if res == C.WINBOOL(0) {
return 0, 0, fmt.Errorf("Could not ContinueDebugEvent")
return 0, 0, fmt.Errorf("could not ContinueDebugEvent")
}
}
}
Expand All @@ -375,8 +409,7 @@ func (dbp *Process) loadProcessInformation(wg *sync.WaitGroup) {
}

func (dbp *Process) wait(pid, options int) (int, *sys.WaitStatus, error) {
fmt.Println("wait")
return 0, nil, fmt.Errorf("Not implemented: wait")
return 0, nil, fmt.Errorf("not implemented: wait")
}

func (dbp *Process) setCurrentBreakpoints(trapthread *Thread) error {
Expand All @@ -390,12 +423,11 @@ func (dbp *Process) setCurrentBreakpoints(trapthread *Thread) error {
// in this case, one for each thread, so we should only handle the BP hit
// on the thread that the debugger was evented on.

err := trapthread.SetCurrentBreakpoint()
return err
return trapthread.SetCurrentBreakpoint()
}

func (dbp *Process) exitGuard(err error) error {
return err
return err
}

func (dbp *Process) resume() error {
Expand All @@ -419,5 +451,5 @@ func (dbp *Process) resume() error {

func killProcess(pid int) error {
fmt.Println("killProcess")
return fmt.Errorf("Not implemented: killProcess")
return fmt.Errorf("not implemented: killProcess")
}
4 changes: 2 additions & 2 deletions proc/ptrace_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
)

func PtraceAttach(pid int) error {
return fmt.Errorf("Not implemented: PtraceAttach")
return fmt.Errorf("not implemented: PtraceAttach")
}

func PtraceDetach(tid, sig int) error {
return fmt.Errorf("Not implemented: PtraceDetach")
return fmt.Errorf("not implemented: PtraceDetach")
}
14 changes: 6 additions & 8 deletions proc/registers_windows_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ func (r *Regs) SetPC(thread *Thread, pc uint64) error {

res := C.GetThreadContext(C.HANDLE(thread.os.hThread), &context)
if res == C.FALSE {
return fmt.Errorf("Could not GetThreadContext")
return fmt.Errorf("could not GetThreadContext")
}

context.Rip = C.DWORD64(pc)

res = C.SetThreadContext(C.HANDLE(thread.os.hThread), &context)
if res == C.FALSE {
return fmt.Errorf("Could not SetThreadContext")
return fmt.Errorf("could not SetThreadContext")
}

return nil
Expand All @@ -118,13 +118,13 @@ func registers(thread *Thread) (Registers, error) {
context.ContextFlags = C.CONTEXT_ALL
res := C.GetThreadContext(C.HANDLE(thread.os.hThread), &context)
if res == C.FALSE {
return nil, fmt.Errorf("Failed to read ThreadContext")
return nil, fmt.Errorf("failed to read ThreadContext")
}

var threadInfo C.THREAD_BASIC_INFORMATION
res = C.thread_basic_information(C.HANDLE(thread.os.hThread), &threadInfo)
if res == C.FALSE {
return nil, fmt.Errorf("Failed to get thread_basic_information")
return nil, fmt.Errorf("failed to get thread_basic_information")
}
tls := uintptr(threadInfo.TebBaseAddress)

Expand Down Expand Up @@ -156,11 +156,9 @@ func registers(thread *Thread) (Registers, error) {
}

func (thread *Thread) saveRegisters() (Registers, error) {
fmt.Println("registers.saveRegisters")
return nil, fmt.Errorf("Not implemented: saveRegisters")
return nil, fmt.Errorf("not implemented: saveRegisters")
}

func (thread *Thread) restoreRegisters() error {
fmt.Println("registers.restoreRegisters")
return fmt.Errorf("Not implemented: restoreRegisters")
return fmt.Errorf("not implemented: restoreRegisters")
}
12 changes: 6 additions & 6 deletions proc/threads_windows.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
typedef NTSTATUS (WINAPI *pNtQIT)(HANDLE, LONG, PVOID, ULONG, PULONG);

WINBOOL thread_basic_information(HANDLE h, THREAD_BASIC_INFORMATION* addr) {
static pNtQIT NtQueryInformationThread = NULL;
static pNtQIT NtQueryInformationThread = NULL;
if(NtQueryInformationThread == NULL) {
NtQueryInformationThread = (pNtQIT)GetProcAddress(GetModuleHandle("ntdll.dll"), "NtQueryInformationThread");
if(NtQueryInformationThread == NULL) {
return 0;
}
}
NtQueryInformationThread = (pNtQIT)GetProcAddress(GetModuleHandle("ntdll.dll"), "NtQueryInformationThread");
if(NtQueryInformationThread == NULL) {
return 0;
}
}

NTSTATUS status = NtQueryInformationThread(h, ThreadBasicInformation, addr, 48, 0);
return NT_SUCCESS(status);
Expand Down
20 changes: 10 additions & 10 deletions proc/threads_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ func (t *Thread) singleStep() error {
// Set the processor TRAP flag
res := C.GetThreadContext(C.HANDLE(t.os.hThread), &context)
if res == C.FALSE {
return fmt.Errorf("Could not GetThreadContext")
return fmt.Errorf("could not GetThreadContext")
}

context.EFlags |= 0x100

res = C.SetThreadContext(C.HANDLE(t.os.hThread), &context)
if res == C.FALSE {
return fmt.Errorf("Could not SetThreadContext")
return fmt.Errorf("could not SetThreadContext")
}

// Suspend all threads except this one
Expand All @@ -52,7 +52,7 @@ func (t *Thread) singleStep() error {
}
res := C.SuspendThread(C.HANDLE(thread.os.hThread))
if res == C.DWORD(0xFFFFFFFF) {
return fmt.Errorf("Could not suspend thread: %d", thread.ID)
return fmt.Errorf("could not suspend thread: %d", thread.ID)
}
}

Expand All @@ -61,7 +61,7 @@ func (t *Thread) singleStep() error {
res = C.ContinueDebugEvent(C.DWORD(t.dbp.Pid), C.DWORD(t.ID), C.DBG_CONTINUE)
})
if res == C.FALSE {
return fmt.Errorf("Could not ContinueDebugEvent.")
return fmt.Errorf("could not ContinueDebugEvent.")
}
_, err := t.dbp.trapWait(0)
if err != nil {
Expand All @@ -75,21 +75,21 @@ func (t *Thread) singleStep() error {
}
res := C.ResumeThread(C.HANDLE(thread.os.hThread))
if res == C.DWORD(0xFFFFFFFF) {
return fmt.Errorf("Could not resume thread: %d", thread.ID)
return fmt.Errorf("ould not resume thread: %d", thread.ID)
}
}

// Unset the processor TRAP flag
res = C.GetThreadContext(C.HANDLE(t.os.hThread), &context)
if res == C.FALSE {
return fmt.Errorf("Could not GetThreadContext")
return fmt.Errorf("could not GetThreadContext")
}

context.EFlags &= ^C.DWORD(0x100)

res = C.SetThreadContext(C.HANDLE(t.os.hThread), &context)
if res == C.FALSE {
return fmt.Errorf("Could not SetThreadContext")
return fmt.Errorf("could not SetThreadContext")
}

return nil
Expand All @@ -104,13 +104,13 @@ func (t *Thread) resume() error {
res = C.ContinueDebugEvent(C.DWORD(t.dbp.Pid), C.DWORD(t.ID), C.DBG_CONTINUE)
})
if res == C.FALSE {
return fmt.Errorf("Could not ContinueDebugEvent.")
return fmt.Errorf("could not ContinueDebugEvent.")
}
return nil
}

func (t *Thread) blocked() bool {
// TODO: Probably incorrect - what are teh runtime functions that
// TODO: Probably incorrect - what are the runtime functions that
// indicate blocking on Windows?
pc, err := t.PC()
if err != nil {
Expand All @@ -121,7 +121,7 @@ func (t *Thread) blocked() bool {
return false
}
switch fn.Name {
case "runtime.kevent", "runtime.mach_semaphore_wait", "runtime.usleep":
case "runtime.kevent", "runtime.usleep":
return true
default:
return false
Expand Down

0 comments on commit ce025fe

Please sign in to comment.