Skip to content

Commit

Permalink
make generateFilterRules take machine and peers
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
  • Loading branch information
kradalby committed Jun 21, 2023
1 parent 9c425a1 commit 161243c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
6 changes: 4 additions & 2 deletions hscontrol/policy/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func GenerateFilterRules(
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
}

rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain)
rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
Expand All @@ -152,10 +152,12 @@ func GenerateFilterRules(
// generateFilterRules takes a set of machines and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) generateFilterRules(
machines types.Machines,
machine *types.Machine,
peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
machines := append(peers, *machine)

for index, acl := range pol.ACLs {
if acl.Action != "accept" {
Expand Down
37 changes: 20 additions & 17 deletions hscontrol/policy/acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil)
}
Expand Down Expand Up @@ -230,7 +230,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
pol, err := LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
}
Expand Down Expand Up @@ -310,7 +310,7 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

Expand Down Expand Up @@ -366,7 +366,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

Expand Down Expand Up @@ -401,7 +401,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

Expand All @@ -428,7 +428,7 @@ acls:
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

Expand Down Expand Up @@ -459,7 +459,7 @@ acls:
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)

rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

Expand Down Expand Up @@ -1620,7 +1620,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
pol ACLPolicy
}
type args struct {
machines types.Machines
machine types.Machine
peers types.Machines
stripEmailDomain bool
}
tests := []struct {
Expand Down Expand Up @@ -1651,7 +1652,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
},
},
args: args{
machines: types.Machines{},
machine: types.Machine{},
peers: types.Machines{},
stripEmailDomain: true,
},
want: []tailcfg.FilterRule{
Expand Down Expand Up @@ -1691,14 +1693,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
},
},
args: args{
machines: types.Machines{
{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
},
User: types.User{Name: "mickael"},
machine: types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
},
User: types.User{Name: "mickael"},
},
peers: types.Machines{
{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
Expand Down Expand Up @@ -1739,7 +1741,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.generateFilterRules(
tt.args.machines,
&tt.args.machine,
tt.args.peers,
tt.args.stripEmailDomain,
)
if (err != nil) != tt.wantErr {
Expand Down

0 comments on commit 161243c

Please sign in to comment.