/
main.go
178 lines (141 loc) · 3.87 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package main
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
"unicode/utf16"
"unsafe"
)
// type workingDirectory interface {
// get() (string, error)
// }
// type osWorkingDirectory struct{}
// func (*osWorkingDirectory) get() (string, error) {
// return os.Getwd()
// }
type scriptType int
const (
none scriptType = iota
powershell
cmdOrBat
)
var (
kernel = syscall.MustLoadDLL("kernel32.dll")
getModuleFileNameProc = kernel.MustFindProc("GetModuleFileNameW")
)
// Gets full path to this actual program
// from Windows kernel.
func getModuleFileName() (string, error) {
var n uint32
// Buffer to receive exe's path
b := make([]uint16, syscall.MAX_PATH)
size := uint32(len(b))
r0, _, e1 := getModuleFileNameProc.Call(0, uintptr(unsafe.Pointer(&b[0])), uintptr(size))
n = uint32(r0)
if n == 0 {
return "", e1
}
return string(utf16.Decode(b[0:n])), nil
}
func debugEnabled() bool {
return os.Getenv("GITHOOK_DEBUG") != ""
}
func scriptNameWithoutExtension(hookDir, hookName, repo string) string {
return filepath.Join(hookDir, repo, hookName)
}
func getHookScript(hookDir, hookName, repo string) (string, scriptType) {
scriptMap := map[string]scriptType{
"bat": cmdOrBat,
"cmd": cmdOrBat,
"ps1": powershell,
}
for _, dir := range []string{repo, "00-githooks-shared"} {
for extension, typ := range scriptMap {
script := fmt.Sprintf("%s.%s", scriptNameWithoutExtension(hookDir, hookName, dir), extension)
if _, err := os.Stat(script); err == nil {
if debugEnabled() {
fmt.Printf("Will execute %s\n", script)
}
return script, typ
}
}
}
return "", none
}
const defaultFailedCode = 1
func getExitCode(ps *os.ProcessState) int {
if ws, ok := ps.Sys().(syscall.WaitStatus); !ok {
return defaultFailedCode
} else {
return ws.ExitStatus()
}
}
func RunCommand(cmd *exec.Cmd) (exitCode int) {
// Redirect pipes of script to run
// attaching them to this process's pipes
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
err := cmd.Run()
if err != nil {
// try to get the exit code
if exitError, ok := err.(*exec.ExitError); ok {
exitCode = getExitCode(exitError.ProcessState)
} else {
exitCode = defaultFailedCode
fmt.Println(err.Error())
}
} else {
// success, exitCode should be 0 if go is ok
exitCode = getExitCode(cmd.ProcessState)
}
return
}
func processHook(actualBin string, hookName string) (int, bool) {
// Hooks directory is the deirectory where this executable is found.
hookDir := filepath.Dir(actualBin)
hookArgs := os.Args[1:]
wd, _ := os.Getwd()
// Git executes the hook with the working directory set to the root of the repo.
repoName := filepath.Base(wd)
var cmd *exec.Cmd
script, typ := getHookScript(hookDir, hookName, repoName)
switch typ {
case cmdOrBat:
cmd = exec.Command("cmd", append([]string{"/c", script}, hookArgs...)...)
case powershell:
cmd = exec.Command("powershell", append([]string{"-File", script}, hookArgs...)...)
default:
if debugEnabled() {
fmt.Printf("(No script found for hook %s)\n", hookName)
}
// Exit zero (don't so action), and false to indicate no script run (for tests).
return 0, false
}
// Exit code of script, and true to indicate script was run.
return RunCommand(cmd), true
}
func main() {
// First argument is the path used to invoke the program.
// If invoked via a symlink, then the symlink name will
// be the basename of this path, and we use this to
// identify the actual hook.
hookName := filepath.Base(os.Args[0])
// This will be the absolute path to the real executable,
// not any relative or symlinked path.
actualBin, err := getModuleFileName()
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
i := strings.LastIndex(hookName, ".")
if i > 1 {
// Strip extension.
hookName = hookName[:i]
}
exitCode, _ := processHook(actualBin, hookName)
os.Exit(exitCode)
}