diff --git a/catalogs/ad/catalog.go b/catalogs/ad/catalog.go index fa96797..c489ede 100644 --- a/catalogs/ad/catalog.go +++ b/catalogs/ad/catalog.go @@ -107,8 +107,8 @@ func (c *Catalog) GetByDN(dn string) (*ldap.Entry, error) { return c.cl.SearchEntry(sr) } -// CheckConn Check connection to catalog. -func (c *Catalog) CheckConn(cfg *catalogs.Config) error { +// CheckConnection Check connection to AD. +func CheckConnection(cfg *catalogs.Config) error { _, err := ldapconn.NewClient(&ldapconn.Config{ Host: cfg.Host, Port: cfg.Port, diff --git a/cmd/login.go b/cmd/login.go index 219829f..c3c4bc2 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -1,10 +1,13 @@ package cmd import ( + "cataloger/catalogs/ad" "encoding/base64" "errors" "fmt" "os" + "os/user" + "strings" "github.com/howeyc/gopass" "github.com/manifoldco/promptui" @@ -18,18 +21,19 @@ var ( Use: "login", Short: "Login to catalog", Run: func(cmd *cobra.Command, args []string) { - // Create config in not exists - if viper.ConfigFileUsed() == "" { - askUser() - file := cfgFilename + "." + cfgExtention - folder := "./" + cfgFolder - // osuser, err := user.Current() - // if err == nil { - // folder = osuser.HomeDir + "/" + cfgFolder - // } - path := folder + "/" + file + askUser() + + file := cfgFilename + "." + cfgExtention + folder := "./" + cfgFolder + osuser, err := user.Current() + if err == nil { + folder = osuser.HomeDir + "/" + cfgFolder + } + path := folder + "/" + file + // Create config in not exists + if viper.ConfigFileUsed() == "" { // Create config folder if _, err := os.Stat(folder); os.IsNotExist(err) { if err := os.Mkdir(folder, 0700); err != nil { @@ -44,11 +48,30 @@ var ( if err := f.Close(); err != nil { log.Fatal(err) } - viper.SetConfigType(cfgExtention) - if err := viper.WriteConfig(); err != nil { + } + // Update config file + viper.SetConfigType(cfgExtention) + if err := viper.WriteConfig(); err != nil { + log.Fatal(err) + } + + log.Debugf("Trying connect to catalog. Source: %s", source) + // Decode password + d, err := base64Decode(viper.GetString("auth.bind_pass")) + if err != nil { + log.Fatal(err) + } + viper.Set("auth.bind_pass", d) + switch source { + case "ad": + if err := ad.CheckConnection(createConfig()); err != nil { log.Fatal(err) } + log.Debugf("Successfully connected to %s:%s", viper.GetString("server.host"), viper.GetString("server.port")) + default: + log.Fatalf("Unknown source type: %s", source) } + log.Info("Login successfull") }, } @@ -59,36 +82,50 @@ func init() { } func askUser() { - if viper.GetString("server.host") == "" { - promptString("Host", "server.host") + if viper.ConfigFileUsed() != "" { + log.Warnf("Already using config file - '%s'. Further answers will overwrite it.", viper.ConfigFileUsed()) } - if viper.GetInt("server.port") == 0 { - fmt.Print("Port: ") - port := 0 - fmt.Scanln(&port) - if port == 0 { - log.Fatal("Catalog host port can't be 0") - } - viper.Set("server.port", port) - } + var resp string - if viper.GetString("auth.bind_dn") == "" { - promptString("BindDN", "auth.bind_dn") - } + resp = promptString("Host", viper.GetString("server.host")) + viper.Set("server.host", resp) - if viper.GetString("auth.bind_pass") == "" { - fmt.Printf("BindDN password: ") - pass, _ := gopass.GetPasswdMasked() - viper.Set("auth.bind_pass", base64Encode(string(pass))) + resp = promptString("Port", viper.GetString("server.port")) + viper.Set("server.port", resp) + + resp = promptString("Use SSL", viper.GetString("server.ssl")) + switch strings.ToLower(resp) { + case "true": + viper.Set("server.ssl", true) + case "false": + viper.Set("server.ssl", false) + default: + log.Fatal("Unexpected value. Expecting 'true' or 'false'") } - if viper.GetString("params.search_base") == "" { - promptString("Search base:", "params.search_base") + resp = promptString("Insecure SSL", viper.GetString("server.insecure")) + switch strings.ToLower(resp) { + case "true": + viper.Set("server.insecure", true) + case "false": + viper.Set("server.insecure", false) + default: + log.Fatal("Unexpected value. Expecting 'true' or 'false'") } + + resp = promptString("BindDN", viper.GetString("auth.bind_dn")) + viper.Set("auth.bind_dn", resp) + + fmt.Printf("BindDN password: ") + pass, _ := gopass.GetPasswdMasked() + viper.Set("auth.bind_pass", base64Encode(string(pass))) + + resp = promptString("Search base", viper.GetString("params.search_base")) + viper.Set("params.search_base", resp) } -func promptString(label string, param string) { +func promptString(label string, defVal string) string { validate := func(input string) error { if input == "" { return errors.New("Empty param") @@ -108,11 +145,17 @@ func promptString(label string, param string) { Validate: validate, Templates: templates, } + + if defVal != "" { + prompt.Default = defVal + } + result, err := prompt.Run() if err != nil { log.Fatalf("prompt failed %s", err) } - viper.Set(param, result) + + return result } func base64Encode(str string) string {