Skip to content

Commit

Permalink
Allow to select the database name for every collection.
Browse files Browse the repository at this point in the history
Usefull if you want to allow people to switch databases.
  • Loading branch information
ilijamt committed Jan 11, 2017
1 parent 2a2547d commit 91ba6ba
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
7 changes: 5 additions & 2 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func (v *ValidationError) Error() string {

type Collection struct {
Name string
Database string
Context *Context
Connection *Connection
}
Expand All @@ -78,12 +79,14 @@ func (d DocumentNotFoundError) Error() string {
return "Document not found"
}

// Collection ...
func (c *Collection) Collection() *mgo.Collection {
return c.Connection.Session.DB(c.Connection.Config.Database).C(c.Name)
return c.Connection.Session.DB(c.Database).C(c.Name)
}

// CollectionOnSession ...
func (c *Collection) collectionOnSession(sess *mgo.Session) *mgo.Collection {
return sess.DB(c.Connection.Config.Database).C(c.Name)
return sess.DB(c.Database).C(c.Name)
}

func (c *Collection) PreSave(doc Document) error {
Expand Down
15 changes: 10 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (m *Connection) Connect() (err error) {
}

session, err := mgo.DialWithInfo(m.Config.DialInfo)

if err != nil {
return err
}
Expand All @@ -71,12 +70,18 @@ func (m *Connection) Connect() (err error) {
return nil
}

func (m *Connection) Collection(name string) *Collection {

// Just create a new instance - it's cheap and only has name
// CollectionFromDatabase ...
func (m *Connection) CollectionFromDatabase(name string, database string) *Collection {
// Just create a new instance - it's cheap and only has name and a database name
return &Collection{
Connection: m,
Name: name,
Context: m.Context,
Database: database,
Name: name,
}
}

// Collection ...
func (m *Connection) Collection(name string) *Collection {
return m.CollectionFromDatabase(name, m.Config.Database)
}
24 changes: 22 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,29 @@ func TestRetrieveCollection(t *testing.T) {
Convey("should be able to retrieve a collection instance from a connection", t, func() {
conn := getConnection()
defer conn.Session.Close()
col := conn.Collection("tests")

col := conn.Collection("tests");
So(col.Name, ShouldEqual, "tests")
So(col.Connection, ShouldEqual, conn)

So(col.Context.Get("foo"), ShouldEqual, "bar")

So(conn.Config.Database, ShouldEqual, col.Database)
})
Convey("should be able to retrieve a collection instance from a connection with different databases", t, func() {
conn := getConnection()
defer conn.Session.Close()

col1 := conn.CollectionFromDatabase("tests", "test1");
So(col1.Name, ShouldEqual, "tests")
So(col1.Connection, ShouldEqual, conn)
So(col1.Database, ShouldEqual, "test1")

col2 := conn.CollectionFromDatabase("tests", "test2");
So(col2.Name, ShouldEqual, "tests")
So(col2.Connection, ShouldEqual, conn)
So(col2.Database, ShouldEqual, "test2")

So(col2.Connection, ShouldEqual, col1.Connection)
So(col1.Database, ShouldNotEqual, col2.Database)
})
}
2 changes: 1 addition & 1 deletion resultSet.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (r *ResultSet) Paginate(perPage, page int) (*PaginationInfo, error) {
// Get count on a different session to avoid blocking
sess := r.Collection.Connection.Session.Copy()

count, err := sess.DB(r.Collection.Connection.Config.Database).C(r.Collection.Name).Find(r.Params).Count()
count, err := sess.DB(r.Collection.Database).C(r.Collection.Name).Find(r.Params).Count()
sess.Close()

if err != nil {
Expand Down

0 comments on commit 91ba6ba

Please sign in to comment.