/
switch.go
173 lines (155 loc) · 4.35 KB
/
switch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package link
import (
"context"
"errors"
"fmt"
"sync"
"github.com/google/gopacket"
gplayers "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
)
type (
// SwitchConfig contains the configs for RunSwitch().
SwitchConfig struct {
MetricLabels struct {
StackName string `yaml:"stackName"`
} `yaml:"metricLabels"`
Ports []EthernetPortConfig `yaml:"ethernetPorts"`
}
SwitchWaitCloseFunc func(func(frame <-chan *gplayers.Ethernet))
switchImpl struct {
conf *SwitchConfig
ports []EthernetPort
}
)
// RunSwitch runs a hypothetical L2 switch, which decaps and
// forwards ethernet frames based on a forwarding table
// constructed by learning L2 routes on the fly: frame comes
// in, a mapping from src MAC address to port number is cached.
// No mapping present on the table: frame is forwarded to all
// other ports.
// If dst MAC address matches the MAC address of one of the
// switch's ports, the frame is discarded.
//
// The returned function blocks until the switch has stopped
// running, which happens upon the given ctx being cancelled.
// The function argument is another function, passed for
// consuming any potential frames that remained in the ports.
func RunSwitch(ctx context.Context, conf SwitchConfig) (SwitchWaitCloseFunc, error) {
// create switch
if len(conf.Ports) < 3 {
return nil, errors.New("switch will only run with a least three ports")
}
ports := make([]EthernetPort, 0, len(conf.Ports))
for i, portConf := range conf.Ports {
portConf.ForwardingMode = true
if portConf.MetricLabels.StackName == "" {
portConf.MetricLabels.StackName = conf.MetricLabels.StackName
}
port, err := NewEthernetPort(ctx, portConf)
if err != nil {
for j := i - 1; 0 <= j; j-- {
ports[j].Close()
}
return nil, fmt.Errorf("error creating ethernet port number %d: %w", i, err)
}
ports = append(ports, port)
}
s := &switchImpl{
conf: &conf,
ports: ports,
}
// start switching thread
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
s.run(ctx)
}()
return func(f func(frame <-chan *gplayers.Ethernet)) {
wg.Wait()
if f == nil {
return
}
for _, port := range s.ports {
f(port.Recv())
}
}, nil
}
func (s *switchImpl) run(ctx context.Context) {
var wg sync.WaitGroup
defer func() {
wg.Wait() // wait all port threads first
for _, port := range s.ports {
port.Close()
}
}()
var forwardingTable sync.Map
storeRoute := func(macAddress gopacket.Endpoint, portNumber int) {
// sync.Map is optimized for write once, read many times.
// because switch routes do not change often, we first read
// the current value and only update if the new portNumber
// is different
oldPortNumber, hasOldRoute := forwardingTable.Load(macAddress)
if !hasOldRoute || oldPortNumber.(int) != portNumber {
forwardingTable.Store(macAddress, portNumber)
}
}
portMACAddresses := make(map[gopacket.Endpoint]struct{})
for _, port := range s.ports {
portMACAddresses[port.MACAddress()] = struct{}{}
}
ctxDone := ctx.Done()
for i, fromPort := range s.ports {
// make local copies so the i-th thread captures
// references only to its own (i, fromPort) pair
i := i
fromPort := fromPort
// start port thread
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctxDone:
return
case frame := <-fromPort.Recv():
l := logrus.
WithField("from_port", i).
WithField("frame", frame)
// update forwarding table
srcMAC := gplayers.NewMACEndpoint(frame.SrcMAC)
storeRoute(srcMAC, i)
// check if dst mac address matches the mac address of one of the ports
dstMAC := gplayers.NewMACEndpoint(frame.DstMAC)
if _, isPortMACAddress := portMACAddresses[dstMAC]; isPortMACAddress {
continue
}
// fetch route and forward
dstPort, hasRoute := forwardingTable.Load(dstMAC)
if hasRoute {
j := dstPort.(int)
if err := s.ports[j].Send(ctx, frame); err != nil {
l.
WithError(err).
WithField("to_port", j).
Error("error forwarding frame")
}
continue
}
// no route, forward to all other ports
for j, toPort := range s.ports {
if j != i {
if err := toPort.Send(ctx, frame); err != nil {
l.
WithError(err).
WithField("to_port", j).
Error("error forwarding frame")
}
}
}
}
}
}()
}
}