/
h2.go
156 lines (138 loc) · 4.97 KB
/
h2.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright 2021 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package h2 contains basic HTTP/2 handling for Martian.
package h2
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net/url"
"sync"
"github.com/google/martian/v3/log"
"golang.org/x/net/http2"
)
var (
// connectionPreface is the constant value of the connection preface.
// https://tools.ietf.org/html/rfc7540#section-3.5
connectionPreface = []byte("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
)
// Config stores the configuration information needed for HTTP/2 processing.
type Config struct {
// AllowedHostsFilter is a function returning true if the argument is a host for which H2 is
// permitted.
AllowedHostsFilter func(string) bool
// RootCAs is the pool of CA certificates used by the MitM client to authenticate the server.
RootCAs *x509.CertPool
// StreamProcessorFactories is a list of factories used to instantiate a chain of HTTP/2 stream
// processors. A chain is created for every stream.
StreamProcessorFactories []StreamProcessorFactory
// EnableDebugLogs turns on fine-grained debug logging for HTTP/2.
EnableDebugLogs bool
}
// Proxy proxies HTTP/2 traffic between a client connection, `cc`, and the HTTP/2 `url` assuming
// h2 is being used. Since no browsers use h2c, it's safe to assume all traffic uses TLS.
func (c *Config) Proxy(closing chan bool, cc io.ReadWriter, url *url.URL) error {
if c.EnableDebugLogs {
log.Infof("\u001b[1;35mProxying %v with HTTP/2\u001b[0m", url)
}
sc, err := tls.Dial("tcp", url.Host, &tls.Config{
RootCAs: c.RootCAs,
NextProtos: []string{"h2"},
})
if err != nil {
return fmt.Errorf("connecting h2 to %v: %w", url, err)
}
if err := forwardPreface(sc, cc); err != nil {
return fmt.Errorf("initializing h2 with %v: %w", url, err)
}
cf, sf := http2.NewFramer(cc, cc), http2.NewFramer(sc, sc)
cToS := newRelay(ClientToServer, "client", url.String(), cf, sf, &c.EnableDebugLogs)
sToC := newRelay(ServerToClient, url.String(), "client", sf, cf, &c.EnableDebugLogs)
// Completes circular parts of the initialization.
// The client-to-server relay depends on the server-to-client relay and vice versa.
cToS.peer, sToC.peer = sToC, cToS
// Creating processors is circular because the create function references the relays and the
// relays need to call create.
cToS.processors = &streamProcessors{
create: func(id uint32) *Processors {
p := &Processors{cToS: &relayAdapter{id, cToS}, sToC: &relayAdapter{id, sToC}}
// Chains the pipeline of processors together.
for i := len(c.StreamProcessorFactories) - 1; i >= 0; i-- {
cToS, sToC := c.StreamProcessorFactories[i](url, p)
// Bypasses any nil processors.
if cToS == nil {
cToS = p.ForDirection(ClientToServer)
}
if sToC == nil {
sToC = p.ForDirection(ServerToClient)
}
p = &Processors{cToS: cToS, sToC: sToC}
}
return p
},
}
sToC.processors = cToS.processors
var wg sync.WaitGroup
wg.Add(2)
go func() { // Forwards frames from client to server.
defer wg.Done()
if err := cToS.relayFrames(closing); err != nil {
log.Errorf("relaying frame from client to %v: %v", url, err)
}
}()
go func() { // Forwards frames from server to client.
defer wg.Done()
if err := sToC.relayFrames(closing); err != nil {
log.Errorf("relaying frame from %v to client: %v", url, err)
}
}()
wg.Wait()
return nil
}
// forwardPreface forwards the connection preface from the client to the server.
func forwardPreface(server io.Writer, client io.Reader) error {
preface := make([]byte, len(connectionPreface))
if _, err := client.Read(preface); err != nil {
return fmt.Errorf("reading preface: %w", err)
}
if !bytes.Equal(preface, connectionPreface) {
return fmt.Errorf("client sent unexpected preface: %s", hex.Dump(preface))
}
for m := len(connectionPreface); m > 0; {
n, err := server.Write([]byte(preface))
if err != nil {
return fmt.Errorf("writing preface: %w", err)
}
preface = preface[n:]
m -= n
}
return nil
}
type streamProcessors struct {
// processors stores `*Processors` instances keyed by uint32 stream ID.
processors sync.Map
// create creates `*Processors` for the given stream ID.
create func(uint32) *Processors
}
// Get returns a the processor with the given ID and direction.
func (s *streamProcessors) Get(id uint32, dir Direction) Processor {
value, ok := s.processors.Load(id)
if !ok {
value, _ = s.processors.LoadOrStore(id, s.create(id))
}
return value.(*Processors).ForDirection(dir)
}