Skip to content

Commit

Permalink
feat: This commit contains the test and implementation of the wildcar…
Browse files Browse the repository at this point in the history
…d host handling in `Includes`

I thought about the best way to do it, and came up with the decision to **not** duplicate code.
This required the breakdown of the `parse` function, and the creation of a function that would handle **ONLY** the extraction of host data.
Without default values and such. It has a boolean parameter to return either only the virtual or actual hosts.
  • Loading branch information
userwiths committed Feb 16, 2024
1 parent c993924 commit cc4c292
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 16 deletions.
53 changes: 37 additions & 16 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,16 @@ func ParseFS(fsys fs.FS, path string) ([]*SSHHost, error) {
return parse(string(content), path)
}

// parses an openssh config file
func parse(input string, path string) ([]*SSHHost, error) {
sshConfigs := []*SSHHost{}
var next item
// Can be used to get only virtual or non-virtual hosts.
// Has uses no default values, they are handled in `parse` function.
func extractHosts(input string, path string, onlyVirtual bool) ([]*SSHHost, error){
var returnHosts []*SSHHost
var sshHost *SSHHost
var wildcardHosts []*SSHHost
var onlyIncludes bool = !strings.Contains(input, "Host ") && strings.Contains(input, "Include ");
var next item

lexer := lex(input)
Loop:
Loop:
for {
token := lexer.nextItem()

Expand All @@ -164,10 +164,10 @@ Loop:
switch token.typ {
case itemHost:
if sshHost != nil {
if containsWildcard(sshHost) {
wildcardHosts = append(wildcardHosts, sshHost)
} else {
sshConfigs = append(sshConfigs, sshHost)
if containsWildcard(sshHost) && onlyVirtual {
returnHosts = append(returnHosts, sshHost)
} else if !onlyVirtual && !containsWildcard(sshHost) {
returnHosts = append(returnHosts, sshHost)
}
}

Expand Down Expand Up @@ -256,12 +256,16 @@ Loop:
}

for _, f := range files {
includeSshConfigs, err := Parse(f)
fInput, err := ioutil.ReadFile(f)
if err != nil {
return nil, err
}
includeSshConfigs, err := extractHosts(string(fInput), f, onlyVirtual)
if err != nil {
return nil, err
}

sshConfigs = append(sshConfigs, includeSshConfigs...)
returnHosts = append(returnHosts, includeSshConfigs...)
}
case itemCiphers:
next = lexer.nextItem()
Expand All @@ -279,17 +283,31 @@ Loop:
return nil, fmt.Errorf("%s at pos %d", token.val, token.pos)
case itemEOF:
if sshHost != nil {
if containsWildcard(sshHost) {
wildcardHosts = append(wildcardHosts, sshHost)
} else {
sshConfigs = append(sshConfigs, sshHost)
if containsWildcard(sshHost) && onlyVirtual {
returnHosts = append(returnHosts, sshHost)
} else if !onlyVirtual && !containsWildcard(sshHost) {
returnHosts = append(returnHosts, sshHost)
}
}
break Loop
default:
// continue onwards
}
}
return returnHosts, nil
}

// parses an openssh config file
func parse(input string, path string) ([]*SSHHost, error) {
sshConfigs, err := extractHosts(input, path, false)
if err != nil {
return nil, err
}
wildcardHosts, err := extractHosts(input, path, true)
if err != nil {
return nil, err
}

if len(wildcardHosts) > 0 {
err := error(nil)
sshConfigs, err = applyWildcardRules(wildcardHosts, sshConfigs)
Expand Down Expand Up @@ -340,6 +358,9 @@ func matchWildcardHost(wildcardHost *SSHHost, host *SSHHost) bool {
}

func containsWildcard(host *SSHHost) bool {
if host == nil {
return false
}
for _, h := range host.Host {
if strings.Contains(h, "*") {
return true
Expand Down
44 changes: 44 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1197,4 +1197,48 @@ func TestWildcardMacs(t *testing.T) {
t.Errorf("unexpected error parsing config: %s", err.Error())
}
compare(t, expected, parsed)
}

func TestWildcardInclude(t *testing.T) {
config := `Include ./b.conf
Host special*
User special
Port 3333
Host not-special
HostName not-special.com
`
configB := `Host special
HostName special1.com
Host not-*
User not-special
Port 4444`
tmpdir := t.TempDir()
f, err := os.Create(tmpdir + "/b.conf")
if err != nil {
t.Errorf("unable to create file: %s", err.Error())
}
defer f.Close()
_, err = f.WriteString(configB)
if err != nil {
t.Errorf("unable to write to file: %s", err.Error())
}
expected := []*SSHHost{
{
Host: []string{"special"},
User: "special",
Port: 3333,
HostName: "special1.com",
}, {
Host: []string{"not-special"},
HostName: "not-special.com",
User: "not-special",
Port: 4444,
},
}

parsed, err := parse(config, tmpdir + "/config")
if err != nil {
t.Errorf("unexpected error parsing config: %s", err.Error())
}
compare(t, expected, parsed)
}

0 comments on commit cc4c292

Please sign in to comment.