From 94e9369853f07b516dd5c065623ca2d964cfb07d Mon Sep 17 00:00:00 2001 From: Florent Biville Date: Thu, 1 Sep 2022 12:16:38 +0200 Subject: [PATCH] Close Neo4j session and driver Fixes #50 --- cmd/cmd_circleci.go | 3 ++- cmd/cmd_github.go | 3 ++- cmd/main.go | 14 +++++++++++--- pkg/database/database.go | 17 +++++++++++++++++ test/e2e/main_test.go | 4 ++++ 5 files changed, 36 insertions(+), 5 deletions(-) diff --git a/cmd/cmd_circleci.go b/cmd/cmd_circleci.go index dfcb582..ae97813 100644 --- a/cmd/cmd_circleci.go +++ b/cmd/cmd_circleci.go @@ -10,7 +10,7 @@ import ( log "github.com/sirupsen/logrus" ) -func cmdCircleCI(cmd *flag.FlagSet) { +func cmdCircleCI(cmd *flag.FlagSet) *database.Database { // Setup common params and parse command. // We have to do this here because we have a custom flag in the GitHub command. setupCommonFlags() @@ -25,6 +25,7 @@ func cmdCircleCI(cmd *flag.FlagSet) { cci := circleci.GetCircleCI(db, organization, *circleCICookie, session) cci.Sync() + return db } func validateCircleCIParams() { diff --git a/cmd/cmd_github.go b/cmd/cmd_github.go index e940074..08131c9 100644 --- a/cmd/cmd_github.go +++ b/cmd/cmd_github.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" ) -func cmdGitHub(cmd *flag.FlagSet) { +func cmdGitHub(cmd *flag.FlagSet) *database.Database { setupCommonFlags() // Since we want to support a list this has to be defined here, because Golang's `flag` // sucks (or more likely, I don't understand how to use it) @@ -43,6 +43,7 @@ func cmdGitHub(cmd *flag.FlagSet) { // Now we can actually call the ingestor gh := github.GetGitHub(db, *githubRESTURL, *githubGraphQLURL, *githubToken, organization, session) gh.SyncByIngestorNames(ingestorNames) + return db } func validateGitHubParams() { diff --git a/cmd/main.go b/cmd/main.go index 19222e1..f200c3d 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -56,13 +56,21 @@ func main() { log.Fatalf("Unknown subcommand '%s', see help for more details.", os.Args[1]) } + var db *database.Database + defer func() { + if db != nil { + if err := db.Close(); err != nil { + log.Fatal(err) + } + } + }() switch cmd.Name() { case githubCmd.Name(): - cmdGitHub(cmd) + db = cmdGitHub(cmd) case circleCICmd.Name(): - cmdCircleCI(cmd) + db = cmdCircleCI(cmd) case enrichCmd.Name(): setupCommonFlags() @@ -70,7 +78,7 @@ func main() { validateCommonParams() initLogging() - db := database.GetDB(neo4jURI, neo4jUser, neo4jPassword) + db = database.GetDB(neo4jURI, neo4jUser, neo4jPassword) en := enrich.GetEnricher(db, organization) en.Enrich() diff --git a/pkg/database/database.go b/pkg/database/database.go index 79ef3a9..b51463f 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -1,12 +1,14 @@ package database import ( + "fmt" "log" "github.com/neo4j/neo4j-go-driver/v4/neo4j" ) type Database struct { + driver neo4j.Driver session neo4j.Session } @@ -36,3 +38,18 @@ func (d *Database) Run(query string, params map[string]interface{}) neo4j.Result return records } + +func (d *Database) Close() error { + sessionErr := d.session.Close() + driverErr := d.driver.Close() + if driverErr == nil { + return sessionErr + } + if sessionErr == nil { + return driverErr + } + return fmt.Errorf("Both session and driver could not be closed."+ + "\nsession close failed with: %v"+ + "\ndriver close failed with: %v\n", + sessionErr, driverErr) +} diff --git a/test/e2e/main_test.go b/test/e2e/main_test.go index 3ff0639..a832d3e 100644 --- a/test/e2e/main_test.go +++ b/test/e2e/main_test.go @@ -2,6 +2,7 @@ package e2e import ( "fmt" + log "github.com/sirupsen/logrus" "os" "testing" @@ -90,5 +91,8 @@ func TestMain(m *testing.M) { gh := github.GetGitHub(db, githubRESTURL, githubGraphQLURL, githubToken, organization, session) gh.SyncByIngestorNames(ingestors) exitVal := m.Run() + if err := db.Close(); err != nil { + log.Fatal(err) + } os.Exit(exitVal) }