Skip to content

Commit

Permalink
Merge branch 'feat/change-initial-ipblock-29' into dev
Browse files Browse the repository at this point in the history
Closes #29.
  • Loading branch information
cad committed Aug 31, 2017
2 parents 788c694 + 87dfa70 commit c247248
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 73 deletions.
3 changes: 2 additions & 1 deletion api/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ func (s *VPNService) Init(ctx context.Context, req *pb.VPNInitRequest) (*pb.VPNI
case pb.VPNProto_NOPREF:
proto = ovpm.UDPProto
}
if err := ovpm.Init(req.Hostname, req.Port, proto); err != nil {

if err := ovpm.Init(req.Hostname, req.Port, proto, req.IPBlock); err != nil {
logrus.Errorf("server can not be created: %v", err)
}
return &pb.VPNInitResponse{}, nil
Expand Down
17 changes: 15 additions & 2 deletions cmd/ovpm/vpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"

"github.com/Sirupsen/logrus"
"github.com/asaskevich/govalidator"
"github.com/cad/ovpm"
"github.com/cad/ovpm/pb"
"github.com/olekukonko/tablewriter"
Expand Down Expand Up @@ -60,12 +61,16 @@ var vpnInitCommand = cli.Command{
Name: "tcp, t",
Usage: "use TCP for vpn protocol, instead of UDP",
},
cli.StringFlag{
Name: "net, n",
Usage: fmt.Sprintf("VPN network to give clients IP addresses from, in the CIDR form (default: %s)", ovpm.DefaultVPNNetwork),
},
},
Action: func(c *cli.Context) error {
action = "vpn:init"
hostname := c.String("hostname")
if hostname == "" {
logrus.Errorf("'hostname' is needed")
logrus.Errorf("'hostname' is required")
fmt.Println(cli.ShowSubcommandHelp(c))
os.Exit(1)

Expand All @@ -83,6 +88,14 @@ var vpnInitCommand = cli.Command{
proto = pb.VPNProto_TCP
}

ipblock := c.String("net")
if ipblock != "" && !govalidator.IsCIDR(ipblock) {
fmt.Println("--net takes an ip network in the CIDR form. e.g. 10.9.0.0/24")
fmt.Println()
fmt.Println(cli.ShowSubcommandHelp(c))
os.Exit(1)
}

conn := getConn(c.GlobalString("daemon-port"))
defer conn.Close()
vpnSvc := pb.NewVPNServiceClient(conn)
Expand All @@ -102,7 +115,7 @@ var vpnInitCommand = cli.Command{
okayResponses := []string{"y", "Y", "yes", "Yes", "YES"}
nokayResponses := []string{"n", "N", "no", "No", "NO"}
if stringInSlice(response, okayResponses) {
if _, err := vpnSvc.Init(context.Background(), &pb.VPNInitRequest{Hostname: hostname, Port: port, Protopref: proto}); err != nil {
if _, err := vpnSvc.Init(context.Background(), &pb.VPNInitRequest{Hostname: hostname, Port: port, Protopref: proto, IPBlock: ipblock}); err != nil {
logrus.Errorf("server can not be initialized: %v", err)
os.Exit(1)
return err
Expand Down
6 changes: 3 additions & 3 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ const (
// DefaultVPNProto is the default OpenVPN protocol to use.
DefaultVPNProto = UDPProto

// DefaultVPNNetwork is the default OpenVPN network to use.
DefaultVPNNetwork = "10.9.0.0/24"

etcBasePath = "/etc/ovpm/"
varBasePath = "/var/db/ovpm/"

Expand All @@ -23,9 +26,6 @@ const (
_DefaultCAKeyPath = varBasePath + "ca.key"
_DefaultDHParamsPath = varBasePath + "dh4096.pem"
_DefaultCRLPath = varBasePath + "crl.pem"

_DefaultServerNetwork = "10.9.0.0"
_DefaultServerNetMask = "255.255.255.0"
)

// Testing is used to determine wether we are testing or running normally.
Expand Down
14 changes: 7 additions & 7 deletions net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func TestVPNCreateNewNetwork(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
Init("localhost", "", UDPProto)
Init("localhost", "", UDPProto, "")

// Prepare:
// Test:
Expand Down Expand Up @@ -56,7 +56,7 @@ func TestVPNDeleteNetwork(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
Init("localhost", "", UDPProto)
Init("localhost", "", UDPProto, "")

// Prepare:
// Test:
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestVPNGetNetwork(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
Init("localhost", "", UDPProto)
Init("localhost", "", UDPProto, "")

// Prepare:
// Test:
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestVPNGetAllNetworks(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
Init("localhost", "", UDPProto)
Init("localhost", "", UDPProto, "")

// Prepare:
// Test:
Expand Down Expand Up @@ -175,7 +175,7 @@ func TestNetAssociate(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
Init("localhost", "", UDPProto)
Init("localhost", "", UDPProto, "")

// Prepare:
// Test:
Expand Down Expand Up @@ -213,7 +213,7 @@ func TestNetDissociate(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
err := Init("localhost", "", UDPProto)
err := Init("localhost", "", UDPProto, "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestNetGetAssociatedUsers(t *testing.T) {
setupTestCase()
SetupDB("sqlite3", ":memory:")
defer CeaseDB()
Init("localhost", "", UDPProto)
Init("localhost", "", UDPProto, "")

// Prepare:
// Test:
Expand Down
63 changes: 36 additions & 27 deletions pb/vpn.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pb/vpn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ message VPNInitRequest {
string Hostname = 1;
string Port = 2;
VPNProto Protopref = 3;
string IPBlock = 4;
}

service VPNService {
Expand Down
3 changes: 3 additions & 0 deletions pb/vpn.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
},
"Protopref": {
"$ref": "#/definitions/pbVPNProto"
},
"IPBlock": {
"type": "string"
}
}
},
Expand Down
14 changes: 11 additions & 3 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,12 @@ func (u *DBUser) GetCreatedAt() string {
func (u *DBUser) getIP() net.IP {
users := getNonStaticHostUsers()
staticHostIDs := getStaticHostIDs()
mask := net.IPMask(net.ParseIP(_DefaultServerNetMask).To4())
network := net.ParseIP(_DefaultServerNetwork).To4().Mask(mask)
server, err := GetServerInstance()
if err != nil {
logrus.Panicf("can not get server instance: %v", err)
}
mask := net.IPMask(net.ParseIP(server.Mask).To4())
network := net.ParseIP(server.Net).To4().Mask(mask)

// If the user has static ip address, return it immediately.
if u.HostID != 0 {
Expand Down Expand Up @@ -335,7 +339,11 @@ func (u *DBUser) getIP() net.IP {

// GetIPNet returns user's vpn ip network. (e.g. 192.168.0.1/24)
func (u *DBUser) GetIPNet() string {
mask := net.IPMask(net.ParseIP(_DefaultServerNetMask).To4())
server, err := GetServerInstance()
if err != nil {
logrus.Panicf("can not get user ipnet: %v", err)
}
mask := net.IPMask(net.ParseIP(server.Mask).To4())

ipn := net.IPNet{
IP: u.getIP(),
Expand Down
20 changes: 10 additions & 10 deletions user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestCreateNewUser(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")
server, _ := ovpm.GetServerInstance()

// Prepare:
Expand Down Expand Up @@ -89,7 +89,7 @@ func TestUserUpdate(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:
username := "testUser"
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestUserPasswordCorrect(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:
initialPassword := "g00dp@ssW0rd9"
Expand All @@ -144,7 +144,7 @@ func TestUserPasswordReset(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:
initialPassword := "g00dp@ssW0rd9"
Expand All @@ -171,7 +171,7 @@ func TestUserDelete(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:
username := "testUser"
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestUserGet(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:
username := "testUser"
Expand All @@ -233,7 +233,7 @@ func TestUserGetAll(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")
count := 5

// Prepare:
Expand Down Expand Up @@ -271,14 +271,14 @@ func TestUserRenew(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:
user, _ := ovpm.CreateNewUser("user", "1234", false, 0)

// Test:
// Re initialize the server.
ovpm.Init("example.com", "3333", ovpm.UDPProto) // This causes implicit Renew() on every user in the system.
ovpm.Init("example.com", "3333", ovpm.UDPProto, "") // This causes implicit Renew() on every user in the system.

// Fetch user back.
fetchedUser, _ := ovpm.GetUser(user.GetUsername())
Expand All @@ -293,7 +293,7 @@ func TestUserIPAllocator(t *testing.T) {
// Initialize:
ovpm.SetupDB("sqlite3", ":memory:")
defer ovpm.CeaseDB()
ovpm.Init("localhost", "", ovpm.UDPProto)
ovpm.Init("localhost", "", ovpm.UDPProto, "")

// Prepare:

Expand Down

0 comments on commit c247248

Please sign in to comment.