Skip to content

Commit

Permalink
refactor:filter, stub and passthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
shivamsouravjha committed Jan 18, 2024
1 parent f25ff78 commit 8f35c38
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 97 deletions.
28 changes: 18 additions & 10 deletions cmd/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ func readRecordConfig(configPath string) (*models.Record, error) {
return &doc.Record, nil
}

var filters = models.Filters{}
var filters = models.TestFilter{}

func (t *Record) GetRecordConfig(path *string, proxyPort *uint32, appCmd *string, appContainer, networkName *string, Delay *uint64, buildDelay *time.Duration, passThroughPorts *[]uint, PassThroughHosts *[]string, configPath string) error {
func (t *Record) GetRecordConfig(path *string, proxyPort *uint32, appCmd *string, appContainer, networkName *string, Delay *uint64, buildDelay *time.Duration, passThroughPorts *[]uint, passThrough *[]models.Filters, configPath string) error {
configFilePath := filepath.Join(configPath, "keploy-config.yaml")
if isExist := utils.CheckFileExists(configFilePath); !isExist {
return errFileNotFound
Expand All @@ -56,7 +56,9 @@ func (t *Record) GetRecordConfig(path *string, proxyPort *uint32, appCmd *string
if len(*path) == 0 {
*path = confRecord.Path
}
filters = confRecord.Filters

filters = confRecord.Tests

if *proxyPort == 0 {
*proxyPort = confRecord.ProxyPort
}
Expand All @@ -75,12 +77,18 @@ func (t *Record) GetRecordConfig(path *string, proxyPort *uint32, appCmd *string
if *buildDelay == 30*time.Second && confRecord.BuildDelay != 0 {
*buildDelay = confRecord.BuildDelay
}
*passThrough = append(*passThrough, confRecord.Stubs.Filters...)

if len(*passThroughPorts) == 0 {
*passThroughPorts = confRecord.PassThroughPorts
}
if len(*PassThroughHosts) == 0 {
*PassThroughHosts = confRecord.BypassEndpointsRegistry
for _, filter := range confRecord.Stubs.Filters {
if filter.Port != 0 && filter.Host == "" && filter.Path == "" {
*passThroughPorts = append(*passThroughPorts, filter.Port)
} else {
*passThrough = append(*passThrough, filter)
}
}
}

return nil
}

Expand Down Expand Up @@ -158,9 +166,9 @@ func (r *Record) GetCmd() *cobra.Command {
return err
}

passThroughHosts := []string{}
passThrough := []models.Filters{}

err = r.GetRecordConfig(&path, &proxyPort, &appCmd, &appContainer, &networkName, &delay, &buildDelay, &ports, &passThroughHosts, configPath)
err = r.GetRecordConfig(&path, &proxyPort, &appCmd, &appContainer, &networkName, &delay, &buildDelay, &ports, &passThrough, configPath)
if err != nil {
if err == errFileNotFound {
r.logger.Info("Keploy config not found, using default config")
Expand Down Expand Up @@ -218,7 +226,7 @@ func (r *Record) GetCmd() *cobra.Command {
}

r.logger.Debug("the ports are", zap.Any("ports", ports))
r.recorder.CaptureTraffic(path, proxyPort, appCmd, appContainer, networkName, delay, buildDelay, ports, &filters, enableTele, passThroughHosts)
r.recorder.CaptureTraffic(path, proxyPort, appCmd, appContainer, networkName, delay, buildDelay, ports, &filters, enableTele, passThrough)
return nil
},
}
Expand Down
22 changes: 14 additions & 8 deletions cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func readTestConfig(configPath string) (*models.Test, error) {
return &doc.Test, nil
}

func (t *Test) getTestConfig(path *string, proxyPort *uint32, appCmd *string, tests *map[string][]string, appContainer, networkName *string, Delay *uint64, buildDelay *time.Duration, passThorughPorts *[]uint, apiTimeout *uint64, globalNoise *models.GlobalNoise, testSetNoise *models.TestsetNoise, coverageReportPath *string, withCoverage *bool, configPath string, passThroughHosts *[]string) error {
func (t *Test) getTestConfig(path *string, proxyPort *uint32, appCmd *string, tests *map[string][]string, appContainer, networkName *string, Delay *uint64, buildDelay *time.Duration, passThroughPorts *[]uint, apiTimeout *uint64, globalNoise *models.GlobalNoise, testSetNoise *models.TestsetNoise, coverageReportPath *string, withCoverage *bool, configPath string, passThroughHosts *[]models.Filters) error {
configFilePath := filepath.Join(configPath, "keploy-config.yaml")
if isExist := utils.CheckFileExists(configFilePath); !isExist {
return errFileNotFound
Expand All @@ -60,7 +60,7 @@ func (t *Test) getTestConfig(path *string, proxyPort *uint32, appCmd *string, te
if *appCmd == "" {
*appCmd = confTest.Command
}
for testset, testcases := range confTest.Tests {
for testset, testcases := range confTest.SelectedTests {
if _, ok := (*tests)[testset]; !ok {
(*tests)[testset] = testcases
}
Expand All @@ -77,9 +77,7 @@ func (t *Test) getTestConfig(path *string, proxyPort *uint32, appCmd *string, te
if *buildDelay == 30*time.Second && confTest.BuildDelay != 0 {
*buildDelay = confTest.BuildDelay
}
if len(*passThorughPorts) == 0 {
*passThorughPorts = confTest.PassThroughPorts
}

if len(*coverageReportPath) == 0 {
*coverageReportPath = confTest.CoverageReportPath
}
Expand All @@ -89,9 +87,17 @@ func (t *Test) getTestConfig(path *string, proxyPort *uint32, appCmd *string, te
}
*globalNoise = confTest.GlobalNoise.Global
*testSetNoise = confTest.GlobalNoise.Testsets
if len(*passThroughHosts) == 0 {
*passThroughHosts = confTest.BypassEndpointsRegistry
*passThroughHosts = append(*passThroughHosts, confTest.Stubs.Filters...)
if len(*passThroughPorts) == 0 {
for _, filter := range confTest.Stubs.Filters {
if filter.Port != 0 && filter.Host == "" && filter.Path == "" {
*passThroughPorts = append(*passThroughPorts, filter.Port)
} else {
*passThroughHosts = append(*passThroughHosts, filter)
}
}
}

return nil
}

Expand Down Expand Up @@ -199,7 +205,7 @@ func (t *Test) GetCmd() *cobra.Command {
globalNoise := make(models.GlobalNoise)
testsetNoise := make(models.TestsetNoise)

passThroughHosts := []string{}
passThroughHosts := []models.Filters{}

err = t.getTestConfig(&path, &proxyPort, &appCmd, &tests, &appContainer, &networkName, &delay, &buildDelay, &ports, &apiTimeout, &globalNoise, &testsetNoise, &coverageReportPath, &withCoverage, configPath, &passThroughHosts)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/hooks/connection/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func NewFactory(inactivityThreshold time.Duration, logger *zap.Logger) *Factory
}
}

func (factory *Factory) HandleReadyConnections(db platform.TestCaseDB, ctx context.Context, filters *models.Filters) {
func (factory *Factory) HandleReadyConnections(db platform.TestCaseDB, ctx context.Context, filters *models.TestFilter) {
factory.mutex.Lock()
defer factory.mutex.Unlock()
var trackersToDelete []structs.ConnID
Expand Down Expand Up @@ -96,7 +96,7 @@ func (factory *Factory) GetOrCreate(connectionID structs.ConnID) *Tracker {
return tracker
}

func capture(db platform.TestCaseDB, req *http.Request, resp *http.Response, logger *zap.Logger, ctx context.Context, reqTimeTest time.Time, resTimeTest time.Time, filters *models.Filters) {
func capture(db platform.TestCaseDB, req *http.Request, resp *http.Response, logger *zap.Logger, ctx context.Context, reqTimeTest time.Time, resTimeTest time.Time, filters *models.TestFilter) {
reqBody, err := io.ReadAll(req.Body)
if err != nil {
logger.Error("failed to read the http request body", zap.Error(err))
Expand Down
21 changes: 16 additions & 5 deletions pkg/hooks/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ type Hook struct {

idc clients.InternalDockerClient
configMocks []*models.Mock
passThroughHosts []string
passThroughHosts models.Stubs
sourcePort int
}

func NewHook(db platform.TestCaseDB, mainRoutineId int, logger *zap.Logger) (*Hook, error) {
Expand Down Expand Up @@ -138,12 +139,22 @@ func (h *Hook) GetProxyPort() uint32 {
return h.proxyPort
}

func (h *Hook) GetProxyHost() []string {
func (h *Hook) GetProxyHost() models.Stubs {
return h.passThroughHosts
}

func (h *Hook) SetProxyHosts(passThroughHosts []string) {
h.passThroughHosts = passThroughHosts
func (h *Hook) SetProxyHosts(passThroughHosts []models.Filters) {
h.passThroughHosts = models.Stubs{
Filters: passThroughHosts,
}
}

func (h *Hook) GetSourcePort() int {
return h.sourcePort
}

func (h *Hook) SetSourcePort(sourcePort int) {
h.sourcePort = sourcePort
}

func (h *Hook) AppendMocks(m *models.Mock, ctx context.Context) error {
Expand Down Expand Up @@ -527,7 +538,7 @@ func (h *Hook) Stop(forceStop bool) {
// $BPF_CLANG and $BPF_CFLAGS are set by the Makefile.
//
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS -no-global-types -target $TARGET bpf keploy_ebpf.c -- -I./headers -I./headers/$TARGET
func (h *Hook) LoadHooks(appCmd, appContainer string, pid uint32, ctx context.Context, filters *models.Filters) error {
func (h *Hook) LoadHooks(appCmd, appContainer string, pid uint32, ctx context.Context, filters *models.TestFilter) error {
if err := settings.InitRealTimeOffset(); err != nil {
h.logger.Error("failed to fix the BPF clock", zap.Error(err))
return err
Expand Down
65 changes: 37 additions & 28 deletions pkg/models/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,51 @@ type Config struct {
}

type Record struct {
Path string `json:"path" yaml:"path"`
Command string `json:"command" yaml:"command"`
ProxyPort uint32 `json:"proxyport" yaml:"proxyport"`
ContainerName string `json:"containerName" yaml:"containerName"`
NetworkName string `json:"networkName" yaml:"networkName"`
Delay uint64 `json:"delay" yaml:"delay"`
BuildDelay time.Duration `json:"buildDelay" yaml:"buildDelay"`
PassThroughPorts []uint `json:"passThroughPorts" yaml:"passThroughPorts"`
BypassEndpointsRegistry []string `json:"bypassEndpointsRegistry" yaml:"bypassEndpointsRegistry"`
Filters Filters `json:"filters" yaml:"filters"`
Path string `json:"path" yaml:"path"`
Command string `json:"command" yaml:"command"`
ProxyPort uint32 `json:"proxyport" yaml:"proxyport"`
ContainerName string `json:"containerName" yaml:"containerName"`
NetworkName string `json:"networkName" yaml:"networkName"`
Delay uint64 `json:"delay" yaml:"delay"`
BuildDelay time.Duration `json:"buildDelay" yaml:"buildDelay"`
Tests TestFilter `json:"tests" yaml:"tests"`
Stubs Stubs `json:"stubs" yaml:"stubs"`
}

type TestFilter struct {
Filters []Filters `json:"filters" yaml:"filters"`
}

type Stubs struct {
Filters []Filters `json:"filters" yaml:"filters"`
}
type Filters struct {
ReqHeader []string `json:"req_header" yaml:"req_header"`
URLMethods map[string][]string `json:"urlMethods" yaml:"urlMethods"`
Path string `json:"path" yaml:"path"`
UrlMethods []string `json:"urlMethods" yaml:"urlMethods"`
Host string `json:"host" yaml:"host"`
Headers map[string]string `json:"headers" yaml:"headers"`
Port uint `json:"ports" yaml:"ports"`
}

func (filter *Filters) GetKind() string {
return "filter"
func (tests *TestFilter) GetKind() string {
return "Tests"
}

type Test struct {
Path string `json:"path" yaml:"path"`
Command string `json:"command" yaml:"command"`
ProxyPort uint32 `json:"proxyport" yaml:"proxyport"`
ContainerName string `json:"containerName" yaml:"containerName"`
NetworkName string `json:"networkName" yaml:"networkName"`
Tests map[string][]string `json:"tests" yaml:"tests"`
GlobalNoise Globalnoise `json:"globalNoise" yaml:"globalNoise"`
Delay uint64 `json:"delay" yaml:"delay"`
BuildDelay time.Duration `json:"buildDelay" yaml:"buildDelay"`
ApiTimeout uint64 `json:"apiTimeout" yaml:"apiTimeout"`
PassThroughPorts []uint `json:"passThroughPorts" yaml:"passThroughPorts"`
BypassEndpointsRegistry []string `json:"bypassEndpointsRegistry" yaml:"bypassEndpointsRegistry"`
WithCoverage bool `json:"withCoverage" yaml:"withCoverage"` // boolean to capture the coverage in test
CoverageReportPath string `json:"coverageReportPath" yaml:"coverageReportPath"` // directory path to store the coverage files
Path string `json:"path" yaml:"path"`
Command string `json:"command" yaml:"command"`
ProxyPort uint32 `json:"proxyport" yaml:"proxyport"`
ContainerName string `json:"containerName" yaml:"containerName"`
NetworkName string `json:"networkName" yaml:"networkName"`
SelectedTests map[string][]string `json:"selectedTests" yaml:"selectedTests"`
Tests TestFilter `json:"tests" yaml:"tests"`
Stubs Stubs `json:"stubs" yaml:"stubs"`
GlobalNoise Globalnoise `json:"globalNoise" yaml:"globalNoise"`
Delay uint64 `json:"delay" yaml:"delay"`
BuildDelay time.Duration `json:"buildDelay" yaml:"buildDelay"`
ApiTimeout uint64 `json:"apiTimeout" yaml:"apiTimeout"`
WithCoverage bool `json:"withCoverage" yaml:"withCoverage"` // boolean to capture the coverage in test
CoverageReportPath string `json:"coverageReportPath" yaml:"coverageReportPath"` // directory path to store the coverage files
}

type Globalnoise struct {
Expand Down
62 changes: 41 additions & 21 deletions pkg/platform/yaml/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -129,33 +130,51 @@ func (ys *Yaml) Write(path, fileName string, docRead platform.KindSpecifier) err
return nil
}

func containsMatchingUrl(urlMethods map[string][]string, urlStr string, requestMethod models.Method) bool {
parsedURL, err := url.Parse(urlStr)
func containsMatchingUrl(urlMethods []string, urlStr string, requestUrl string, requestMethod models.Method) bool {
urlMatched := false
parsedURL, err := url.Parse(requestUrl)
if err != nil {
return false
}

// Check for URL path and method
path := parsedURL.Path
if methods, exists := urlMethods[path]; exists {
// Loop through the methods for this path
for _, method := range methods {
// If the request method matches one of the allowed methods, return true
regex, err := regexp.Compile(urlStr)
if err != nil {
return false
}

urlMatch := regex.FindStringSubmatch(parsedURL.Path)

if len(urlMatch) > 0 && len(urlStr) != 0 {
urlMatched = true
}

if len(urlMethods) != 0 {
urlMatched = false
for _, method := range urlMethods {
if string(method) == string(requestMethod) {
return true
urlMatched = true
}
}
// If the request method is not in the allowed methods, return false
return false
}

return false
return urlMatched
}

func hasBannedHeaders(object map[string]string, bannedHeaders []string) bool {
for headerName, _ := range object {
for _, bannedHeader := range bannedHeaders {
if headerName == bannedHeader {
func hasBannedHeaders(object map[string]string, bannedHeaders map[string]string) bool {
for headerName, headerNameValue := range object {
for bannedHeaderName, bannedHeaderValue := range bannedHeaders {
regex, err := regexp.Compile(bannedHeaderValue)
if err != nil {
continue
}
headerNameMatch := regex.FindStringSubmatch(headerNameValue)
regex, err = regexp.Compile(bannedHeaderValue)
if err != nil {
continue
}
headerValueMatch := regex.FindStringSubmatch(headerNameValue)
if len(headerNameMatch) > 0 || len(headerValueMatch) > 0 || headerName == bannedHeaderName || bannedHeaderValue == headerNameValue {
return true
}
}
Expand All @@ -168,18 +187,19 @@ func (ys *Yaml) WriteTestcase(tcRead platform.KindSpecifier, ctx context.Context
if !ok {
return fmt.Errorf("%s failed to read testcase in WriteTestcase", Emoji)
}
filters, ok := filtersRead.(*models.Filters)
testFilters, ok := filtersRead.(*models.TestFilter)

var bypassTestCase = false

if ok {
if containsMatchingUrl(filters.URLMethods, tc.HttpReq.URL, tc.HttpReq.Method) {
bypassTestCase = true
} else if hasBannedHeaders(tc.HttpReq.Header, filters.ReqHeader) {
bypassTestCase = true
for _, testFilter := range testFilters.Filters {
if containsMatchingUrl(testFilter.UrlMethods, testFilter.Path, tc.HttpReq.URL, tc.HttpReq.Method) {
bypassTestCase = true
} else if hasBannedHeaders(tc.HttpReq.Header, testFilter.Headers) {
bypassTestCase = true
}
}
}

if !bypassTestCase {
ys.tele.RecordedTestAndMocks()
ys.mutex.Lock()
Expand Down
Loading

0 comments on commit 8f35c38

Please sign in to comment.