diff --git a/main.go b/main.go index 1ff7d14..291d291 100644 --- a/main.go +++ b/main.go @@ -33,15 +33,20 @@ import ( "io/ioutil" "log" "os" + "regexp" + "strings" + "syscall" "time" "github.com/go-sql-driver/mysql" + "golang.org/x/crypto/ssh/terminal" ) var ( dump = flag.String("dump", "", "MySQL dump file") - dsn = flag.String("dsn", "root:root@tcp(0.0.0.0:3306)/", "MySQL Data Source Name") + dsn = flag.String("dsn", "user:password@tcp(0.0.0.0:3306)/", "MySQL Data Source Name") enableSsl = flag.Bool("enable_ssl", false, "Connect to MySQL with SSL") + prompt = flag.Bool("prompt", false, "Prompt for password rather than specifying in the command. Change dsn format to 'user@tcp(0.0.0.0:3306)/'") sslCa = flag.String("ssl_ca", "server-ca.pem", "MySQL Server certificate") sslCert = flag.String("ssl_cert", "client-cert.pem", "MySQL Client PEM cert file") sslKey = flag.String("ssl_key", "client-key.pem", "MySQL Client PEM key file") @@ -134,12 +139,13 @@ func main() { log.Fatalf("no -dump file specified") } + var finalDsn = *dsn if *enableSsl { - rootCertPool := x509.NewCertPool() pem, err := ioutil.ReadFile(*sslCa) if err != nil { log.Fatalln("ioutil.Readline:", err) } + rootCertPool := x509.NewCertPool() if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { log.Fatal("Failed to append CA certificate PEM.") } @@ -149,13 +155,48 @@ func main() { log.Fatalln("tls.LoadX509KeyPair:", err) } clientCert = append(clientCert, certs) - mysql.RegisterTLSConfig("custom", &tls.Config{ + const customTLSName = "custom" + tlserr := mysql.RegisterTLSConfig(customTLSName, &tls.Config{ RootCAs: rootCertPool, Certificates: clientCert, ServerName: *serverName, }) + if tlserr != nil { + log.Fatalln("mysql.RegisterTLSConfig:", tlserr) + } + finalDsn = strings.Join([]string{finalDsn, "?tls=", customTLSName}, "") } - db, err := sql.Open("mysql", *dsn) + + if *prompt { + // DSN strings look like: + // user:password@tcp(0.0.0.0:3306)/ + // With this flag the user can avoid typing their password: + // user@tcp(0.0.0.0:3306)/ + // Save text before ':' and after '@' so we can insert the password + // to create a proper DSN string. + dsnRegex := regexp.MustCompile(`(\w*):?\w*(@.+)`) + matches := dsnRegex.FindStringSubmatch(finalDsn) + if matches == nil { + fmt.Print("Incorrect format for dsn. Usage:\n") + flag.PrintDefaults() + os.Exit(1) + } + + fmt.Print("Enter password: ") + // Don't echo password to screen during input. + password, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + log.Fatalln("Error reading password:", err) + } + // ReadPassword() leaves cursor on the input line, + // so begin output on the next line + fmt.Print("\n") + + // Insert password into the connection string. + finalDsn = strings.Join([]string{matches[1], ":", string(password), matches[2]}, "") + } + + db, err := sql.Open("mysql", finalDsn) if err != nil { log.Fatalln("sql.Open:", err) }