Skip to content

Commit

Permalink
Close Neo4j session and driver
Browse files Browse the repository at this point in the history
  • Loading branch information
fbiville committed Sep 1, 2022
1 parent 64bbdcc commit 94e9369
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 5 deletions.
3 changes: 2 additions & 1 deletion cmd/cmd_circleci.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
log "github.com/sirupsen/logrus"
)

func cmdCircleCI(cmd *flag.FlagSet) {
func cmdCircleCI(cmd *flag.FlagSet) *database.Database {
// Setup common params and parse command.
// We have to do this here because we have a custom flag in the GitHub command.
setupCommonFlags()
Expand All @@ -25,6 +25,7 @@ func cmdCircleCI(cmd *flag.FlagSet) {

cci := circleci.GetCircleCI(db, organization, *circleCICookie, session)
cci.Sync()
return db
}

func validateCircleCIParams() {
Expand Down
3 changes: 2 additions & 1 deletion cmd/cmd_github.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus"
)

func cmdGitHub(cmd *flag.FlagSet) {
func cmdGitHub(cmd *flag.FlagSet) *database.Database {
setupCommonFlags()
// Since we want to support a list this has to be defined here, because Golang's `flag`
// sucks (or more likely, I don't understand how to use it)
Expand Down Expand Up @@ -43,6 +43,7 @@ func cmdGitHub(cmd *flag.FlagSet) {
// Now we can actually call the ingestor
gh := github.GetGitHub(db, *githubRESTURL, *githubGraphQLURL, *githubToken, organization, session)
gh.SyncByIngestorNames(ingestorNames)
return db
}

func validateGitHubParams() {
Expand Down
14 changes: 11 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,29 @@ func main() {
log.Fatalf("Unknown subcommand '%s', see help for more details.", os.Args[1])
}

var db *database.Database
defer func() {
if db != nil {
if err := db.Close(); err != nil {
log.Fatal(err)
}
}
}()
switch cmd.Name() {

case githubCmd.Name():
cmdGitHub(cmd)
db = cmdGitHub(cmd)

case circleCICmd.Name():
cmdCircleCI(cmd)
db = cmdCircleCI(cmd)

case enrichCmd.Name():
setupCommonFlags()
cmd.Parse(os.Args[2:])
validateCommonParams()
initLogging()

db := database.GetDB(neo4jURI, neo4jUser, neo4jPassword)
db = database.GetDB(neo4jURI, neo4jUser, neo4jPassword)
en := enrich.GetEnricher(db, organization)
en.Enrich()

Expand Down
17 changes: 17 additions & 0 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package database

import (
"fmt"
"log"

"github.com/neo4j/neo4j-go-driver/v4/neo4j"
)

type Database struct {
driver neo4j.Driver
session neo4j.Session
}

Expand Down Expand Up @@ -36,3 +38,18 @@ func (d *Database) Run(query string, params map[string]interface{}) neo4j.Result

return records
}

func (d *Database) Close() error {
sessionErr := d.session.Close()
driverErr := d.driver.Close()
if driverErr == nil {
return sessionErr
}
if sessionErr == nil {
return driverErr
}
return fmt.Errorf("Both session and driver could not be closed."+
"\nsession close failed with: %v"+
"\ndriver close failed with: %v\n",
sessionErr, driverErr)
}
4 changes: 4 additions & 0 deletions test/e2e/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package e2e

import (
"fmt"
log "github.com/sirupsen/logrus"
"os"
"testing"

Expand Down Expand Up @@ -90,5 +91,8 @@ func TestMain(m *testing.M) {
gh := github.GetGitHub(db, githubRESTURL, githubGraphQLURL, githubToken, organization, session)
gh.SyncByIngestorNames(ingestors)
exitVal := m.Run()
if err := db.Close(); err != nil {
log.Fatal(err)
}
os.Exit(exitVal)
}

0 comments on commit 94e9369

Please sign in to comment.