diff --git a/Dockerfile b/Dockerfile index 1cf558a..3347046 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,8 @@ FROM golang:1.16 as flyutil WORKDIR /go/src/github.com/fly-examples/fly-postgres COPY . . -RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/start ./cmd +RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/flyadmin ./cmd/flyadmin +RUN CGO_ENABLED=0 GOOS=linux go build -v -o /fly/bin/start ./cmd/start FROM postgres:${PG_VERSION} ENV PGDATA=/data/pg_data diff --git a/cmd/flyadmin/main.go b/cmd/flyadmin/main.go new file mode 100644 index 0000000..d59cb53 --- /dev/null +++ b/cmd/flyadmin/main.go @@ -0,0 +1,252 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "github.com/jackc/pgx/v4" +) + +type cmd func(pg *pgx.Conn, input map[string]interface{}) (result interface{}, err error) + +func main() { + app := os.Getenv("FLY_APP_NAME") + hostname := fmt.Sprintf("%s.internal:5432", app) + conn, err := openConnection(hostname) + if err != nil { + fmt.Fprintf(os.Stderr, "error connecting to postgres: %s\n", err) + os.Exit(1) + } + defer conn.Close(context.Background()) + + if len(os.Args) == 1 { + fmt.Fprintln(os.Stderr, "subcommand required") + os.Exit(1) + } + + command := os.Args[1] + input := map[string]interface{}{} + + if len(os.Args) > 2 && os.Args[2] != "" { + if err := json.Unmarshal([]byte(os.Args[2]), &input); err != nil { + fmt.Fprintf(os.Stderr, "error decoding json input: %s\n", err) + os.Exit(1) + } + } + + commands := map[string]cmd{ + "database-list": listDatabases, + "database-create": createDatabase, + "database-delete": deleteDatabase, + "user-list": listUsers, + "user-create": createUser, + "user-delete": deleteUser, + "grant-access": grantAccess, + "revoke-access": revokeAccess, + "grant-superuser": grantSuperuser, + "revoke-superuser": revokeSuperuser, + } + + cmd := commands[command] + if cmd == nil { + fmt.Fprintf(os.Stderr, "unknown command '%s'\n", command) + os.Exit(1) + } + + output, err := cmd(conn, input) + resp := response{ + Result: output, + } + if err != nil { + resp.Error = err.Error() + } + + if err := json.NewEncoder(os.Stdout).Encode(resp); err != nil { + fmt.Fprintf(os.Stderr, "error marshaling response '%s'\n", err) + os.Exit(1) + } +} + +type response struct { + Result interface{} `json:"result"` + Error string `json:"error,omitempty"` +} + +func listDatabases(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := ` + SELECT d.datname, + (SELECT array_agg(u.usename::text order by u.usename) + from pg_user u + where has_database_privilege(u.usename, d.datname, 'CONNECT')) as allowed_users + from pg_database d where d.datistemplate = false + order by d.datname; + ` + + rows, err := pg.Query(context.Background(), sql) + if err != nil { + return nil, err + } + defer rows.Close() + + values := []dbInfo{} + + for rows.Next() { + di := dbInfo{} + if err := rows.Scan(&di.Name, &di.Users); err != nil { + return nil, err + } + values = append(values, di) + } + + return values, nil +} + +type userInfo struct { + Username string `json:"username"` + SuperUser bool `json:"superuser"` + Databases []string `json:"databases"` +} + +type dbInfo struct { + Name string `json:"name"` + Users []string `json:"users"` +} + +func listUsers(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := ` + select u.usename, + usesuper as superuser, + (select array_agg(d.datname::text order by d.datname) + from pg_database d + WHERE datistemplate = false + AND has_database_privilege(u.usename, d.datname, 'CONNECT') + ) as allowed_databases + from pg_user u + order by u.usename + ` + + rows, err := pg.Query(context.Background(), sql) + if err != nil { + return nil, err + } + defer rows.Close() + + values := []userInfo{} + + for rows.Next() { + ui := userInfo{} + if err := rows.Scan(&ui.Username, &ui.SuperUser, &ui.Databases); err != nil { + return nil, err + } + values = append(values, ui) + } + + return values, nil +} + +func createUser(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf(`CREATE USER %s WITH LOGIN PASSWORD '%s'`, input["username"], input["password"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + if val, ok := input["superuser"]; ok && val == true { + return grantSuperuser(pg, input) + } + + return true, nil +} + +func deleteUser(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf(`DROP USER IF EXISTS %s`, input["username"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func createDatabase(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf("CREATE DATABASE %s;", input["name"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func deleteDatabase(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf("DROP DATABASE %s;", input["name"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func grantAccess(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", input["database"], input["username"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func revokeAccess(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf("REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s", input["database"], input["username"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func grantSuperuser(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf("ALTER USER %s WITH SUPERUSER;", input["username"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func revokeSuperuser(pg *pgx.Conn, input map[string]interface{}) (interface{}, error) { + sql := fmt.Sprintf("ALTER USER %s WITH NOSUPERUSER;", input["username"]) + + _, err := pg.Exec(context.Background(), sql) + if err != nil { + return false, err + } + + return true, nil +} + +func openConnection(hostname string) (*pgx.Conn, error) { + url := fmt.Sprintf("postgres://%s/postgres", hostname) + conf, err := pgx.ParseConfig(url) + + if err != nil { + return nil, err + } + conf.User = "flypgadmin" + conf.Password = os.Getenv("SU_PASSWORD") + + return pgx.ConnectConfig(context.Background(), conf) +} diff --git a/cmd/main.go b/cmd/start/main.go similarity index 100% rename from cmd/main.go rename to cmd/start/main.go