-
Notifications
You must be signed in to change notification settings - Fork 0
/
write_to_conn.go
82 lines (74 loc) · 1.7 KB
/
write_to_conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package io2
import (
"io"
"io/ioutil"
"net"
"sync"
"sync/atomic"
)
type WriteConn struct {
isDoneAtomic uint32
wait sync.WaitGroup
conn ConnWithShutdown
downstream io.Writer
writeErr error
}
type ConnWithShutdown interface {
net.Conn
CloseWrite() error
}
func NewWriteConn(connection ConnWithShutdown, downstream io.Writer) (io.WriteCloser, error) {
writeConn := WriteConn{
conn: connection,
downstream: downstream,
}
writeConn.wait.Add(1)
go writeConn.copyDownstream()
return &writeConn, nil
}
func (wcself *WriteConn) copyDownstream() {
_, wcself.writeErr = io.Copy(wcself.downstream, wcself.conn)
if wcself.writeErr != nil {
duplexCloser, ok := wcself.conn.(ConnWithDuplexShutdown)
if ok {
_ = duplexCloser.CloseRead()
_ = duplexCloser.CloseWrite()
} else {
_, _ = io.Copy(ioutil.Discard, wcself.conn)
}
}
atomic.AddUint32(&wcself.isDoneAtomic, 1)
wcself.wait.Done()
}
func (wcself *WriteConn) Write(data []byte) (int, error) {
if atomic.LoadUint32(&wcself.isDoneAtomic) != 0 {
wcself.wait.Wait()
if wcself.writeErr == nil {
return 0, io.ErrShortWrite
}
err := wcself.writeErr
wcself.writeErr = nil
return 0, err // return the write error back asap
}
return wcself.conn.Write(data)
}
func (wcself *WriteConn) Close() error {
shutdownErr := wcself.conn.CloseWrite()
if shutdownErr != nil {
return shutdownErr
}
wcself.wait.Wait()
socketErr := wcself.conn.Close()
writeCloser, ok := wcself.downstream.(io.WriteCloser)
if ok {
err := writeCloser.Close()
if (wcself.writeErr != nil &&
wcself.writeErr != io.EOF) || err != nil {
if wcself.writeErr != nil {
err = wcself.writeErr
}
return err
}
}
return socketErr
}