Skip to content

Commit

Permalink
Support for using multiple protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
balajiv113 committed Jan 25, 2023
1 parent e1d9f2c commit ab1d2d8
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 54 deletions.
10 changes: 10 additions & 0 deletions pkg/tap/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package tap

import (
"net"
)

type protocolConn struct {
net.Conn
protocolImpl protocol
}
77 changes: 34 additions & 43 deletions pkg/tap/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Switch struct {
maxTransmissionUnit int

nextConnID int
conns map[int]net.Conn
conns map[int]protocolConn
connLock sync.Mutex

cam map[tcpip.LinkAddress]int
Expand All @@ -45,17 +45,14 @@ type Switch struct {
writeLock sync.Mutex

gateway VirtualDevice

protocol protocol
}

func NewSwitch(debug bool, mtu int, protocol types.Protocol) *Switch {
func NewSwitch(debug bool, mtu int) *Switch {
return &Switch{
debug: debug,
maxTransmissionUnit: mtu,
conns: make(map[int]net.Conn),
conns: make(map[int]protocolConn),
cam: make(map[tcpip.LinkAddress]int),
protocol: protocolImplementation(protocol),
}
}

Expand All @@ -73,13 +70,14 @@ func (e *Switch) Connect(ep VirtualDevice) {
e.gateway = ep
}

func (e *Switch) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
func (e *Switch) DeliverNetworkPacket(_ tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
if err := e.tx(pkt); err != nil {
log.Error(err)
}
}

func (e *Switch) Accept(ctx context.Context, conn net.Conn) error {
func (e *Switch) Accept(ctx context.Context, rawConn net.Conn, protocol types.Protocol) error {
conn := protocolConn{Conn: rawConn, protocolImpl: protocolImplementation(protocol)}
log.Infof("new connection from %s to %s", conn.RemoteAddr().String(), conn.LocalAddr().String())
id, failed := e.connect(conn)
if failed {
Expand All @@ -100,7 +98,7 @@ func (e *Switch) Accept(ctx context.Context, conn net.Conn) error {
return nil
}

func (e *Switch) connect(conn net.Conn) (int, bool) {
func (e *Switch) connect(conn protocolConn) (int, bool) {
e.connLock.Lock()
defer e.connLock.Unlock()

Expand All @@ -112,23 +110,10 @@ func (e *Switch) connect(conn net.Conn) (int, bool) {
}

func (e *Switch) tx(pkt stack.PacketBufferPtr) error {
if e.protocol.Stream() {
return e.txStream(pkt, e.protocol.(streamProtocol))
}
return e.txNonStream(pkt)
}

func (e *Switch) txNonStream(pkt stack.PacketBufferPtr) error {
return e.txBuf(pkt, nil)
}

func (e *Switch) txStream(pkt stack.PacketBufferPtr, sProtocol streamProtocol) error {
size := sProtocol.Buf()
sProtocol.Write(size, pkt.Size())
return e.txBuf(pkt, size)
return e.txPkt(pkt)
}

func (e *Switch) txBuf(pkt stack.PacketBufferPtr, size []byte) error {
func (e *Switch) txPkt(pkt stack.PacketBufferPtr) error {
e.writeLock.Lock()
defer e.writeLock.Unlock()

Expand All @@ -151,14 +136,9 @@ func (e *Switch) txBuf(pkt stack.PacketBufferPtr, size []byte) error {
if id == srcID {
continue
}
if len(size) > 0 {
if _, err := conn.Write(size); err != nil {
e.disconnect(id, conn)
return err
}
}
if _, err := conn.Write(buf); err != nil {
e.disconnect(id, conn)

err := e.txBuf(id, conn, buf)
if err != nil {
return err
}

Expand All @@ -173,17 +153,28 @@ func (e *Switch) txBuf(pkt stack.PacketBufferPtr, size []byte) error {
}
e.camLock.RUnlock()
conn := e.conns[id]
if len(size) > 0 {
if _, err := conn.Write(size); err != nil {
e.disconnect(id, conn)
return err
}
err := e.txBuf(id, conn, buf)
if err != nil {
return err
}
if _, err := conn.Write(buf); err != nil {
atomic.AddUint64(&e.Sent, uint64(pkt.Size()))
}
return nil
}

func (e *Switch) txBuf(id int, conn protocolConn, buf []byte) error {
if conn.protocolImpl.Stream() {
size := conn.protocolImpl.(streamProtocol).Buf()
conn.protocolImpl.(streamProtocol).Write(size, len(buf))

if _, err := conn.Write(size); err != nil {
e.disconnect(id, conn)
return err
}
atomic.AddUint64(&e.Sent, uint64(pkt.Size()))
}
if _, err := conn.Write(buf); err != nil {
e.disconnect(id, conn)
return err
}
return nil
}
Expand All @@ -201,9 +192,9 @@ func (e *Switch) disconnect(id int, conn net.Conn) {
delete(e.conns, id)
}

func (e *Switch) rx(ctx context.Context, id int, conn net.Conn) error {
if e.protocol.Stream() {
return e.rxStream(ctx, id, conn, e.protocol.(streamProtocol))
func (e *Switch) rx(ctx context.Context, id int, conn protocolConn) error {
if conn.protocolImpl.Stream() {
return e.rxStream(ctx, id, conn, conn.protocolImpl.(streamProtocol))
}
return e.rxNonStream(ctx, id, conn)
}
Expand Down Expand Up @@ -254,7 +245,7 @@ loop:
return nil
}

func (e *Switch) rxBuf(ctx context.Context, id int, buf []byte) {
func (e *Switch) rxBuf(_ context.Context, id int, buf []byte) {
if e.debug {
packet := gopacket.NewPacket(buf, layers.LayerTypeEthernet, gopacket.Default)
log.Info(packet.String())
Expand Down
12 changes: 6 additions & 6 deletions pkg/types/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ type Configuration struct {
// Allow to assign a pre-defined MAC address to an Hyperkit VM
VpnKitUUIDMacAddresses map[string]string

// Qemu or Hyperkit protocol
// Qemu protocol is 32bits big endian size of the packet, then the packet.
// Hyperkit protocol is handshake, then 16bits little endian size of packet, then the packet.
// Bess protocol transfers bare L2 packets as SOCK_SEQPACKET.
// Protocol to be used. Only for /connect mux
Protocol Protocol
}

type Protocol string

const (
// HyperKitProtocol is handshake, then 16bits little endian size of packet, then the packet.
HyperKitProtocol Protocol = "hyperkit"
QemuProtocol Protocol = "qemu"
BessProtocol Protocol = "bess"
// QemuProtocol is 32bits big endian size of the packet, then the packet.
QemuProtocol Protocol = "qemu"
// BessProtocol transfers bare L2 packets as SOCK_SEQPACKET.
BessProtocol Protocol = "bess"
)

type Zone struct {
Expand Down
3 changes: 2 additions & 1 deletion pkg/virtualnetwork/bess.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package virtualnetwork

import (
"context"
"github.com/containers/gvisor-tap-vsock/pkg/types"
"net"
)

func (n *VirtualNetwork) AcceptBess(ctx context.Context, conn net.Conn) error {
return n.networkSwitch.Accept(ctx, conn)
return n.networkSwitch.Accept(ctx, conn, types.BessProtocol)
}
2 changes: 1 addition & 1 deletion pkg/virtualnetwork/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (n *VirtualNetwork) Mux() *http.ServeMux {
return
}

_ = n.networkSwitch.Accept(context.Background(), conn)
_ = n.networkSwitch.Accept(context.Background(), conn, n.configuration.Protocol)
})
mux.HandleFunc("/tunnel", func(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
Expand Down
3 changes: 2 additions & 1 deletion pkg/virtualnetwork/qemu.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package virtualnetwork

import (
"context"
"github.com/containers/gvisor-tap-vsock/pkg/types"
"net"
)

func (n *VirtualNetwork) AcceptQemu(ctx context.Context, conn net.Conn) error {
return n.networkSwitch.Accept(ctx, conn)
return n.networkSwitch.Accept(ctx, conn, types.QemuProtocol)
}
2 changes: 1 addition & 1 deletion pkg/virtualnetwork/virtualnetwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func New(configuration *types.Configuration) (*VirtualNetwork, error) {
if err != nil {
return nil, errors.Wrap(err, "cannot create tap endpoint")
}
networkSwitch := tap.NewSwitch(configuration.Debug, configuration.MTU, configuration.Protocol)
networkSwitch := tap.NewSwitch(configuration.Debug, configuration.MTU)
tapEndpoint.Connect(networkSwitch)
networkSwitch.Connect(tapEndpoint)

Expand Down
2 changes: 1 addition & 1 deletion pkg/virtualnetwork/vpnkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (n *VirtualNetwork) AcceptVpnKit(conn net.Conn) error {
if err := vpnkitHandshake(conn, n.configuration); err != nil {
log.Error(err)
}
_ = n.networkSwitch.Accept(context.Background(), conn)
_ = n.networkSwitch.Accept(context.Background(), conn, types.HyperKitProtocol)
return nil
}

Expand Down

0 comments on commit ab1d2d8

Please sign in to comment.