diff --git a/pkg/winquit/channels_windows.go b/pkg/winquit/channels_windows.go index a9cd6be..2d7e48a 100644 --- a/pkg/winquit/channels_windows.go +++ b/pkg/winquit/channels_windows.go @@ -1,50 +1,50 @@ package winquit import ( - "os" - "syscall" + "os" + "syscall" ) type baseChannelType interface { - getKey() any - notifyNonBlocking() - notifyBlocking() + getKey() any + notifyNonBlocking() + notifyBlocking() } type boolChannelType struct { - channel chan bool + channel chan bool } func (b *boolChannelType) getKey() any { - return b.channel + return b.channel } func (b *boolChannelType) notifyNonBlocking() { - select { - case b.channel <- true: - default: - } + select { + case b.channel <- true: + default: + } } func (s *boolChannelType) notifyBlocking() { - s.channel <- true + s.channel <- true } type sigChannelType struct { - channel chan os.Signal + channel chan os.Signal } func (s *sigChannelType) getKey() any { - return s.channel + return s.channel } func (s *sigChannelType) notifyNonBlocking() { - select { - case s.channel <- syscall.SIGTERM: - default: - } + select { + case s.channel <- syscall.SIGTERM: + default: + } } func (s *sigChannelType) notifyBlocking() { - s.channel <- syscall.SIGTERM + s.channel <- syscall.SIGTERM } diff --git a/pkg/winquit/client.go b/pkg/winquit/client.go index 1df46ce..95b813c 100644 --- a/pkg/winquit/client.go +++ b/pkg/winquit/client.go @@ -1,7 +1,7 @@ package winquit import ( - "time" + "time" ) // RequestQuit sends a Windows quit notification to the specified process id. @@ -17,7 +17,7 @@ import ( // Callers must have appropriate security permissions, otherwise an error will // be returned. See the notes in the package documentation for more details. func RequestQuit(pid int) error { - return requestQuit(pid) + return requestQuit(pid) } // QuitProcess first sends a Windows quit notification to the specified process id, @@ -27,5 +27,5 @@ func RequestQuit(pid int) error { // Callers must have appropriate security permissions, otherwise an error will // be returned. See the notes in the package documentation for more details. func QuitProcess(pid int, waitNicely time.Duration) error { - return quitProcess(pid, waitNicely) + return quitProcess(pid, waitNicely) } diff --git a/pkg/winquit/client_unsupported.go b/pkg/winquit/client_unsupported.go index bb383f6..8aa5ed0 100644 --- a/pkg/winquit/client_unsupported.go +++ b/pkg/winquit/client_unsupported.go @@ -4,14 +4,14 @@ package winquit import ( - "fmt" - "time" + "fmt" + "time" ) func requestQuit(pid int) error { - return fmt.Errorf("not implemented on non-Windows") + return fmt.Errorf("not implemented on non-Windows") } func quitProcess(pid int, waitNicely time.Duration) error { - return fmt.Errorf("not implemented on non-Windows") + return fmt.Errorf("not implemented on non-Windows") } diff --git a/pkg/winquit/client_windows.go b/pkg/winquit/client_windows.go index 43eddef..e03919b 100644 --- a/pkg/winquit/client_windows.go +++ b/pkg/winquit/client_windows.go @@ -1,47 +1,47 @@ package winquit import ( - "os" - "time" + "os" + "time" - "github.com/containers/winquit/pkg/winquit/win32" - "github.com/sirupsen/logrus" + "github.com/containers/winquit/pkg/winquit/win32" + "github.com/sirupsen/logrus" ) func requestQuit(pid int) error { - threads, err := win32.GetProcThreads(uint32(pid)) - if err != nil { - return err - } + threads, err := win32.GetProcThreads(uint32(pid)) + if err != nil { + return err + } - for _, thread := range threads { - logrus.Debugf("Closing windows on thread %d", thread) - win32.CloseThreadWindows(uint32(thread)) - } + for _, thread := range threads { + logrus.Debugf("Closing windows on thread %d", thread) + win32.CloseThreadWindows(uint32(thread)) + } - return nil + return nil } func quitProcess(pid int, waitNicely time.Duration) error { - _ = RequestQuit(pid) + _ = RequestQuit(pid) - proc, err := os.FindProcess(pid) - if err != nil { - return nil - } + proc, err := os.FindProcess(pid) + if err != nil { + return nil + } - done := make(chan bool) + done := make(chan bool) - go func() { - proc.Wait() - done <- true - }() + go func() { + proc.Wait() + done <- true + }() - select { - case <-done: - return nil - case <-time.After(waitNicely): - } + select { + case <-done: + return nil + case <-time.After(waitNicely): + } - return proc.Kill() + return proc.Kill() } diff --git a/pkg/winquit/doc.go b/pkg/winquit/doc.go index 7c1804b..079794c 100644 --- a/pkg/winquit/doc.go +++ b/pkg/winquit/doc.go @@ -17,64 +17,64 @@ // The following example demonstrates usage of NotifyOnQuit() to wait for a // windows quit event before shutting down: // -// func server() { -// fmt.Println("Starting server") +// func server() { +// fmt.Println("Starting server") // -// // Create a channel, and register it -// done := make(chan bool, 1) -// winquit.NotifyOnQuit(done) +// // Create a channel, and register it +// done := make(chan bool, 1) +// winquit.NotifyOnQuit(done) // -// // Wait until we receive a quit event -// <-done +// // Wait until we receive a quit event +// <-done // -// fmt.Println("Shutting down") -// // Perform cleanup tasks -// } +// fmt.Println("Shutting down") +// // Perform cleanup tasks +// } // // # Blended signal example // // The following example demonstrates usage of SimulateSigTermOnQuit() in // concert with signal.Notify(): // -// func server() { -// fmt.Println("Starting server") +// func server() { +// fmt.Println("Starting server") // -// // Create a channel, and register it -// done := make(chan os.Signal, 1) +// // Create a channel, and register it +// done := make(chan os.Signal, 1) // -// // Wait on console interrupt events -// signal.Notify(done, syscall.SIGINT) +// // Wait on console interrupt events +// signal.Notify(done, syscall.SIGINT) // -// // Simulate SIGTERM when a quit occurs -// winquit.SimulateSigTermOnQuit(done) +// // Simulate SIGTERM when a quit occurs +// winquit.SimulateSigTermOnQuit(done) // -// // Wait until we receive a signal or quit event -// <-done +// // Wait until we receive a signal or quit event +// <-done // -// fmt.Println("Shutting down") -// // Perform cleanup tasks -// } +// fmt.Println("Shutting down") +// // Perform cleanup tasks +// } // // # Client example // // The following example demonstrates how an application can ask or // force other windows programs to quit: // -// func client() { -// // Ask nicely for program "one" to quit. This request may not -// // be honored if its a console application, or if the program -// // is hung -// if err := winquit.RequestQuit(pidOne); err != nil { -// fmt.Printf("error sending quit request, %s", err.Error()) -// } -// -// // Force program "two" to quit, but give it 20 seconds to -// // perform any cleanup tasks and quit on it's own -// timeout := time.Second * 20 -// if err := winquit.QuitProcess(pidTwo, timeout); err != nil { -// fmt.Printf("error killing process, %s", err.Error()) -// } -// } +// func client() { +// // Ask nicely for program "one" to quit. This request may not +// // be honored if its a console application, or if the program +// // is hung +// if err := winquit.RequestQuit(pidOne); err != nil { +// fmt.Printf("error sending quit request, %s", err.Error()) +// } +// +// // Force program "two" to quit, but give it 20 seconds to +// // perform any cleanup tasks and quit on it's own +// timeout := time.Second * 20 +// if err := winquit.QuitProcess(pidTwo, timeout); err != nil { +// fmt.Printf("error killing process, %s", err.Error()) +// } +// } // // # How it works // diff --git a/pkg/winquit/server.go b/pkg/winquit/server.go index 42fe3fa..79f3456 100644 --- a/pkg/winquit/server.go +++ b/pkg/winquit/server.go @@ -1,7 +1,7 @@ package winquit import ( - "os" + "os" ) // NotifyOnQuit relays a Windows quit notification to the boolean done channel. @@ -20,7 +20,7 @@ import ( // If this function is called after a Windows quit notification has occurred, it // will immediately deliver a "true" value. func NotifyOnQuit(done chan bool) { - notifyOnQuit(done) + notifyOnQuit(done) } // SimulateSigTermOnQuit relays a Windows quit notification following the same @@ -30,5 +30,5 @@ func NotifyOnQuit(done chan bool) { // This function allows for the reuse of the same underlying channel used with // in a separate os.signal.Notify method call. func SimulateSigTermOnQuit(handler chan os.Signal) { - simulateSigTermOnQuit(handler) + simulateSigTermOnQuit(handler) } diff --git a/pkg/winquit/server_unsupported.go b/pkg/winquit/server_unsupported.go index b5e8d06..fc440ab 100644 --- a/pkg/winquit/server_unsupported.go +++ b/pkg/winquit/server_unsupported.go @@ -4,7 +4,7 @@ package winquit import ( - "os" + "os" ) func notifyOnQuit(done chan bool) { diff --git a/pkg/winquit/server_windows.go b/pkg/winquit/server_windows.go index 9e77df8..d00a3bc 100644 --- a/pkg/winquit/server_windows.go +++ b/pkg/winquit/server_windows.go @@ -1,140 +1,140 @@ package winquit import ( - "os" - "path/filepath" - "runtime" - "strings" - "sync" - "syscall" - - "github.com/containers/winquit/pkg/winquit/win32" - "github.com/sirupsen/logrus" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "syscall" + + "github.com/containers/winquit/pkg/winquit/win32" + "github.com/sirupsen/logrus" ) type receiversType struct { - sync.Mutex + sync.Mutex - result bool - channels map[any]baseChannelType + result bool + channels map[any]baseChannelType } var ( - receivers *receiversType = &receiversType{ - channels: make(map[any]baseChannelType), - } + receivers *receiversType = &receiversType{ + channels: make(map[any]baseChannelType), + } - loopInit sync.Once + loopInit sync.Once ) func (r *receiversType) add(channel baseChannelType) { - r.Lock() - defer r.Unlock() + r.Lock() + defer r.Unlock() - if _, ok := r.channels[channel.getKey()]; ok { - return - } + if _, ok := r.channels[channel.getKey()]; ok { + return + } - if r.result { - go func() { - channel.notifyBlocking() - }() - return - } + if r.result { + go func() { + channel.notifyBlocking() + }() + return + } - r.channels[channel.getKey()] = channel + r.channels[channel.getKey()] = channel } func (r *receiversType) notifyAll() { - r.Lock() - defer r.Unlock() - r.result = true - for _, channel := range r.channels { - channel.notifyNonBlocking() - delete(r.channels, channel.getKey()) - } - for _, channel := range r.channels { - channel.notifyBlocking() - delete(r.channels, channel) - } + r.Lock() + defer r.Unlock() + r.result = true + for _, channel := range r.channels { + channel.notifyNonBlocking() + delete(r.channels, channel.getKey()) + } + for _, channel := range r.channels { + channel.notifyBlocking() + delete(r.channels, channel) + } } func initLoop() { - loopInit.Do(func() { - go messageLoop() - }) + loopInit.Do(func() { + go messageLoop() + }) } func notifyOnQuit(done chan bool) { - receivers.add(&boolChannelType{done}) - initLoop() + receivers.add(&boolChannelType{done}) + initLoop() } func simulateSigTermOnQuit(handler chan os.Signal) { - receivers.add(&sigChannelType{handler}) - initLoop() + receivers.add(&sigChannelType{handler}) + initLoop() } func messageLoop() { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - registerDummyWindow() - - logrus.Info("Entering loop for quit") - for { - ret, msg, err := win32.GetMessage(0, 0, 0) - if err != nil { - logrus.Debugf("Error receiving win32 message, %s", err.Error()) - continue - } - if ret == 0 { - logrus.Debug("Received QUIT notification") - receivers.notifyAll() - - return - } - logrus.Debugf("Unhandled message: %d", msg.Message) - win32.TranslateMessage(msg) - win32.DispatchMessage(msg) - } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + registerDummyWindow() + + logrus.Info("Entering loop for quit") + for { + ret, msg, err := win32.GetMessage(0, 0, 0) + if err != nil { + logrus.Debugf("Error receiving win32 message, %s", err.Error()) + continue + } + if ret == 0 { + logrus.Debug("Received QUIT notification") + receivers.notifyAll() + + return + } + logrus.Debugf("Unhandled message: %d", msg.Message) + win32.TranslateMessage(msg) + win32.DispatchMessage(msg) + } } func getAppName() (string, error) { - exeName, err := os.Executable() - if err != nil { - return "", err - } - suffix := filepath.Ext(exeName) - return strings.TrimSuffix(filepath.Base(exeName), suffix), nil + exeName, err := os.Executable() + if err != nil { + return "", err + } + suffix := filepath.Ext(exeName) + return strings.TrimSuffix(filepath.Base(exeName), suffix), nil } func registerDummyWindow() error { - var app syscall.Handle - var err error + var app syscall.Handle + var err error - app, err = win32.GetModuleHandle("") - if err != nil { - return err - } + app, err = win32.GetModuleHandle("") + if err != nil { + return err + } - appName, err := getAppName() - if err != nil { - return err - } + appName, err := getAppName() + if err != nil { + return err + } - className := appName + "-rclass" - winName := appName + "-root" + className := appName + "-rclass" + winName := appName + "-root" - _, err = win32.RegisterDummyWinClass(className, app) - if err != nil { - return err - } + _, err = win32.RegisterDummyWinClass(className, app) + if err != nil { + return err + } - _, err = win32.CreateDummyWindow(winName, className, app) - if err != nil { - return err - } + _, err = win32.CreateDummyWindow(winName, className, app) + if err != nil { + return err + } - return nil + return nil } diff --git a/pkg/winquit/win32/common.go b/pkg/winquit/win32/common.go index 5f08578..dbb1365 100644 --- a/pkg/winquit/win32/common.go +++ b/pkg/winquit/win32/common.go @@ -4,14 +4,14 @@ package win32 import ( - "syscall" + "syscall" ) const ( - ERROR_NO_MORE_ITEMS = 259 + ERROR_NO_MORE_ITEMS = 259 ) var ( - kernel32 = syscall.NewLazyDLL("kernel32.dll") - user32 = syscall.NewLazyDLL("user32.dll") + kernel32 = syscall.NewLazyDLL("kernel32.dll") + user32 = syscall.NewLazyDLL("user32.dll") ) diff --git a/pkg/winquit/win32/msg.go b/pkg/winquit/win32/msg.go index 67ec0ab..f6063a5 100644 --- a/pkg/winquit/win32/msg.go +++ b/pkg/winquit/win32/msg.go @@ -4,84 +4,84 @@ package win32 import ( - "syscall" - "unsafe" + "syscall" + "unsafe" ) type MSG struct { - HWnd uintptr - Message uint32 - WParam uintptr - LParam uintptr - Time uint32 - Pt struct{ X, Y int32 } + HWnd uintptr + Message uint32 + WParam uintptr + LParam uintptr + Time uint32 + Pt struct{ X, Y int32 } } const ( - WM_QUIT = 0x12 - WM_DESTROY = 0x02 - WM_CLOSE = 0x10 + WM_QUIT = 0x12 + WM_DESTROY = 0x02 + WM_CLOSE = 0x10 ) var ( - postQuitMessage = user32.NewProc("PostQuitMessage") - procGetMessage = user32.NewProc("GetMessageW") - procTranslateMessage = user32.NewProc("TranslateMessage") - procDispatchMessage = user32.NewProc("DispatchMessageW") - procSendMessage = user32.NewProc("SendMessageW") + postQuitMessage = user32.NewProc("PostQuitMessage") + procGetMessage = user32.NewProc("GetMessageW") + procTranslateMessage = user32.NewProc("TranslateMessage") + procDispatchMessage = user32.NewProc("DispatchMessageW") + procSendMessage = user32.NewProc("SendMessageW") ) func TranslateMessage(msg *MSG) bool { - ret, _, _ := - procTranslateMessage.Call( // BOOL TranslateMessage() - uintptr(unsafe.Pointer(msg)), // [in] const MSG *lpMsg - ) + ret, _, _ := + procTranslateMessage.Call( // BOOL TranslateMessage() + uintptr(unsafe.Pointer(msg)), // [in] const MSG *lpMsg + ) - return ret != 0 + return ret != 0 } func DispatchMessage(msg *MSG) uintptr { - ret, _, _ := - procDispatchMessage.Call( // LRESULT DispatchMessage() - uintptr(unsafe.Pointer(msg)), // [in] const MSG *lpMsg - ) + ret, _, _ := + procDispatchMessage.Call( // LRESULT DispatchMessage() + uintptr(unsafe.Pointer(msg)), // [in] const MSG *lpMsg + ) - return ret + return ret } func SendMessage(handle syscall.Handle, message uint, wparm uintptr, lparam uintptr) uintptr { - ret, _, _ := - procSendMessage.Call( // LRESULT SendMessage() - uintptr(handle), // [in] HWND hWnd - uintptr(message), // [in] UINT Msg - wparm, // [in] WPARAM wParam - lparam, // [in] LPARAM lParam - ) - - return ret + ret, _, _ := + procSendMessage.Call( // LRESULT SendMessage() + uintptr(handle), // [in] HWND hWnd + uintptr(message), // [in] UINT Msg + wparm, // [in] WPARAM wParam + lparam, // [in] LPARAM lParam + ) + + return ret } func PostQuitMessage(code int) { - _, _, _ = - postQuitMessage.Call( // void PostQuitMessage() - uintptr(code), // [in] int nExitCode - ) + _, _, _ = + postQuitMessage.Call( // void PostQuitMessage() + uintptr(code), // [in] int nExitCode + ) } func GetMessage(handle syscall.Handle, int, max int) (int32, *MSG, error) { - var msg MSG - ret, _, err := - procGetMessage.Call( // // BOOL GetMessage() - uintptr(unsafe.Pointer(&msg)), // [out] LPMSG lpMsg, - uintptr(handle), // [in, optional] HWND hWnd, - 0, // [in] UINT wMsgFilterMin, - 0, // [in] UINT wMsgFilterMax - ) - - if int32(ret) == -1 { - return 0, nil, err - } - - return int32(ret), &msg, nil + var msg MSG + ret, _, err := + procGetMessage.Call( // // BOOL GetMessage() + uintptr(unsafe.Pointer(&msg)), // [out] LPMSG lpMsg, + uintptr(handle), // [in, optional] HWND hWnd, + 0, // [in] UINT wMsgFilterMin, + 0, // [in] UINT wMsgFilterMax + ) + + if int32(ret) == -1 { + return 0, nil, err + } + + return int32(ret), &msg, nil } diff --git a/pkg/winquit/win32/proc.go b/pkg/winquit/win32/proc.go index 094d2d9..6f7ccfc 100644 --- a/pkg/winquit/win32/proc.go +++ b/pkg/winquit/win32/proc.go @@ -4,56 +4,56 @@ package win32 import ( - "fmt" - "syscall" + "fmt" + "syscall" ) const ( - MAXIMUM_ALLOWED = 0x02000000 + MAXIMUM_ALLOWED = 0x02000000 ) var ( - procOpenProcess = kernel32.NewProc("OpenProcess") - procCloseHandle = kernel32.NewProc("CloseHandle") - procGetModuleHandle = kernel32.NewProc("GetModuleHandleW") + procOpenProcess = kernel32.NewProc("OpenProcess") + procCloseHandle = kernel32.NewProc("CloseHandle") + procGetModuleHandle = kernel32.NewProc("GetModuleHandleW") ) func OpenProcess(pid uint32) (syscall.Handle, error) { - ret, _, err := - procOpenProcess.Call( // HANDLE OpenProcess() - MAXIMUM_ALLOWED, // [in] DWORD dwDesiredAccess, - 0, // [in] BOOL bInheritHandle, - uintptr(pid), // [in] DWORD dwProcessId - ) - - if ret == 0 { - return 0, err - } - - return syscall.Handle(ret), nil + ret, _, err := + procOpenProcess.Call( // HANDLE OpenProcess() + MAXIMUM_ALLOWED, // [in] DWORD dwDesiredAccess, + 0, // [in] BOOL bInheritHandle, + uintptr(pid), // [in] DWORD dwProcessId + ) + + if ret == 0 { + return 0, err + } + + return syscall.Handle(ret), nil } func CloseHandle(handle syscall.Handle) error { - ret, _, err := - procCloseHandle.Call( // BOOL CloseHandle() - uintptr(handle), // [in] HANDLE hObject - ) - if ret != 0 { - return fmt.Errorf("error closing handle: %w", err) - } - - return nil + ret, _, err := + procCloseHandle.Call( // BOOL CloseHandle() + uintptr(handle), // [in] HANDLE hObject + ) + if ret != 0 { + return fmt.Errorf("error closing handle: %w", err) + } + + return nil } func GetProcThreads(pid uint32) ([]uint, error) { - process, err := OpenProcess(pid) - if err != nil { - return nil, err - } + process, err := OpenProcess(pid) + if err != nil { + return nil, err + } - defer func() { - _ = CloseHandle(process) - }() + defer func() { + _ = CloseHandle(process) + }() - return GetProcThreadIds(process) + return GetProcThreadIds(process) } diff --git a/pkg/winquit/win32/pss.go b/pkg/winquit/win32/pss.go index c146095..bd03959 100644 --- a/pkg/winquit/win32/pss.go +++ b/pkg/winquit/win32/pss.go @@ -4,157 +4,157 @@ package win32 import ( - "fmt" - "syscall" - "unsafe" + "fmt" + "syscall" + "unsafe" ) type PSS_THREAD_ENTRY struct { - ExitStatus uint32 - TebBaseAddress uintptr - ProcessId uint32 - ThreadId uint32 - AffinityMask uintptr - Priority int32 - BasePriority int32 - LastSyscallFirstArgument uintptr - LastSyscallNumber uint16 - CreateTime uint64 - ExitTime uint64 - KernelTime uint64 - UserTime uint64 - Win32StartAddress uintptr - CaptureTime uint64 - Flags uint32 - SuspendCount uint16 - SizeOfContextRecord uint16 - ContextRecord uintptr + ExitStatus uint32 + TebBaseAddress uintptr + ProcessId uint32 + ThreadId uint32 + AffinityMask uintptr + Priority int32 + BasePriority int32 + LastSyscallFirstArgument uintptr + LastSyscallNumber uint16 + CreateTime uint64 + ExitTime uint64 + KernelTime uint64 + UserTime uint64 + Win32StartAddress uintptr + CaptureTime uint64 + Flags uint32 + SuspendCount uint16 + SizeOfContextRecord uint16 + ContextRecord uintptr } const ( - PSS_CAPTURE_THREADS = 0x00000080 - PSS_WALK_THREADS = 3 - PSS_QUERY_THREAD_INFORMATION = 5 + PSS_CAPTURE_THREADS = 0x00000080 + PSS_WALK_THREADS = 3 + PSS_QUERY_THREAD_INFORMATION = 5 ) var ( - procPssCaptureSnapshot = kernel32.NewProc("PssCaptureSnapshot") - procPssFreeSnapshot = kernel32.NewProc("PssFreeSnapshot") - procPssWalkMarkerCreate = kernel32.NewProc("PssWalkMarkerCreate") - procPssWalkMarkerFree = kernel32.NewProc("PssWalkMarkerFree") - procPssWalkSnapshot = kernel32.NewProc("PssWalkSnapshot") + procPssCaptureSnapshot = kernel32.NewProc("PssCaptureSnapshot") + procPssFreeSnapshot = kernel32.NewProc("PssFreeSnapshot") + procPssWalkMarkerCreate = kernel32.NewProc("PssWalkMarkerCreate") + procPssWalkMarkerFree = kernel32.NewProc("PssWalkMarkerFree") + procPssWalkSnapshot = kernel32.NewProc("PssWalkSnapshot") ) func PssCaptureSnapshot(process syscall.Handle, flags int32, contextFlags int32) (syscall.Handle, error) { - var snapshot syscall.Handle - ret, _, err := - procPssCaptureSnapshot.Call( // DWORD PssCaptureSnapshot() - uintptr(process), // [in] HANDLE ProcessHandle, - uintptr(flags), // [in] PSS_CAPTURE_FLAGS CaptureFlags, - uintptr(contextFlags), // [in, optional] DWORD ThreadContextFlags, - uintptr(unsafe.Pointer(&snapshot)), // [out] HPSS *SnapshotHandle - ) - - if ret != 0 { - return 0, err - } - - return snapshot, nil + var snapshot syscall.Handle + ret, _, err := + procPssCaptureSnapshot.Call( // DWORD PssCaptureSnapshot() + uintptr(process), // [in] HANDLE ProcessHandle, + uintptr(flags), // [in] PSS_CAPTURE_FLAGS CaptureFlags, + uintptr(contextFlags), // [in, optional] DWORD ThreadContextFlags, + uintptr(unsafe.Pointer(&snapshot)), // [out] HPSS *SnapshotHandle + ) + + if ret != 0 { + return 0, err + } + + return snapshot, nil } func PssFreeSnapshot(process syscall.Handle, snapshot syscall.Handle) error { - ret, _, _ := - procPssFreeSnapshot.Call( // DWORD PssFreeSnapshot() - uintptr(process), // [in] HANDLE ProcessHandle, - uintptr(snapshot), // [in] HPSS SnapshotHandle - ) - if ret != 0 { - return fmt.Errorf("error freeing snapshot: %d", ret) - } - - return nil + ret, _, _ := + procPssFreeSnapshot.Call( // DWORD PssFreeSnapshot() + uintptr(process), // [in] HANDLE ProcessHandle, + uintptr(snapshot), // [in] HPSS SnapshotHandle + ) + if ret != 0 { + return fmt.Errorf("error freeing snapshot: %d", ret) + } + + return nil } func PssWalkMarkerCreate() (syscall.Handle, error) { - var walkptr uintptr - - ret, _, _ := - procPssWalkMarkerCreate.Call( // // DWORD PssWalkMarkerCreate() - 0, // [in, optional] PSS_ALLOCATOR const *Allocator - uintptr(unsafe.Pointer(&walkptr)), // [out] HPSSWALK *WalkMarkerHandle - ) - if ret != 0 { - return 0, fmt.Errorf("error creating process walker mark: %d", ret) - } - - return syscall.Handle(walkptr), nil + var walkptr uintptr + + ret, _, _ := + procPssWalkMarkerCreate.Call( // // DWORD PssWalkMarkerCreate() + 0, // [in, optional] PSS_ALLOCATOR const *Allocator + uintptr(unsafe.Pointer(&walkptr)), // [out] HPSSWALK *WalkMarkerHandle + ) + if ret != 0 { + return 0, fmt.Errorf("error creating process walker mark: %d", ret) + } + + return syscall.Handle(walkptr), nil } func PssWalkMarkerFree(handle syscall.Handle) error { - ret, _, _ := - procPssWalkMarkerFree.Call( // DWORD PssWalkMarkerFree() - uintptr(handle), // [in] HPSSWALK WalkMarkerHandle - ) - if ret != 0 { - return fmt.Errorf("error freeing process walker mark: %d", ret) - } - - return nil + ret, _, _ := + procPssWalkMarkerFree.Call( // DWORD PssWalkMarkerFree() + uintptr(handle), // [in] HPSSWALK WalkMarkerHandle + ) + if ret != 0 { + return fmt.Errorf("error freeing process walker mark: %d", ret) + } + + return nil } func PssWalkThreadSnapshot(snapshot syscall.Handle, marker syscall.Handle) (*PSS_THREAD_ENTRY, error) { - var thread PSS_THREAD_ENTRY - ret, _, err := - procPssWalkSnapshot.Call( // // DWORD PssWalkSnapshot() - uintptr(snapshot), // [in] HPSS SnapshotHandle, - PSS_WALK_THREADS, // [in] PSS_WALK_INFORMATION_CLASS InformationClass, - uintptr(marker), // [in] HPSSWALK WalkMarkerHandle, - uintptr(unsafe.Pointer(&thread)), // [out] void *Buffer, - unsafe.Sizeof(thread), // [in] DWORD BufferLength - ) - - if ret == ERROR_NO_MORE_ITEMS { - return nil, nil - } - - if ret != 0 { - return nil, fmt.Errorf("error waling thread snapshot: %d (%w)", ret, err) - } - - return &thread, nil + var thread PSS_THREAD_ENTRY + ret, _, err := + procPssWalkSnapshot.Call( // // DWORD PssWalkSnapshot() + uintptr(snapshot), // [in] HPSS SnapshotHandle, + PSS_WALK_THREADS, // [in] PSS_WALK_INFORMATION_CLASS InformationClass, + uintptr(marker), // [in] HPSSWALK WalkMarkerHandle, + uintptr(unsafe.Pointer(&thread)), // [out] void *Buffer, + unsafe.Sizeof(thread), // [in] DWORD BufferLength + ) + + if ret == ERROR_NO_MORE_ITEMS { + return nil, nil + } + + if ret != 0 { + return nil, fmt.Errorf("error waling thread snapshot: %d (%w)", ret, err) + } + + return &thread, nil } func GetProcThreadIds(process syscall.Handle) ([]uint, error) { - snapshot, err := PssCaptureSnapshot(process, PSS_CAPTURE_THREADS, 0) - if err != nil { - return nil, err - } - defer func() { - _ = PssFreeSnapshot(process, snapshot) - }() - - marker, err := PssWalkMarkerCreate() - if err != nil { - return nil, err - } - - defer func() { - _ = PssWalkMarkerFree(marker) - }() - - var threads []uint - - for { - thread, err := PssWalkThreadSnapshot(snapshot, marker) - if err != nil { - return nil, err - } - if thread == nil { - break - } - - threads = append(threads, uint(thread.ThreadId)) - } - - return threads, nil + snapshot, err := PssCaptureSnapshot(process, PSS_CAPTURE_THREADS, 0) + if err != nil { + return nil, err + } + defer func() { + _ = PssFreeSnapshot(process, snapshot) + }() + + marker, err := PssWalkMarkerCreate() + if err != nil { + return nil, err + } + + defer func() { + _ = PssWalkMarkerFree(marker) + }() + + var threads []uint + + for { + thread, err := PssWalkThreadSnapshot(snapshot, marker) + if err != nil { + return nil, err + } + if thread == nil { + break + } + + threads = append(threads, uint(thread.ThreadId)) + } + + return threads, nil } diff --git a/pkg/winquit/win32/win.go b/pkg/winquit/win32/win.go index 38f237c..b243b0b 100644 --- a/pkg/winquit/win32/win.go +++ b/pkg/winquit/win32/win.go @@ -4,159 +4,159 @@ package win32 import ( - "fmt" - "syscall" - "unsafe" + "fmt" + "syscall" + "unsafe" ) type WNDCLASSEX struct { - cbSize uint32 - style uint32 - lpfnWndProc uintptr - cbClsExtra int32 - cbWndExtra int32 - hInstance syscall.Handle - hIcon syscall.Handle - hCursor syscall.Handle - hbrBackground syscall.Handle - menuName *uint16 - className *uint16 - hIconSm syscall.Handle + cbSize uint32 + style uint32 + lpfnWndProc uintptr + cbClsExtra int32 + cbWndExtra int32 + hInstance syscall.Handle + hIcon syscall.Handle + hCursor syscall.Handle + hbrBackground syscall.Handle + menuName *uint16 + className *uint16 + hIconSm syscall.Handle } const ( - COLOR_WINDOW = 0x05 - CW_USEDEFAULT = ^0x7fffffff + COLOR_WINDOW = 0x05 + CW_USEDEFAULT = ^0x7fffffff ) var ( - procEnumThreadWindows = user32.NewProc("EnumThreadWindows") - procRegisterClassEx = user32.NewProc("RegisterClassExW") - procCreateWindowEx = user32.NewProc("CreateWindowExW") - procDefWinProc = user32.NewProc("DefWindowProcW") + procEnumThreadWindows = user32.NewProc("EnumThreadWindows") + procRegisterClassEx = user32.NewProc("RegisterClassExW") + procCreateWindowEx = user32.NewProc("CreateWindowExW") + procDefWinProc = user32.NewProc("DefWindowProcW") - callbackEnumThreadWindows = syscall.NewCallback(wndProcCloseWindow) + callbackEnumThreadWindows = syscall.NewCallback(wndProcCloseWindow) ) func DefWindowProc(hWnd syscall.Handle, msg uint32, wParam uintptr, lParam uintptr) int32 { - ret, _, _ := - procDefWinProc.Call( // LRESULT DefWindowProcW() - uintptr(hWnd), // [in] HWND hWnd, - uintptr(msg), // [in] UINT Msg, - wParam, // [in] WPARAM wParam, - lParam, // [in] LPARAM lParam - ) - return int32(ret) + ret, _, _ := + procDefWinProc.Call( // LRESULT DefWindowProcW() + uintptr(hWnd), // [in] HWND hWnd, + uintptr(msg), // [in] UINT Msg, + wParam, // [in] WPARAM wParam, + lParam, // [in] LPARAM lParam + ) + return int32(ret) } func GetModuleHandle(name string) (syscall.Handle, error) { - var name16 *uint16 - var err error - - if len(name) > 0 { - if name16, err = syscall.UTF16PtrFromString(name); err != nil { - return 0, err - } - } - - ret, _, err := - procGetModuleHandle.Call( // HMODULE GetModuleHandleW() - uintptr(unsafe.Pointer(name16)), // [in, optional] LPCWSTR lpModuleName - ) - if ret == 0 { - return 0, fmt.Errorf("could not obtain module handle: %w", err) - } - - return syscall.Handle(ret), nil + var name16 *uint16 + var err error + + if len(name) > 0 { + if name16, err = syscall.UTF16PtrFromString(name); err != nil { + return 0, err + } + } + + ret, _, err := + procGetModuleHandle.Call( // HMODULE GetModuleHandleW() + uintptr(unsafe.Pointer(name16)), // [in, optional] LPCWSTR lpModuleName + ) + if ret == 0 { + return 0, fmt.Errorf("could not obtain module handle: %w", err) + } + + return syscall.Handle(ret), nil } func RegisterClassEx(class *WNDCLASSEX) (uint16, error) { - ret, _, err := - procRegisterClassEx.Call( // ATOM RegisterClassExW() - uintptr(unsafe.Pointer(class)), // [in] const WNDCLASSEXW *unnamedParam1 - ) - if ret == 0 { - return 0, fmt.Errorf("could not register window class: %w", err) - } + ret, _, err := + procRegisterClassEx.Call( // ATOM RegisterClassExW() + uintptr(unsafe.Pointer(class)), // [in] const WNDCLASSEXW *unnamedParam1 + ) + if ret == 0 { + return 0, fmt.Errorf("could not register window class: %w", err) + } - return uint16(ret), nil + return uint16(ret), nil } func wndProc(hWnd syscall.Handle, msg uint32, wParam uintptr, lParam uintptr) uintptr { - switch msg { - case WM_DESTROY: - PostQuitMessage(0) - return 0 - default: - return uintptr(DefWindowProc(hWnd, msg, wParam, lParam)) - } + switch msg { + case WM_DESTROY: + PostQuitMessage(0) + return 0 + default: + return uintptr(DefWindowProc(hWnd, msg, wParam, lParam)) + } } func CloseThreadWindows(threadId uint32) { - _, _, _ = - procEnumThreadWindows.Call( // // BOOL EnumThreadWindows() - uintptr(threadId), // [in] DWORD dwThreadId, - callbackEnumThreadWindows, // [in] WNDENUMPROC lpfn, - 0, // [in] LPARAM lParam - ) + _, _, _ = + procEnumThreadWindows.Call( // // BOOL EnumThreadWindows() + uintptr(threadId), // [in] DWORD dwThreadId, + callbackEnumThreadWindows, // [in] WNDENUMPROC lpfn, + 0, // [in] LPARAM lParam + ) } func wndProcCloseWindow(hwnd uintptr, lparam uintptr) uintptr { - SendMessage(syscall.Handle(hwnd), WM_CLOSE, 0, 0) + SendMessage(syscall.Handle(hwnd), WM_CLOSE, 0, 0) - return 1 + return 1 } func RegisterDummyWinClass(name string, appInstance syscall.Handle) (uint16, error) { - var class16 *uint16 - var err error - if class16, err = syscall.UTF16PtrFromString(name); err != nil { - return 0, err - } + var class16 *uint16 + var err error + if class16, err = syscall.UTF16PtrFromString(name); err != nil { + return 0, err + } - class := WNDCLASSEX{ - className: class16, - hInstance: appInstance, - lpfnWndProc: syscall.NewCallback(wndProc), - } + class := WNDCLASSEX{ + className: class16, + hInstance: appInstance, + lpfnWndProc: syscall.NewCallback(wndProc), + } - class.cbSize = uint32(unsafe.Sizeof(class)) + class.cbSize = uint32(unsafe.Sizeof(class)) - return RegisterClassEx(&class) + return RegisterClassEx(&class) } func CreateDummyWindow(name string, className string, appInstance syscall.Handle) (syscall.Handle, error) { - var name16, class16 *uint16 - var err error - - cwDefault := CW_USEDEFAULT - - if name16, err = syscall.UTF16PtrFromString(name); err != nil { - return 0, err - } - if class16, err = syscall.UTF16PtrFromString(className); err != nil { - return 0, err - } - ret, _, err := procCreateWindowEx.Call( //HWND CreateWindowExW( - 0, // [in] DWORD dwExStyle, - uintptr(unsafe.Pointer(class16)), // [in, optional] LPCWSTR lpClassName, - uintptr(unsafe.Pointer(name16)), // [in, optional] LPCWSTR lpWindowName, - 0, // [in] DWORD dwStyle, - uintptr(cwDefault), // [in] int X, - uintptr(cwDefault), // [in] int Y, - 0, // [in] int nWidth, - 0, // [in] int nHeight, - 0, // [in, optional] HWND hWndParent, - 0, // [in, optional] HMENU hMenu, - uintptr(appInstance), // [in, optional] HINSTANCE hInstance, - 0, // [in, optional] LPVOID lpParam - ) - - if ret == 0 { - return 0, fmt.Errorf("could not create window: %w", err) - } - - return syscall.Handle(ret), nil + var name16, class16 *uint16 + var err error + + cwDefault := CW_USEDEFAULT + + if name16, err = syscall.UTF16PtrFromString(name); err != nil { + return 0, err + } + if class16, err = syscall.UTF16PtrFromString(className); err != nil { + return 0, err + } + ret, _, err := procCreateWindowEx.Call( //HWND CreateWindowExW( + 0, // [in] DWORD dwExStyle, + uintptr(unsafe.Pointer(class16)), // [in, optional] LPCWSTR lpClassName, + uintptr(unsafe.Pointer(name16)), // [in, optional] LPCWSTR lpWindowName, + 0, // [in] DWORD dwStyle, + uintptr(cwDefault), // [in] int X, + uintptr(cwDefault), // [in] int Y, + 0, // [in] int nWidth, + 0, // [in] int nHeight, + 0, // [in, optional] HWND hWndParent, + 0, // [in, optional] HMENU hMenu, + uintptr(appInstance), // [in, optional] HINSTANCE hInstance, + 0, // [in, optional] LPVOID lpParam + ) + + if ret == 0 { + return 0, fmt.Errorf("could not create window: %w", err) + } + + return syscall.Handle(ret), nil } diff --git a/test/client_test.go b/test/client_test.go index 1701832..9441f66 100644 --- a/test/client_test.go +++ b/test/client_test.go @@ -4,14 +4,14 @@ package e2e import ( - "os" - "os/exec" - "path/filepath" - "time" - - "github.com/containers/winquit/pkg/winquit" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/containers/winquit/pkg/winquit" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" ) var WINQUIT_PATH = filepath.Join("..", "bin", "winquit.exe") @@ -20,91 +20,91 @@ const SHOULD_TIME = 10 const WONT_TIME = 5 var _ = Describe("perquisites", func() { - It("winquit binary is built", func() { - _, err := os.Stat(WINQUIT_PATH) - Expect(err).ShouldNot(HaveOccurred()) - }) + It("winquit binary is built", func() { + _, err := os.Stat(WINQUIT_PATH) + Expect(err).ShouldNot(HaveOccurred()) + }) }) var _ = Describe("client", func() { - It("request quit should kill thidparty(winver) process", func() { - cmd := exec.Command("winver") - verifyRequestQuit(cmd, SHOULD_TIME, true) - }) + It("request quit should kill thidparty(winver) process", func() { + cmd := exec.Command("winver") + verifyRequestQuit(cmd, SHOULD_TIME, true) + }) }) var _ = Describe("client", func() { - It("request quit kills winquit simple server", func() { - cmd := exec.Command(WINQUIT_PATH, "simple-server") - verifyRequestQuit(cmd, SHOULD_TIME, true) - }) + It("request quit kills winquit simple server", func() { + cmd := exec.Command(WINQUIT_PATH, "simple-server") + verifyRequestQuit(cmd, SHOULD_TIME, true) + }) }) var _ = Describe("client", func() { - It("request quit kills winquit multi-server", func() { - cmd := exec.Command(WINQUIT_PATH, "multi-server") - verifyRequestQuit(cmd, SHOULD_TIME, true) - }) + It("request quit kills winquit multi-server", func() { + cmd := exec.Command(WINQUIT_PATH, "multi-server") + verifyRequestQuit(cmd, SHOULD_TIME, true) + }) }) var _ = Describe("client", func() { - It("request quit kills winquit signal server", func() { - cmd := exec.Command(WINQUIT_PATH, "signal-server") - verifyRequestQuit(cmd, SHOULD_TIME, true) - }) + It("request quit kills winquit signal server", func() { + cmd := exec.Command(WINQUIT_PATH, "signal-server") + verifyRequestQuit(cmd, SHOULD_TIME, true) + }) }) var _ = Describe("client", func() { - It("request quit does not kill winquit hang server", func() { - cmd := exec.Command(WINQUIT_PATH, "hang-server") - verifyRequestQuit(cmd, WONT_TIME, false) - }) + It("request quit does not kill winquit hang server", func() { + cmd := exec.Command(WINQUIT_PATH, "hang-server") + verifyRequestQuit(cmd, WONT_TIME, false) + }) }) var _ = Describe("client", func() { - It("demand quit does kill winquit hang server", func() { - cmd := exec.Command(WINQUIT_PATH, "hang-server") - verifyForceQuit(cmd, WONT_TIME, SHOULD_TIME, true) - }) + It("demand quit does kill winquit hang server", func() { + cmd := exec.Command(WINQUIT_PATH, "hang-server") + verifyForceQuit(cmd, WONT_TIME, SHOULD_TIME, true) + }) }) func verifyRequestQuit(cmd *exec.Cmd, timeout int, outcome bool) { - verifyStart(cmd) - winquit.RequestQuit(cmd.Process.Pid) - verifyExit(cmd, timeout, outcome) + verifyStart(cmd) + winquit.RequestQuit(cmd.Process.Pid) + verifyExit(cmd, timeout, outcome) } func verifyForceQuit(cmd *exec.Cmd, forceTimeout int, timeout int, outcome bool) { - verifyStart(cmd) - winquit.QuitProcess(cmd.Process.Pid, time.Duration(forceTimeout)*time.Second) - verifyExit(cmd, timeout, outcome) + verifyStart(cmd) + winquit.QuitProcess(cmd.Process.Pid, time.Duration(forceTimeout)*time.Second) + verifyExit(cmd, timeout, outcome) } func verifyStart(cmd *exec.Cmd) { - err := cmd.Start() - Expect(err).ShouldNot(HaveOccurred()) - time.Sleep(100 * time.Millisecond) - Expect(cmd.ProcessState).To(BeNil()) - _, err = os.FindProcess(cmd.Process.Pid) - Expect(err).ShouldNot(HaveOccurred()) + err := cmd.Start() + Expect(err).ShouldNot(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + Expect(cmd.ProcessState).To(BeNil()) + _, err = os.FindProcess(cmd.Process.Pid) + Expect(err).ShouldNot(HaveOccurred()) } func verifyExit(cmd *exec.Cmd, timeout int, outcome bool) { - completed := make(chan bool) - go func() { - cmd.Wait() - completed <- true - }() - - result := false - select { - case <-completed: - result = true - case <-time.After(time.Duration(timeout) * time.Second): - } - - Expect(result).To(Equal(outcome)) - if !outcome { - cmd.Process.Kill() - } + completed := make(chan bool) + go func() { + cmd.Wait() + completed <- true + }() + + result := false + select { + case <-completed: + result = true + case <-time.After(time.Duration(timeout) * time.Second): + } + + Expect(result).To(Equal(outcome)) + if !outcome { + cmd.Process.Kill() + } } diff --git a/test/suite_test.go b/test/suite_test.go index 6d9387d..759f659 100644 --- a/test/suite_test.go +++ b/test/suite_test.go @@ -4,13 +4,13 @@ package e2e import ( - "testing" + "testing" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" ) func TestTest(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Test Suite") + RegisterFailHandler(Fail) + RunSpecs(t, "Test Suite") }