Skip to content

Commit

Permalink
Clean up API, fix duplicate notification on child changed in watchChi…
Browse files Browse the repository at this point in the history
…ldrenWithData, add unit tests
  • Loading branch information
John Corwin committed Jul 6, 2010
1 parent 183b019 commit 896117f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 48 deletions.
2 changes: 1 addition & 1 deletion project/build.properties
Expand Up @@ -3,7 +3,7 @@
project.organization=com.twitter
project.name=zookeeper-client
scala.version=2.7.7
project.version=1.3
project.version=1.4
sbt.version=0.7.3
def.scala.version=2.7.7
build.scala.versions=2.7.7
Expand Down
2 changes: 1 addition & 1 deletion project/build/ZookeeperClientProject.scala
Expand Up @@ -11,7 +11,7 @@ class ZookeeperClientProject(info: ProjectInfo) extends StandardProject(info) {
val apache = "apache" at "http://people.apache.org/repo/m2-ibiblio-rsync-repository/"

// dependencies
val specs = "org.scala-tools.testing" % "specs" % "1.6.2"
val specs = "org.scala-tools.testing" % "specs" % "1.6.2.1" % "test"
val vscaladoc = "org.scala-tools" % "vscaladoc" % "1.1-md-3"
val markdownj = "markdownj" % "markdownj" % "1.0.2b4-0.3.0"
val slf4jApi = "org.slf4j" % "slf4j-api" % "1.5.8"
Expand Down
87 changes: 43 additions & 44 deletions src/main/scala/com/twitter/zookeeper/ZooKeeperClient.scala
Expand Up @@ -12,20 +12,16 @@ import net.lag.logging.Logger
import net.lag.configgy.ConfigMap
import java.util.concurrent.CountDownLatch

// Watch helpers
class ZKWatch(watch : WatchedEvent => Unit) extends Watcher {
override def process(event : WatchedEvent) { watch(event) }
}

object ZKWatch {
def apply(watch : WatchedEvent => Unit) = { new ZKWatch(watch) }
}

class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, watcher: Option[ZKWatch]) extends Watcher {
class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, watcher: WatchedEvent => Unit) extends Watcher {
private val log = Logger.get
private val connectionLatch = new CountDownLatch(1)
private val zk = new ZooKeeper(servers, sessionTimeout, this)
connectionLatch.await()
private val zk = connect()

private def connect() = {
val zkClient = new ZooKeeper(servers, sessionTimeout, this)
connectionLatch.await()
zkClient
}

def process(event : WatchedEvent) {
log.info("Zookeeper event: %s".format(event))
Expand All @@ -35,43 +31,21 @@ class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, w
}
case _ =>
}
watcher.map(w => w.process(event))
watcher(event)
}

def getHandle : ZooKeeper = zk

def this(servers: String, sessionTimeout: Int, basePath : String, watcher: ZKWatch) = {
this(servers, sessionTimeout, basePath, Some(watcher))
}

def this(servers: String, sessionTimeout: Int, basePath : String) = {
this(servers, sessionTimeout, basePath, None)
}

def this(config: ConfigMap, watcher: Option[ZKWatch]) = {
def this(config: ConfigMap, watcher: WatchedEvent => Unit) = {
this(config.getString("zookeeper-client.hostlist").get, // Must be set. No sensible default.
config.getInt("zookeeper-client.session-timeout", 3000),
config.getString("base-path", ""),
watcher)
}

def this(config: ConfigMap, watcher: ZKWatch) = {
this(config, Some(watcher))
}

def this(config: ConfigMap) = {
this(config, None)
}

def this(servers: String, watcher: Option[ZKWatch]) =
def this(servers: String, watcher: WatchedEvent => Unit) =
this(servers, 3000, "", watcher)

def this(servers: String, watcher: ZKWatch) =
this(servers, Some(watcher))

def this(servers: String) =
this(servers, None)

/**
* Given a string representing a path, return each subpath
* Ex. subPaths("/a/b/c", "/") == ["/a", "/a/b", "/a/b/c"]
Expand Down Expand Up @@ -119,16 +93,32 @@ class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, w
zk.getData(makeNodePath(path), false, null)
}

def set(path: String, data: Array[Byte]) {
zk.setData(makeNodePath(path), data, -1)
}

def delete(path: String) {
zk.delete(makeNodePath(path), -1)
}

/**
* Delete a node along with all of its children
*/
def deleteRecursive(path : String) {
val children = getChildren(path)
for (node <- children) {
deleteRecursive(path + '/' + node)
}
delete(path)
}

/**
* Watches a node. When the node's data is changed, onDataChanged will be called with the
* new data value as a byte array. If the node is deleted, onDataChanged will be called with
* None and will track the node's re-creation with an existence watch.
*/
def watchNode(node : String, onDataChanged : Option[Array[Byte]] => Unit) {
log.debug("Watching node %s", node)
val path = makeNodePath(node)
def updateData {
try {
Expand All @@ -140,20 +130,22 @@ class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, w
}
}
}

def deletedData {
onDataChanged(None)
if (zk.exists(path, dataGetter) != null) {
// Node was re-created by the time we called zk.exist
updateData
}
}
def dataGetter : ZKWatch = ZKWatch {
event =>
def dataGetter = new Watcher {
def process(event : WatchedEvent) {
if (event.getType == EventType.NodeDataChanged || event.getType == EventType.NodeCreated) {
updateData
} else if (event.getType == EventType.NodeDeleted) {
deletedData
}
}
}
updateData
}
Expand All @@ -165,16 +157,20 @@ class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, w
*/
def watchChildren(node : String, updateChildren : Seq[String] => Unit) {
val path = makeNodePath(node)
val childWatcher : ZKWatch =
ZKWatch {event =>
val childWatcher = new Watcher {
def process(event : WatchedEvent) {
if (event.getType == EventType.NodeChildrenChanged ||
event.getType == EventType.NodeCreated)
watchChildren(node, updateChildren)}
event.getType == EventType.NodeCreated) {
watchChildren(node, updateChildren)
}
}
}
try {
val children = zk.getChildren(path, childWatcher)
updateChildren(children)
} catch {
case e:KeeperException => {
// Node was deleted -- fire a watch on node re-creation
log.warning("Failed to read node %s: %s", path, e)
updateChildren(List())
zk.exists(path, childWatcher)
Expand Down Expand Up @@ -221,14 +217,17 @@ class ZooKeeperClient(servers: String, sessionTimeout: Int, basePath : String, w
watchMap.synchronized {
// remove deleted children from the watch map
for (child <- removedChildren) {
log.ifDebug {"Node %s: child %s removed".format(node, child)}
watchMap -= child
}
// add new children to the watch map
for (child <- addedChildren) {
// node is added via nodeChanged callback
log.ifDebug {"Node %s: child %s added".format(node, child)}
watchNode("%s/%s".format(node, child), nodeChanged(child))
}
}
for (child <- removedChildren ++ addedChildren) {
for (child <- removedChildren) {
notifier.map(f => f(child))
}
}
Expand Down
79 changes: 77 additions & 2 deletions src/test/scala/com/twitter/zookeeper/ZooKeeperClientSpec.scala
Expand Up @@ -8,14 +8,14 @@ import org.apache.zookeeper.KeeperException.NoNodeException
import org.apache.zookeeper.data.{ACL, Id}
import org.specs._
import net.lag.configgy.Configgy
import scala.collection.mutable

class ZookeeperClientSpec extends Specification {
"ZookeeperClient" should {
Configgy.configure("src/main/resources/config.conf")

val watcher = ZKWatch((a: WatchedEvent) => {})
val configMap = Configgy.config
val zkClient = new ZooKeeperClient(configMap, watcher)
val zkClient = new ZooKeeperClient(configMap, (event : WatchedEvent) => {})

doBefore {
// we need to be sure that a ZooKeeper server is running in order to test
Expand Down Expand Up @@ -57,5 +57,80 @@ class ZookeeperClientSpec extends Specification {
zkClient.create("/foo", data, createMode) mustEqual "/foo"
zkClient.delete("/foo")
}

"watch a node" in {
val data: Array[Byte] = Array(0x63)
val node = "/datanode"
val createMode = EPHEMERAL
var watchCount = 0
def watcher(data : Option[Array[Byte]]) {
watchCount += 1
}
zkClient.create(node, data, createMode)
zkClient.watchNode(node, watcher)
Thread.sleep(50L)
watchCount mustEqual 1
zkClient.delete("/datanode")
}

"watch a tree of nodes" in {
var children : Seq[String] = List()
var watchCount = 0
def watcher(nodes : Seq[String]) {
watchCount += 1
children = nodes
}
zkClient.createPath("/tree/a")
zkClient.createPath("/tree/b")
zkClient.watchChildren("/tree", watcher)
children.size mustEqual 2
children must containAll(List("a", "b"))
watchCount mustEqual 1
zkClient.createPath("/tree/c")
Thread.sleep(50L)
children.size mustEqual 3
children must containAll(List("a", "b", "c"))
watchCount mustEqual 2
zkClient.delete("/tree/a")
Thread.sleep(50L)
children.size mustEqual 2
children must containAll(List("b", "c"))
watchCount mustEqual 3
zkClient.deleteRecursive("/tree")
}

"watch a tree of nodes with data" in {
def mkNode(node : String) {
zkClient.create("/root/" + node, node.getBytes, CreateMode.EPHEMERAL)
}
var children : mutable.Map[String,String] = mutable.Map()
var watchCount = 0
def notifier(child : String) {
watchCount += 1
if (children.contains(child)) {
children(child) mustEqual child
}
}
zkClient.createPath("/root")
mkNode("a")
mkNode("b")
zkClient.watchChildrenWithData("/root", children,
{(b : Array[Byte]) => new String(b)}, notifier)
children.size mustEqual 2
children.keySet must containAll(List("a", "b"))
watchCount mustEqual 2
mkNode("c")
Thread.sleep(50L)
children.size mustEqual 3
children.keySet must containAll(List("a", "b", "c"))
watchCount mustEqual 3
zkClient.delete("/root/a")
Thread.sleep(50L)
children.size mustEqual 2
children.keySet must containAll(List("b", "c"))
watchCount mustEqual 4
zkClient.deleteRecursive("/root")
}

}
}

0 comments on commit 896117f

Please sign in to comment.