Skip to content

Commit

Permalink
[CA-1258] Add a status check to LdapRegistrationDAO and use to check …
Browse files Browse the repository at this point in the history
…openDJ (#519)

* Add a status check to LdapRegistrationDAO and use to check openDJ

* Compiling, tests not passing

* Working StatusServiceSpec tests

* Fix mocks to work successfully for status checks

* [NOT WORKING] Mock registration DAO

* Move to enums, add mockregistrationdao

* Change a string

* Testing

* More testing

* Try scala version different

* Remove version commands

* PR Feedback

* PR Feedback

* Change name to connectionType
  • Loading branch information
s-rubenstein committed May 6, 2021
1 parent c34ff18 commit 80f3066
Show file tree
Hide file tree
Showing 19 changed files with 233 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ object Boot extends IOApp with LazyLogging {
connectionPool.setMaxWaitTimeMillis(30000)
connectionPool.setConnectionPoolName(name)
connectionPool
})(ldapConnection => IO(ldapConnection.close()))
})(ldapConnectionPool => IO(ldapConnectionPool.close()))
}

private[sam] def createAppDependencies(
Expand Down Expand Up @@ -321,7 +321,7 @@ object Boot extends IOApp with LazyLogging {
val policyEvaluatorService = PolicyEvaluatorService(config.emailDomain, resourceTypeMap, accessPolicyDAO, directoryDAO)
val resourceService = new ResourceService(resourceTypeMap, policyEvaluatorService, accessPolicyDAO, directoryDAO, cloudExtensionsInitializer.cloudExtensions, config.emailDomain)
val userService = new UserService(directoryDAO, cloudExtensionsInitializer.cloudExtensions, registrationDAO, config.blockedEmailDomains)
val statusService = new StatusService(directoryDAO, cloudExtensionsInitializer.cloudExtensions, DbReference(DatabaseNames.Read, implicitly), 10 seconds)
val statusService = new StatusService(directoryDAO, registrationDAO, cloudExtensionsInitializer.cloudExtensions, DbReference(DatabaseNames.Read, implicitly), 10 seconds)
val managedGroupService =
new ManagedGroupService(resourceService, policyEvaluatorService, resourceTypeMap, accessPolicyDAO, directoryDAO, cloudExtensionsInitializer.cloudExtensions, config.emailDomain)
val samApplication = SamApplication(userService, resourceService, statusService)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.broadinstitute.dsde.workbench.sam.dataAccess

object ConnectionType extends Enumeration {
type ConnectionType = Value
val LDAP, Postgres = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.model.google.{ServiceAccount, ServiceAccountDisplayName, ServiceAccountSubjectId}
import org.broadinstitute.dsde.workbench.sam._
import org.broadinstitute.dsde.workbench.sam.config.DirectoryConfig
import org.broadinstitute.dsde.workbench.sam.dataAccess.ConnectionType.ConnectionType
import org.broadinstitute.dsde.workbench.sam.schema.JndiSchemaDAO.{Attr, ObjectClass}
import org.broadinstitute.dsde.workbench.sam.util.{LdapSupport, SamRequestContext}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

// use ExecutionContexts.blockingThreadPool for blockingEc
class LdapRegistrationDAO(
Expand All @@ -40,6 +42,8 @@ class LdapRegistrationDAO(
}
}

override def getConnectionType(): ConnectionType = ConnectionType.LDAP

override def createUser(user: WorkbenchUser, samRequestContext: SamRequestContext): IO[WorkbenchUser] = {
val attrs = List(
new Attribute(Attr.email, user.email.value),
Expand Down Expand Up @@ -156,4 +160,17 @@ class LdapRegistrationDAO(

override def setGoogleSubjectId(userId: WorkbenchUserId, googleSubjectId: GoogleSubjectId, samRequestContext: SamRequestContext): IO[Unit] =
executeLdap(IO(ldapConnectionPool.modify(userDn(userId), new Modification(ModificationType.ADD, Attr.googleSubjectId, googleSubjectId.value))), "setGoogleSubjectId", samRequestContext)

override def checkStatus(samRequestContext: SamRequestContext): Boolean = {
val ldapIsHealthy = Try {
ldapConnectionPool.getHealthCheck
val connection = ldapConnectionPool.getConnection
ldapConnectionPool.getHealthCheck.ensureNewConnectionValid(connection)
ldapConnectionPool.releaseConnection(connection)
} match {
case Success(_) => true
case Failure(_) => false
}
ldapIsHealthy
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import cats.effect.{ContextShift, IO, Timer}
import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.model.google.{GoogleProject, ServiceAccount, ServiceAccountSubjectId}
import org.broadinstitute.dsde.workbench.sam._
import org.broadinstitute.dsde.workbench.sam.dataAccess.ConnectionType.ConnectionType
import org.broadinstitute.dsde.workbench.sam.db.SamParameterBinderFactory._
import org.broadinstitute.dsde.workbench.sam.db.SamTypeBinders._
import org.broadinstitute.dsde.workbench.sam.db._
Expand All @@ -17,10 +18,13 @@ import org.broadinstitute.dsde.workbench.sam.util.{DatabaseSupport, SamRequestCo
import org.postgresql.util.PSQLException
import scalikejdbc._

import scala.concurrent.duration.DurationInt
import scala.util.{Failure, Try}

class PostgresDirectoryDAO(protected val writeDbRef: DbReference, protected val readDbRef: DbReference)(implicit val cs: ContextShift[IO], timer: Timer[IO]) extends DirectoryDAO with DatabaseSupport with PostgresGroupDAO {

override def getConnectionType(): ConnectionType = ConnectionType.Postgres

override def createGroup(group: BasicWorkbenchGroup, accessInstructionsOpt: Option[String], samRequestContext: SamRequestContext): IO[BasicWorkbenchGroup] = {
serializableWriteTransaction("createGroup", samRequestContext)({ implicit session =>
val groupId: GroupPK = insertGroup(group)
Expand Down Expand Up @@ -761,4 +765,10 @@ class PostgresDirectoryDAO(protected val writeDbRef: DbReference, protected val
}
})
}

override def checkStatus(samRequestContext: SamRequestContext): Boolean = {
writeDbRef.inLocalTransaction { session =>
session.connection.isValid((2 seconds).toSeconds.intValue())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.broadinstitute.dsde.workbench.sam.dataAccess

import cats.effect.IO
import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.sam.dataAccess.ConnectionType.ConnectionType
import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext

/**
Expand All @@ -11,6 +12,7 @@ import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext
* away from a solution that requires that the Apache proxies query this group, we can remove the RegistrationDAO.
*/
trait RegistrationDAO {
def getConnectionType(): ConnectionType
def createUser(user: WorkbenchUser, samRequestContext: SamRequestContext): IO[WorkbenchUser]
def loadUser(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[Option[WorkbenchUser]]
def deleteUser(userId: WorkbenchUserId, samRequestContext: SamRequestContext): IO[Unit]
Expand All @@ -22,4 +24,5 @@ trait RegistrationDAO {
def deletePetServiceAccount(petServiceAccountId: PetServiceAccountId, samRequestContext: SamRequestContext): IO[Unit]
def updatePetServiceAccount(petServiceAccount: PetServiceAccount, samRequestContext: SamRequestContext): IO[PetServiceAccount]
def setGoogleSubjectId(userId: WorkbenchUserId, googleSubjectId: GoogleSubjectId, samRequestContext: SamRequestContext): IO[Unit]
def checkStatus(samRequestContext: SamRequestContext): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import akka.pattern.ask
import akka.util.Timeout
import cats.effect.IO
import com.typesafe.scalalogging.LazyLogging
import org.broadinstitute.dsde.workbench.model.WorkbenchGroupName
import org.broadinstitute.dsde.workbench.sam.dataAccess.DirectoryDAO
import org.broadinstitute.dsde.workbench.sam.dataAccess.{ConnectionType, DirectoryDAO, RegistrationDAO}
import org.broadinstitute.dsde.workbench.sam.db.DbReference
import org.broadinstitute.dsde.workbench.sam.util.SamRequestContext
import org.broadinstitute.dsde.workbench.util.health.HealthMonitor.GetCurrentStatus
Expand All @@ -17,11 +16,15 @@ import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, Future}

class StatusService(
val directoryDAO: DirectoryDAO,
val cloudExtensions: CloudExtensions,
val dbReference: DbReference,
initialDelay: FiniteDuration = Duration.Zero,
pollInterval: FiniteDuration = 1 minute)(implicit system: ActorSystem, executionContext: ExecutionContext)
val directoryDAO: DirectoryDAO,
// We expect this to be of type LdapRegistrationDAO, because
// the status service specifically cares about checking LDAP's
// status here, not a generic RegistrationDAO
val ldapRegistrationDAO: RegistrationDAO,
val cloudExtensions: CloudExtensions,
val dbReference: DbReference,
initialDelay: FiniteDuration = Duration.Zero,
pollInterval: FiniteDuration = 1 minute)(implicit system: ActorSystem, executionContext: ExecutionContext)
extends LazyLogging {
implicit val askTimeout = Timeout(5 seconds)

Expand All @@ -31,24 +34,30 @@ class StatusService(
def getStatus(): Future[StatusCheckResponse] = (healthMonitor ? GetCurrentStatus).asInstanceOf[Future[StatusCheckResponse]]

private def checkStatus(): Map[Subsystem, Future[SubsystemStatus]] =
cloudExtensions.checkStatus + (OpenDJ -> checkOpenDJ(cloudExtensions.allUsersGroupName).unsafeToFuture()) + (Database -> checkDatabase().unsafeToFuture())
cloudExtensions.checkStatus + (OpenDJ -> checkOpenDJ().unsafeToFuture()) + (Database -> checkDatabase().unsafeToFuture())

private def checkOpenDJ(groupToLoad: WorkbenchGroupName): IO[SubsystemStatus] = {
private def checkOpenDJ(): IO[SubsystemStatus] = IO {
// Since Status calls are ~80% of all Sam calls and are easy to track separately, Status calls are not being traced.
logger.info("checking opendj connection")
directoryDAO.loadGroupEmail(groupToLoad, SamRequestContext(None)).map { // Since Status calls are ~80% of all Sam calls and are easy to track separately, Status calls are not being traced.
case Some(_) => HealthMonitor.OkStatus
case None => HealthMonitor.failedStatus(s"could not find group $groupToLoad in opendj")
if (ldapRegistrationDAO.getConnectionType() != ConnectionType.LDAP) {
HealthMonitor.failedStatus("Connection of RegistrationDAO is not to OpenDJ")
} else {
if (ldapRegistrationDAO.checkStatus(SamRequestContext(None)))
HealthMonitor.OkStatus
else
HealthMonitor.failedStatus(s"LDAP database connection invalid or timed out checking")
}
}

private def checkDatabase(): IO[SubsystemStatus] = IO {
logger.info("checking database connection")
dbReference.inLocalTransaction { session =>
if (session.connection.isValid((2 seconds).toSeconds.intValue())) {
if (directoryDAO.getConnectionType() != ConnectionType.Postgres) {
HealthMonitor.failedStatus("Connection of RegistrationDAO is not to Postgres")
} else {
if (directoryDAO.checkStatus(SamRequestContext(None)))
HealthMonitor.OkStatus
} else {
HealthMonitor.failedStatus("database connection invalid or timed out checking")
}
else
HealthMonitor.failedStatus("Postgres database connection invalid or timed out checking")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.sam.api._
import org.broadinstitute.dsde.workbench.sam.config.AppConfig._
import org.broadinstitute.dsde.workbench.sam.config._
import org.broadinstitute.dsde.workbench.sam.dataAccess.{AccessPolicyDAO, MockAccessPolicyDAO, MockDirectoryDAO}
import org.broadinstitute.dsde.workbench.sam.dataAccess.{AccessPolicyDAO, MockAccessPolicyDAO, MockDirectoryDAO, MockRegistrationDAO}
import org.broadinstitute.dsde.workbench.sam.db.{DatabaseNames, DbReference}
import org.broadinstitute.dsde.workbench.sam.db.tables._
import org.broadinstitute.dsde.workbench.sam.google.{GoogleExtensionRoutes, GoogleExtensions, GoogleGroupSynchronizer, GoogleKeyCache}
Expand Down Expand Up @@ -88,7 +88,7 @@ object TestSupport extends TestSupport {
def genSamDependencies(resourceTypes: Map[ResourceTypeName, ResourceType] = Map.empty, googIamDAO: Option[GoogleIamDAO] = None, googleServicesConfig: GoogleServicesConfig = googleServicesConfig, cloudExtensions: Option[CloudExtensions] = None, googleDirectoryDAO: Option[GoogleDirectoryDAO] = None, policyAccessDAO: Option[AccessPolicyDAO] = None)(implicit system: ActorSystem) = {
val googleDirectoryDAO = new MockGoogleDirectoryDAO()
val directoryDAO = new MockDirectoryDAO()
val registrationDAO = new MockDirectoryDAO()
val registrationDAO = new MockRegistrationDAO()
val googleIamDAO = googIamDAO.getOrElse(new MockGoogleIamDAO())
val policyDAO = policyAccessDAO.getOrElse(new MockAccessPolicyDAO(resourceTypes))
val notificationPubSubDAO = new MockGooglePubSubDAO()
Expand Down Expand Up @@ -121,7 +121,7 @@ object TestSupport extends TestSupport {
val mockResourceService = new ResourceService(resourceTypes, policyEvaluatorService, policyDAO, directoryDAO, googleExt, "example.com")
val mockManagedGroupService = new ManagedGroupService(mockResourceService, policyEvaluatorService, resourceTypes, policyDAO, directoryDAO, googleExt, "example.com")

SamDependencies(mockResourceService, policyEvaluatorService, new UserService(directoryDAO, googleExt, registrationDAO, Seq.empty), new StatusService(directoryDAO, googleExt, dbRef), mockManagedGroupService, directoryDAO, policyDAO, googleExt)
SamDependencies(mockResourceService, policyEvaluatorService, new UserService(directoryDAO, googleExt, registrationDAO, Seq.empty), new StatusService(directoryDAO, registrationDAO, googleExt, dbRef), mockManagedGroupService, directoryDAO, policyDAO, googleExt)
}

def genSamRoutes(samDependencies: SamDependencies, uInfo: UserInfo)(implicit system: ActorSystem, materializer: Materializer): SamRoutes = new SamRoutes(samDependencies.resourceService, samDependencies.userService, samDependencies.statusService, samDependencies.managedGroupService, null, samDependencies.directoryDAO, samDependencies.policyEvaluatorService, LiquibaseConfig("", false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import akka.http.scaladsl.testkit.ScalatestRouteTest
import org.broadinstitute.dsde.workbench.model.ErrorReportJsonSupport._
import org.broadinstitute.dsde.workbench.model.WorkbenchIdentityJsonSupport._
import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.sam.dataAccess.{MockAccessPolicyDAO, MockDirectoryDAO}
import org.broadinstitute.dsde.workbench.sam.dataAccess.{MockAccessPolicyDAO, MockDirectoryDAO, MockRegistrationDAO}
import org.broadinstitute.dsde.workbench.sam.model.SamJsonSupport._
import org.broadinstitute.dsde.workbench.sam.model._
import org.broadinstitute.dsde.workbench.sam.service.UserService.genRandom
Expand Down Expand Up @@ -49,13 +49,13 @@ class ResourceRoutesSpec extends AnyFlatSpec with Matchers with ScalatestRouteTe
private def createSamRoutes(resourceTypes: Map[ResourceTypeName, ResourceType], userInfo: UserInfo = defaultUserInfo) = {
val accessPolicyDAO = new MockAccessPolicyDAO(resourceTypes)
val directoryDAO = new MockDirectoryDAO()
val registrationDAO = new MockDirectoryDAO()
val registrationDAO = new MockRegistrationDAO()

val emailDomain = "example.com"
val policyEvaluatorService = PolicyEvaluatorService(emailDomain, resourceTypes, accessPolicyDAO, directoryDAO)
val mockResourceService = new ResourceService(resourceTypes, policyEvaluatorService, accessPolicyDAO, directoryDAO, NoExtensions, emailDomain)
val mockUserService = new UserService(directoryDAO, NoExtensions, registrationDAO, Seq.empty)
val mockStatusService = new StatusService(directoryDAO, NoExtensions, TestSupport.dbRef)
val mockStatusService = new StatusService(directoryDAO, registrationDAO, NoExtensions, TestSupport.dbRef)
val mockManagedGroupService = new ManagedGroupService(mockResourceService, policyEvaluatorService, resourceTypes, accessPolicyDAO, directoryDAO, NoExtensions, emailDomain)

mockUserService.createUser(CreateWorkbenchUser(defaultUserInfo.userId, defaultGoogleSubjectId, defaultUserInfo.userEmail, None), samRequestContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.broadinstitute.dsde.workbench.model.ErrorReportJsonSupport._
import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.sam.TestSupport.{configResourceTypes, genGoogleSubjectId}
import org.broadinstitute.dsde.workbench.sam.api.TestSamRoutes.SamResourceActionPatterns
import org.broadinstitute.dsde.workbench.sam.dataAccess.{MockAccessPolicyDAO, MockDirectoryDAO}
import org.broadinstitute.dsde.workbench.sam.dataAccess.{MockAccessPolicyDAO, MockDirectoryDAO, MockRegistrationDAO}
import org.broadinstitute.dsde.workbench.sam.model.RootPrimitiveJsonSupport._
import org.broadinstitute.dsde.workbench.sam.model.SamJsonSupport._
import org.broadinstitute.dsde.workbench.sam.model._
Expand Down Expand Up @@ -45,7 +45,7 @@ class ResourceRoutesV2Spec extends AnyFlatSpec with Matchers with TestSupport wi
userInfo: UserInfo = defaultUserInfo): SamRoutes = {
val accessPolicyDAO = new MockAccessPolicyDAO(resourceTypes)
val directoryDAO = new MockDirectoryDAO()
val registrationDAO = new MockDirectoryDAO()
val registrationDAO = new MockRegistrationDAO()
val emailDomain = "example.com"

val policyEvaluatorService = mock[PolicyEvaluatorService](RETURNS_SMART_NULLS)
Expand All @@ -54,7 +54,7 @@ class ResourceRoutesV2Spec extends AnyFlatSpec with Matchers with TestSupport wi
when(mockResourceService.getResourceType(resourceTypeName)).thenReturn(IO(Option(resourceType)))
}
val mockUserService = new UserService(directoryDAO, NoExtensions, registrationDAO, Seq.empty)
val mockStatusService = new StatusService(directoryDAO, NoExtensions, TestSupport.dbRef)
val mockStatusService = new StatusService(directoryDAO, registrationDAO, NoExtensions, TestSupport.dbRef)
val mockManagedGroupService = new ManagedGroupService(mockResourceService, policyEvaluatorService, resourceTypes, accessPolicyDAO, directoryDAO, NoExtensions, emailDomain)

mockUserService.createUser(CreateWorkbenchUser(defaultUserInfo.userId, genGoogleSubjectId(), defaultUserInfo.userEmail, None), samRequestContext)
Expand Down
Loading

0 comments on commit 80f3066

Please sign in to comment.