Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ipsec: Improve encrypt flush command #28795

Merged
merged 8 commits into from
Nov 6, 2023
7 changes: 5 additions & 2 deletions Documentation/cmdref/cilium-dbg_encrypt_flush.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

168 changes: 166 additions & 2 deletions cilium-dbg/cmd/encrypt_flush.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,22 @@
package cmd

import (
"fmt"
"strconv"
"strings"

"github.com/spf13/cobra"
"github.com/vishvananda/netlink"

"github.com/cilium/cilium/pkg/command"
"github.com/cilium/cilium/pkg/common"
"github.com/cilium/cilium/pkg/common/ipsec"
"github.com/cilium/cilium/pkg/datapath/linux/linux_defaults"
)

const (
spiFlagName = "spi"
nodeIDFlagName = "node-id"
)

var encryptFlushCmd = &cobra.Command{
Expand All @@ -17,12 +28,165 @@ var encryptFlushCmd = &cobra.Command{
Long: "Will cause a short connectivity disruption",
Run: func(cmd *cobra.Command, args []string) {
common.RequireRootPrivilege("cilium encrypt flush")
netlink.XfrmPolicyFlush()
netlink.XfrmStateFlush(netlink.XFRM_PROTO_ESP)
runXFRMFlush()
},
}

var (
spiToFilter uint8
nodeIDToFilter uint16
nodeIDParam string
)

func runXFRMFlush() {
if spiToFilter == 0 && nodeIDParam == "" {
flushEverything()
return
}

if spiToFilter > linux_defaults.IPsecMaxKeyVersion {
Fatalf("Given SPI is too big")
}

if nodeIDParam != "" {
var err error
nodeIDToFilter, err = parseNodeID(nodeIDParam)
if err != nil {
Fatalf("Unable to parse node ID %q: %s", nodeIDParam, err)
}
}

states, err := netlink.XfrmStateList(netlink.FAMILY_ALL)
if err != nil {
Fatalf("Failed to retrieve XFRM states: %s", err)
}
policies, err := netlink.XfrmPolicyList(netlink.FAMILY_ALL)
if err != nil {
Fatalf("Failed to retrieve XFRM policies: %s", err)
}
nbStates := len(states)
nbPolicies := len(policies)

if spiToFilter != 0 {
policies, states = filterXFRMBySPI(policies, states)
}
if nodeIDToFilter != 0 {
policies, states = filterXFRMByNodeID(policies, states)
}

if len(policies) == nbPolicies || len(states) == nbStates {
confirmationMsg := "Running this command will delete all XFRM state and/or policies. " +
"It will lead to transient connectivity disruption and plain-text pod-to-pod traffic."
if !confirmXFRMCleanup(confirmationMsg) {
return
}
}

for _, state := range states {
if err := netlink.XfrmStateDel(&state); err != nil {
Fatalf("Stopped XFRM states deletion due to error: %s", err)
}
}
fmt.Printf("Deleted %d XFRM states.\n", len(states))
for _, pol := range policies {
if err := netlink.XfrmPolicyDel(&pol); err != nil {
Fatalf("Stopped XFRM policies deletion due to error: %s", err)
}
}
fmt.Printf("Deleted %d XFRM policies.\n", len(policies))
}

func parseNodeID(nodeID string) (uint16, error) {
var (
val int64
err error
)

if strings.HasPrefix(nodeID, "0x") {
val, err = strconv.ParseInt(nodeID[2:], 16, 0)
if err != nil {
return 0, err
}
} else {
val, err = strconv.ParseInt(nodeID, 10, 0)
if err != nil {
return 0, err
}
}

if val == 0 {
return 0, fmt.Errorf("0 is not a valid node ID in this context")
}

if val < 0 || val > int64(^uint16(0)) {
return 0, fmt.Errorf("given node ID doesn't fit in uint16")
}
return uint16(val), nil
}

type policyFilter func(netlink.XfrmPolicy) bool
type stateFilter func(netlink.XfrmState) bool

func filterXFRMBySPI(policies []netlink.XfrmPolicy, states []netlink.XfrmState) ([]netlink.XfrmPolicy, []netlink.XfrmState) {
return filterXFRMs(policies, states, func(pol netlink.XfrmPolicy) bool {
return ipsec.GetSPIFromXfrmPolicy(&pol) == spiToFilter
}, func(state netlink.XfrmState) bool {
return state.Spi == int(spiToFilter)
})
}

func filterXFRMByNodeID(policies []netlink.XfrmPolicy, states []netlink.XfrmState) ([]netlink.XfrmPolicy, []netlink.XfrmState) {
return filterXFRMs(policies, states, func(pol netlink.XfrmPolicy) bool {
return ipsec.GetNodeIDFromXfrmMark(pol.Mark) == nodeIDToFilter
}, func(state netlink.XfrmState) bool {
return ipsec.GetNodeIDFromXfrmMark(state.Mark) == nodeIDToFilter
})
}

func filterXFRMs(policies []netlink.XfrmPolicy, states []netlink.XfrmState,
filterPol policyFilter, filterState stateFilter) ([]netlink.XfrmPolicy, []netlink.XfrmState) {
Comment on lines +146 to +147
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability, this could be written:

func filterXFRMs(
	policies []netlink.XfrmPolicy, states []netlink.XfrmState,
	filterPol policyFilter, filterState stateFilter
) ([]netlink.XfrmPolicy, []netlink.XfrmState) {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look a bit better. You could make it a linter :-)

policiesToDel := []netlink.XfrmPolicy{}
for _, pol := range policies {
if filterPol(pol) {
policiesToDel = append(policiesToDel, pol)
}
}

statesToDel := []netlink.XfrmState{}
for _, state := range states {
if filterState(state) {
statesToDel = append(statesToDel, state)
}
}

return policiesToDel, statesToDel
}

func flushEverything() {
confirmationMsg := "Flushing all XFRM states and policies can lead to transient " +
"connectivity interruption and plain-text pod-to-pod traffic."
if !confirmXFRMCleanup(confirmationMsg) {
return
}
netlink.XfrmPolicyFlush()
netlink.XfrmStateFlush(netlink.XFRM_PROTO_ESP)
fmt.Println("All XFRM states and policies have been deleted.")
}

func confirmXFRMCleanup(msg string) bool {
if force {
return true
}
var res string
fmt.Printf("%s Do you want to continue? [y/N] ", msg)
fmt.Scanln(&res)
return res == "y"
}

func init() {
encryptFlushCmd.Flags().BoolVarP(&force, forceFlagName, "f", false, "Skip confirmation")
encryptFlushCmd.Flags().Uint8Var(&spiToFilter, spiFlagName, 0, "Only delete states and policies with this SPI. If multiple filters are used, they all apply")
encryptFlushCmd.Flags().StringVar(&nodeIDParam, nodeIDFlagName, "", "Only delete states and policies with this node ID. Decimal or hexadecimal (0x) format. If multiple filters are used, they all apply")
CncryptCmd.AddCommand(encryptFlushCmd)
command.AddOutputOption(encryptFlushCmd)
}
102 changes: 102 additions & 0 deletions cilium-dbg/cmd/encrypt_flush_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package cmd

import (
"net"
"testing"

"github.com/vishvananda/netlink"
)

func TestFilterXFRMs(t *testing.T) {
policies := []netlink.XfrmPolicy{
{Ifid: 1, Proto: netlink.XFRM_PROTO_ESP, Dst: &net.IPNet{IP: net.ParseIP("192.168.1.0"), Mask: net.CIDRMask(24, 32)}},
{Ifid: 2, Proto: netlink.XFRM_PROTO_AH, Dst: &net.IPNet{IP: net.ParseIP("192.168.1.0"), Mask: net.CIDRMask(24, 32)}},
{Ifid: 3, Proto: netlink.XFRM_PROTO_ESP, Dst: &net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(16, 32)}},
{Ifid: 4, Proto: netlink.XFRM_PROTO_AH, Dst: &net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(16, 32)}},
}
states := []netlink.XfrmState{
{Ifid: 1, Proto: netlink.XFRM_PROTO_ESP, Dst: net.ParseIP("192.168.1.0")},
{Ifid: 2, Proto: netlink.XFRM_PROTO_AH, Dst: net.ParseIP("192.168.1.0")},
{Ifid: 3, Proto: netlink.XFRM_PROTO_ESP, Dst: net.ParseIP("10.0.0.0")},
{Ifid: 4, Proto: netlink.XFRM_PROTO_AH, Dst: net.ParseIP("10.0.0.0")},
}
filterDstPolicy := func(pol netlink.XfrmPolicy) bool {
return pol.Dst.IP.String() == "192.168.1.0"
}
filterDstState := func(state netlink.XfrmState) bool {
return state.Dst.String() == "192.168.1.0"
}
filterProtoPolicy := func(pol netlink.XfrmPolicy) bool {
return pol.Proto == netlink.XFRM_PROTO_ESP
}
filterProtoState := func(state netlink.XfrmState) bool {
return state.Proto == netlink.XFRM_PROTO_ESP
}

// Test that single call to filterXFRMs provides the expected results.
resPolicies, resStates := filterXFRMs(policies, states, filterDstPolicy, filterDstState)
if len(resPolicies) != 2 {
t.Errorf("Expected two policies to be filtered, but got %d", len(resPolicies))
}
if len(resStates) != 2 {
t.Errorf("Expected two states to be filtered, but got %d", len(resStates))
}
if resPolicies[0].Ifid != 1 || resPolicies[1].Ifid != 2 {
t.Errorf("Expected policies with Ifids 1 and 2 to be filtered, but got policies with Ifids %d and %d", resPolicies[0].Ifid, resPolicies[1].Ifid)
}
if resStates[0].Ifid != 1 || resStates[1].Ifid != 2 {
t.Errorf("Expected state with Ifids 1 and 2 to be filtered, but got states with Ifids %d and %d", resStates[0].Ifid, resStates[1].Ifid)
}

// Test that chained calls to filterXFRMs also provide the expected results.
resPolicies, resStates = filterXFRMs(resPolicies, resStates, filterProtoPolicy, filterProtoState)
if len(resPolicies) != 1 {
t.Errorf("Expected one policy to be filtered, but got %d", len(resPolicies))
}
if len(resStates) != 1 {
t.Errorf("Expected one state to be filtered, but got %d", len(resStates))
}
if resPolicies[0].Ifid != 1 {
t.Errorf("Expected policies with Ifid 1 to be filtered, but got policies with Ifid %d", resPolicies[0].Ifid)
}
if resStates[0].Ifid != 1 {
t.Errorf("Expected state with Ifid 1 to be filtered, but got states with Ifid %d", resStates[0].Ifid)
}
}

func TestParseNodeID(t *testing.T) {
tests := []struct {
input string
expected uint16
err bool
}{
{"0x0", 0, true},
{"42", 42, false},
{"0x1a", 26, false},
{"65535", 65535, false},
{"70000", 0, true}, // Too big for uint16
{"invalid", 0, true},
{"0xinvalid", 0, true},
{"0xdeadbeef", 0, true}, // Too big for uint16
}

for _, test := range tests {
result, err := parseNodeID(test.input)
if test.err {
if err == nil {
t.Errorf("Expected error for input %s, but got nil", test.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %s: %v", test.input, err)
}

if result != test.expected {
t.Errorf("For input %s, expected %d, but got %d", test.input, test.expected, result)
}
}
}
}
22 changes: 22 additions & 0 deletions pkg/common/ipsec/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package ipsec

import (
"github.com/vishvananda/netlink"

"github.com/cilium/cilium/pkg/datapath/linux/linux_defaults"
)

const (
Expand Down Expand Up @@ -58,3 +60,23 @@ func CountXfrmPoliciesByDir(states []netlink.XfrmPolicy) (int, int, int) {
}
return nbXfrmIn, nbXfrmOut, nbXfrmFwd
}

func GetSPIFromXfrmPolicy(policy *netlink.XfrmPolicy) uint8 {
if policy.Mark == nil {
return 0
}

return ipSecXfrmMarkGetSPI(policy.Mark.Value)
}

// ipSecXfrmMarkGetSPI extracts from a XfrmMark value the encoded SPI
func ipSecXfrmMarkGetSPI(markValue uint32) uint8 {
return uint8(markValue >> linux_defaults.IPsecXFRMMarkSPIShift & 0xF)
}

func GetNodeIDFromXfrmMark(mark *netlink.XfrmMark) uint16 {
if mark == nil {
return 0
}
return uint16(mark.Value >> 16)
}