Skip to content

Commit

Permalink
parse the URI and convert it to whatever format we want (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
lolopinto committed Apr 6, 2021
1 parent 9ba3878 commit ccb6a1c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 61 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/go_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,11 @@ jobs:
- name: setup db
run: cd test_setup && go run .
env:
# format of this string is different for python
# if the logic here changes and we need to access golang, we'd need to rewrite this
# not sure what changed here but now psycopg2 needed
# also need an AUTO_SCHEMA_DB_CONNECTION_STRING
DB_CONNECTION_STRING: 'postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable'
AUTO_SCHEMA_DB_CONNECTION_STRING: 'postgresql+psycopg2://postgres:postgres@localhost:5432/postgres?sslmode=disable'
DB_CONNECTION_STRING: 'postgres://postgres:postgres@localhost:5432/postgres'

- name: Test
run: go test ./... -short
env:
DB_CONNECTION_STRING: 'postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable'
DB_CONNECTION_STRING: 'postgres://postgres:postgres@localhost:5432/postgres'
POSTGRES_PASSWORD: 'postgres'
POSTGRES_USER: 'postgres'
153 changes: 103 additions & 50 deletions ent/config/config.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package config

import (
"fmt"
"io/ioutil"
"log"
"os"
"strconv"
"strings"

"github.com/joho/godotenv"
"github.com/lib/pq"
"github.com/lolopinto/ent/internal/util"
"github.com/pkg/errors"
"gopkg.in/yaml.v2"
)

Expand All @@ -17,63 +20,73 @@ type Config struct {
}

type DBConfig struct {
// depending on what we have return what's needed?
connection string
autoSchemaConnection string
rawDBInfo *RawDbInfo
Dialect string `yaml:"dialect"`
Database string `yaml:"database"`
User string `yaml:"user"`
Password string `yaml:"password"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Pool int `yaml:"pool"`
SslMode string `yaml:"sslmode"`
}

func (db *DBConfig) GetConnectionStr() string {
if db.connection != "" {
return db.connection
}

// Todo probably throw here?
if db.rawDBInfo == nil {
panic("no connection string or db ")
}

return db.rawDBInfo.GetConnectionStr("postgres", true)
return db.getConnectionStr("postgres", true)
}

func (db *DBConfig) GetSQLAlchemyDatabaseURIgo() string {
if db.autoSchemaConnection != "" {
return db.autoSchemaConnection
}
if db.connection != "" {
return db.connection
}
// postgres only for now as above. specific driver also
// no ssl mode
return db.rawDBInfo.GetConnectionStr("postgresql+psycopg2", false)
return db.getConnectionStr("postgresql+psycopg2", false)
}

type RawDbInfo struct {
Dialect string `yaml:"dialect"`
Database string `yaml:"database"`
User string `yaml:"user"`
Password string `yaml:"password"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Pool int `yaml:"pool"`
SslMode string `yaml:"sslmode"`
func (r *DBConfig) setDbName(val string) {
r.Database = val
}

func (r *DBConfig) setUser(val string) {
r.User = val
}

func (r *DBConfig) setPassword(val string) {
r.Password = val
}

func (r *DBConfig) setHost(val string) {
r.Host = val
}

func (r *DBConfig) setPort(val string) {
port, err := strconv.Atoi(val)
if err != nil {
panic(err)
}
r.Port = port
}

func (r *DBConfig) setSSLMode(val string) {
r.SslMode = val
}

func (dbData *RawDbInfo) GetConnectionStr(driver string, sslmode bool) string {
format := "{driver}://{user}:{password}@{host}:{port}/{dbname}"
func (dbData *DBConfig) getConnectionStr(driver string, sslmode bool) string {
format := "{driver}://{user}:{password}@{host}/{dbname}"
parts := []string{
"{driver}", driver,
"{user}", dbData.User,
"{password}", dbData.Password,
"{host}", dbData.Host,
"{port}", strconv.Itoa(dbData.Port),
"{dbname}", dbData.Database,
}
if dbData.Port != 0 {
format = "{driver}://{user}:{password}@{host}:{port}/{dbname}"
parts = append(parts, "{port}", strconv.Itoa(dbData.Port))
}

if sslmode {
format = format + "?sslmode={sslmode}"
parts = append(parts, []string{
parts = append(parts,
"{sslmode}", dbData.SslMode,
}...)
)
}
r := strings.NewReplacer(parts...)

Expand Down Expand Up @@ -106,27 +119,69 @@ func GetConnectionStr() string {
return cfg.DB.GetConnectionStr()
}

func ResetConfig(rdbi *RawDbInfo) {
func ResetConfig(rdbi *DBConfig) {
cfg = &Config{
DB: &DBConfig{
rawDBInfo: rdbi,
},
DB: rdbi,
}
}

func loadDBConfig() *DBConfig {
func parseConnectionString() (*DBConfig, error) {
// DB_CONNECTION_STRING trumps file
conn := util.GetEnv("DB_CONNECTION_STRING", "")
autoSchemaConn := util.GetEnv("AUTO_SCHEMA_DB_CONNECTION_STRING", "")
if conn != "" {
return &DBConfig{
connection: conn,
autoSchemaConnection: autoSchemaConn,

if conn == "" {
return nil, nil
}

url, err := pq.ParseURL(conn)
if err != nil {
return nil, errors.Wrap(err, "error parsing url")
}
parts := strings.Split(url, " ")

r := &DBConfig{
// only postgres supported for now
Dialect: "postgres",
SslMode: "disable",
}
m := map[string]func(string){
"dbname": r.setDbName,
"host": r.setHost,
"user": r.setUser,
"password": r.setPassword,
"port": r.setPort,
"sslmode": r.setSSLMode,
}

for _, part := range parts {

parts2 := strings.Split(part, "=")
if len(parts2) != 2 {
log.Fatal("invalid 2")
}

fn := m[parts2[0]]
if fn == nil {
return nil, fmt.Errorf("invalid key %s in url", parts2[0])
}

fn(parts2[1])
}
return r, nil
}

func loadDBConfig() *DBConfig {
// DB_CONNECTION_STRING trumps file
dbInfo, err := parseConnectionString()
if err != nil {
panic(err)
}
if dbInfo != nil {
return dbInfo
}

path := util.GetEnv("PATH_TO_DB_FILE", "config/database.yml")
_, err := os.Stat(path)
_, err = os.Stat(path)
if err != nil {
log.Fatalf("no way to get db config :%v", err)
}
Expand All @@ -136,7 +191,7 @@ func loadDBConfig() *DBConfig {
log.Fatalf("could not read yml file to load db: %v", err)
}

var dbData RawDbInfo
var dbData DBConfig
err = yaml.Unmarshal(b, &dbData)
if err != nil {
log.Fatal(err)
Expand All @@ -151,7 +206,5 @@ func loadDBConfig() *DBConfig {
log.Fatalf("invalid database configuration")
}

return &DBConfig{
rawDBInfo: &dbData,
}
return &dbData
}
2 changes: 1 addition & 1 deletion ent/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func CloseDB() error {

// TODO this obviously needs to be cleaned up
// used by tests
func ResetDB(db2 *sqlx.DB, rdbi *config.RawDbInfo) error {
func ResetDB(db2 *sqlx.DB, rdbi *config.DBConfig) error {
dbMutex.Lock()
defer dbMutex.Unlock()

Expand Down
6 changes: 3 additions & 3 deletions internal/testingutils/test_db/test_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type TestDB struct {
}

func (tdb *TestDB) BeforeAll() error {
dbInfo := config.RawDbInfo{
dbInfo := config.DBConfig{
Dialect: "postgres",
Port: 5432,
User: os.Getenv("POSTGRES_USER"),
Expand All @@ -29,7 +29,7 @@ func (tdb *TestDB) BeforeAll() error {
SslMode: "disable",
}

db, err := sqlx.Open("postgres", dbInfo.GetConnectionStr("postgres", true))
db, err := sqlx.Open("postgres", dbInfo.GetConnectionStr())
if err != nil {
return err
}
Expand All @@ -47,7 +47,7 @@ func (tdb *TestDB) BeforeAll() error {

dbInfo.Database = tdb.dbName

privatedb, err := sqlx.Open("postgres", dbInfo.GetConnectionStr("postgres", true))
privatedb, err := sqlx.Open("postgres", dbInfo.GetConnectionStr())
if err != nil {
return err
}
Expand Down

0 comments on commit ccb6a1c

Please sign in to comment.