-
Notifications
You must be signed in to change notification settings - Fork 3
/
jndi.go
97 lines (86 loc) · 2.01 KB
/
jndi.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Package jndi lets Gophers participate in the log4j fun (CVE-2021-44228).
//
// It would be irresponsible to use this package.
package jndi
import (
"io"
"io/ioutil"
"log"
"net/http"
"os"
"regexp"
"strings"
)
// NewLogger returns a new Logger that does expansion and evalation of jndi expression
// within the user-influenceable log text.
func NewLogger() *log.Logger {
return log.New(Wrap(os.Stderr), "", log.LstdFlags)
}
func Wrap(w io.Writer) io.Writer {
return writer{w, realEnv}
}
var realEnv = env{
transport: http.DefaultTransport,
getEnv: os.Getenv,
}
type env struct {
transport http.RoundTripper
getEnv func(string) string
}
type writer struct {
ww io.Writer
e env
}
func (w writer) Write(p []byte) (n int, err error) {
n = len(p)
_, err = io.WriteString(w.ww, w.e.subst(string(p)))
return n, err // close enough
}
var opRx = regexp.MustCompile(`\$\{(\w+?):(?:[^}\$]|(\$[^\{]))+}`)
func (e env) subst(s string) string {
for {
s2 := opRx.ReplaceAllStringFunc(s, func(sub string) string {
i := strings.Index(sub, ":")
return e.lookup(sub[2:i], sub[i+1:len(sub)-1])
})
if s2 == s {
return s2
}
s = s2
}
}
// see https://logging.apache.org/log4j/2.x/manual/lookups.html
func (e env) lookup(op, arg string) string {
switch op {
case "lower":
return strings.ToLower(arg)
case "upper":
return strings.ToUpper(arg)
case "env":
if s := e.getEnv(strings.ToLower(arg)); s != "" {
return s
}
if s := e.getEnv(strings.ToUpper(arg)); s != "" {
return s
}
case "jndi":
// I looked at gopkg.in/ldap.v2 and got scared.
// A GET request is enough to do some DNS data and leak
// some environment variable secrets.
urlStr := strings.Replace(arg, "ldap://", "http://", 1) // oh well
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return err.Error()
}
res, err := e.transport.RoundTrip(req)
if err != nil {
return err.Error()
}
all, err := ioutil.ReadAll(res.Body)
if err != nil {
return err.Error()
}
return string(all)
}
return ""
}