Skip to content

Commit

Permalink
refactor: split package main
Browse files Browse the repository at this point in the history
sshportal refactor. Focused on splitting up package main into packages
main, dbmodels, crypto, and bastion.
  • Loading branch information
ahamidullah authored and moul committed Jan 3, 2019
1 parent f220af5 commit fe55cc1
Show file tree
Hide file tree
Showing 15 changed files with 3,095 additions and 395 deletions.
52 changes: 0 additions & 52 deletions config.go

This file was deleted.

87 changes: 1 addition & 86 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
package main // import "moul.io/sshportal"

import (
"fmt"
"log"
"math"
"math/rand"
"net"
"os"
"path"
"time"

"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/moul/ssh"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
)

var (
Expand Down Expand Up @@ -45,7 +39,7 @@ func main() {
if err := ensureLogDirectory(c.String("logs-location")); err != nil {
return err
}
cfg, err := parseServeConfig(c)
cfg, err := parseServerConfig(c)
if err != nil {
return err
}
Expand Down Expand Up @@ -120,82 +114,3 @@ func main() {
log.Fatalf("error: %v", err)
}
}

var defaultChannelHandler ssh.ChannelHandler

func server(c *configServe) (err error) {
var db = (*gorm.DB)(nil)

// try to setup the local DB
if db, err = gorm.Open(c.dbDriver, c.dbURL); err != nil {
return
}
defer func() {
origErr := err
err = db.Close()
if origErr != nil {
err = origErr
}
}()
if err = db.DB().Ping(); err != nil {
return
}
db.LogMode(c.debug)
if err = dbInit(db); err != nil {
return
}

// create TCP listening socket
ln, err := net.Listen("tcp", c.bindAddr)
if err != nil {
return err
}

// configure server
srv := &ssh.Server{
Addr: c.bindAddr,
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
Version: fmt.Sprintf("sshportal-%s", Version),
}

// configure channel handler
defaultSessionHandler := srv.GetChannelHandler("session")
defaultDirectTcpipHandler := srv.GetChannelHandler("direct-tcpip")
defaultChannelHandler = func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
switch newChan.ChannelType() {
case "session":
go defaultSessionHandler(srv, conn, newChan, ctx)
case "direct-tcpip":
go defaultDirectTcpipHandler(srv, conn, newChan, ctx)
default:
if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil {
log.Printf("failed to reject chan: %v", err)
}
}
}
srv.SetChannelHandler("session", nil)
srv.SetChannelHandler("direct-tcpip", nil)
srv.SetChannelHandler("default", channelHandler)

if c.idleTimeout != 0 {
srv.IdleTimeout = c.idleTimeout
// gliderlabs/ssh requires MaxTimeout to be non-zero if we want to use IdleTimeout.
// So, set it to the max value, because we don't want a max timeout.
srv.MaxTimeout = math.MaxInt64
}

for _, opt := range []ssh.Option{
// custom PublicKeyAuth handler
ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)),
ssh.PasswordAuth(passwordAuthHandler(db, c)),
// retrieve sshportal SSH private key from database
privateKeyFromDB(db, c.aesKey),
} {
if err := srv.SetOption(opt); err != nil {
return err
}
}

log.Printf("info: SSH Server accepting connections on %s, idle-timout=%v", c.bindAddr, c.idleTimeout)
return srv.Serve(ln)
}
25 changes: 14 additions & 11 deletions acl.go → pkg/bastion/acl.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package main
package bastion // import "moul.io/sshportal/pkg/bastion"

import "sort"
import (
"moul.io/sshportal/pkg/dbmodels"
"sort"
)

type ByWeight []*ACL
type byWeight []*dbmodels.ACL

func (a ByWeight) Len() int { return len(a) }
func (a ByWeight) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByWeight) Less(i, j int) bool { return a[i].Weight < a[j].Weight }
func (a byWeight) Len() int { return len(a) }
func (a byWeight) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byWeight) Less(i, j int) bool { return a[i].Weight < a[j].Weight }

func CheckACLs(user User, host Host) (string, error) {
func checkACLs(user dbmodels.User, host dbmodels.Host) (string, error) {
// shared ACLs between user and host
aclMap := map[uint]*ACL{}
aclMap := map[uint]*dbmodels.ACL{}
for _, userGroup := range user.Groups {
for _, userGroupACL := range userGroup.ACLs {
for _, hostGroup := range host.Groups {
Expand All @@ -26,15 +29,15 @@ func CheckACLs(user User, host Host) (string, error) {

// deny by default if no shared ACL
if len(aclMap) == 0 {
return string(ACLActionDeny), nil // default action
return string(dbmodels.ACLActionDeny), nil // default action
}

// transform map to slice and sort it
acls := make([]*ACL, 0, len(aclMap))
acls := make([]*dbmodels.ACL, 0, len(aclMap))
for _, acl := range aclMap {
acls = append(acls, acl)
}
sort.Sort(ByWeight(acls))
sort.Sort(byWeight(acls))

return acls[0].Action, nil
}
21 changes: 12 additions & 9 deletions acl_test.go → pkg/bastion/acl_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package bastion // import "moul.io/sshportal/pkg/bastion"

import (
"io/ioutil"
Expand All @@ -7,7 +7,10 @@ import (
"testing"

"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/sqlite"
. "github.com/smartystreets/goconvey/convey"
"moul.io/sshportal/pkg/dbmodels"
)

func TestCheckACLs(t *testing.T) {
Expand All @@ -23,25 +26,25 @@ func TestCheckACLs(t *testing.T) {
db, err := gorm.Open("sqlite3", filepath.Join(tempDir, "sshportal.db"))
So(err, ShouldBeNil)
db.LogMode(false)
So(dbInit(db), ShouldBeNil)
So(DBInit(db), ShouldBeNil)

// create dummy objects
var hostGroup HostGroup
err = HostGroupsByIdentifiers(db, []string{"default"}).First(&hostGroup).Error
var hostGroup dbmodels.HostGroup
err = dbmodels.HostGroupsByIdentifiers(db, []string{"default"}).First(&hostGroup).Error
So(err, ShouldBeNil)
db.Create(&Host{Groups: []*HostGroup{&hostGroup}})
db.Create(&dbmodels.Host{Groups: []*dbmodels.HostGroup{&hostGroup}})

//. load db
var (
hosts []Host
users []User
hosts []dbmodels.Host
users []dbmodels.User
)
db.Preload("Groups").Preload("Groups.ACLs").Find(&hosts)
db.Preload("Groups").Preload("Groups.ACLs").Find(&users)

// test
action, err := CheckACLs(users[0], hosts[0])
action, err := checkACLs(users[0], hosts[0])
So(err, ShouldBeNil)
So(action, ShouldEqual, ACLActionAllow)
So(action, ShouldEqual, dbmodels.ACLActionAllow)
})
}
Loading

0 comments on commit fe55cc1

Please sign in to comment.