Skip to content

Commit

Permalink
Move on fs2
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed May 1, 2018
1 parent 45d352f commit ff71b95
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 194 deletions.
Expand Up @@ -21,17 +21,17 @@ import geotrellis.spark.io.avro.codecs.KeyValueRecordCodec
import geotrellis.spark.io.avro.{AvroEncoder, AvroRecordCodec}
import geotrellis.spark.{Boundable, KeyBounds}

import scalaz.std.vector._
import scalaz.concurrent.{Strategy, Task}
import scalaz.stream.{Process, channel, nondeterminism, tee}
import org.apache.accumulo.core.data.{Range => AccumuloRange}
import org.apache.accumulo.core.security.Authorizations
import org.apache.avro.Schema
import org.apache.hadoop.io.Text
import com.typesafe.config.ConfigFactory
import cats.effect.IO

import scala.concurrent.ExecutionContext
import scala.collection.JavaConversions._
import scala.reflect.ClassTag

import java.util.concurrent.Executors

object AccumuloCollectionReader {
Expand All @@ -52,34 +52,37 @@ object AccumuloCollectionReader {
val ranges = queryKeyBounds.flatMap(decomposeBounds).toIterator

val pool = Executors.newFixedThreadPool(threads)
implicit val ec = ExecutionContext.fromExecutor(pool)

val range: Process[Task, AccumuloRange] = Process.unfold(ranges) { iter =>
val range: fs2.Stream[IO, AccumuloRange] = fs2.Stream.unfold(ranges) { iter =>
if (iter.hasNext) Some(iter.next(), iter)
else None
}

val readChannel = channel.lift { (range: AccumuloRange) => Task {
val read = { (range: AccumuloRange) => fs2.Stream eval IO {
val scanner = instance.connector.createScanner(table, new Authorizations())
scanner.setRange(range)
scanner.fetchColumnFamily(columnFamily)
val result = scanner.iterator
.map({ case entry =>
AvroEncoder.fromBinary(writerSchema.getOrElse(codec.schema), entry.getValue.get)(codec) })
.flatMap({ pairs: Vector[(K, V)] =>
if(filterIndexOnly) pairs
else pairs.filter { pair => includeKey(pair._1) } })
.toVector
scanner.close()
result
}(pool) }

val read = range.tee(readChannel)(tee.zipApply).map(Process.eval)

val result =
scanner
.iterator
.map({ entry => AvroEncoder.fromBinary(writerSchema.getOrElse(codec.schema), entry.getValue.get)(codec) })
.flatMap({ pairs: Vector[(K, V)] =>
if (filterIndexOnly) pairs
else pairs.filter { pair => includeKey(pair._1) }
}).toVector
scanner.close()
result
}
}

try {
nondeterminism
.njoin(maxOpen = threads, maxQueued = threads) { read }(Strategy.Executor(pool))
.runFoldMap(identity).unsafePerformSync: Seq[(K, V)]
(range map read)
.join(threads)
.compile
.toVector
.map(_.flatten)
.unsafeRunSync
} finally pool.shutdown()
}
}
Expand Up @@ -27,9 +27,9 @@ import org.apache.accumulo.core.data.{Key, Mutation, Value}
import org.apache.accumulo.core.client.mapreduce.AccumuloFileOutputFormat
import org.apache.accumulo.core.client.BatchWriterConfig
import com.typesafe.config.ConfigFactory
import scalaz.concurrent.{Strategy, Task}
import scalaz.stream._
import cats.effect.IO

import scala.concurrent.ExecutionContext
import java.util.UUID
import java.util.concurrent.Executors

Expand Down Expand Up @@ -119,13 +119,13 @@ case class SocketWriteStrategy(
if(partition.nonEmpty) {
val poolSize = kwThreads.value
val pool = Executors.newFixedThreadPool(poolSize)
implicit val ec = ExecutionContext.fromExecutor(pool)
val config = serializeWrapper.value
val writer = instance.connector.createBatchWriter(table, config)

try {

val mutations: Process[Task, Mutation] =
Process.unfold(partition){ iter =>
val mutations: fs2.Stream[IO, Mutation] =
fs2.Stream.unfold(partition){ iter =>
if (iter.hasNext) {
val (key, value) = iter.next()
val mutation = new Mutation(key.getRow)
Expand All @@ -136,16 +136,9 @@ case class SocketWriteStrategy(
}
}

val writeChannel =
channel.lift { (mutation: Mutation) =>
Task { writer.addMutation(mutation) } (pool)
}

val writes = mutations.tee(writeChannel)(tee.zipApply).map(Process.eval)
val write = { (mutation: Mutation) => fs2.Stream eval IO { writer.addMutation(mutation) } }

val joined =
nondeterminism.njoin(maxOpen = poolSize, maxQueued = poolSize)(writes)(Strategy.Executor(pool))
joined.run.unsafePerformSync
(mutations map write).join(threads).compile.toVector.unsafeRunSync()
} finally {
writer.close(); pool.shutdown()
}
Expand Down
Expand Up @@ -27,22 +27,17 @@ import com.datastax.driver.core.querybuilder.QueryBuilder
import com.datastax.driver.core.querybuilder.QueryBuilder.{eq => eqs}
import com.datastax.driver.core.ResultSet
import com.datastax.driver.core.schemabuilder.SchemaBuilder

import cats.effect.IO
import org.apache.avro.Schema
import org.apache.spark.rdd.RDD

import com.typesafe.config.ConfigFactory

import scalaz.concurrent.{Strategy, Task}
import scalaz.stream.{Process, nondeterminism}

import java.nio.ByteBuffer
import java.util.concurrent.Executors

import scala.collection.JavaConversions._

import java.math.BigInteger

import scala.collection.JavaConversions._
import scala.concurrent.ExecutionContext

object CassandraRDDWriter {
final val DefaultThreadCount =
Expand Down Expand Up @@ -114,18 +109,19 @@ object CassandraRDDWriter {
val readStatement = session.prepare(readQuery)
val writeStatement = session.prepare(writeQuery)

val rows: Process[Task, (BigInt, Vector[(K,V)])] =
Process.unfold(partition)({ iter =>
val rows: fs2.Stream[IO, (BigInt, Vector[(K,V)])] =
fs2.Stream.unfold(partition)({ iter =>
if (iter.hasNext) {
val record = iter.next()
Some((record._1, record._2.toVector), iter)
} else None
})

val pool = Executors.newFixedThreadPool(threads)
implicit val ec = ExecutionContext.fromExecutor(pool)

def elaborateRow(row: (BigInt, Vector[(K,V)])): Process[Task, (BigInt, Vector[(K,V)])] = {
Process eval Task ({
def elaborateRow(row: (BigInt, Vector[(K,V)])): fs2.Stream[IO, (BigInt, Vector[(K,V)])] = {
fs2.Stream eval IO ({
val (key, kvs1) = row
val kvs2 =
if (mergeFunc.nonEmpty) {
Expand All @@ -138,43 +134,43 @@ object CassandraRDDWriter {
val kvs = mergeFunc match {
case Some(fn) =>
(kvs2 ++ kvs1)
.groupBy({ case (k,v) => k })
.groupBy({ case (k, v) => k })
.map({ case (k, kvs) =>
val vs = kvs.map({ case (k,v) => v }).toSeq
val vs = kvs.map({ case (_, v) => v })
val v: V = vs.tail.foldLeft(vs.head)(fn)
(k, v) })
.toVector
case None => kvs1
}
(key, kvs)
})(pool)
})
}

def rowToBytes(row: (BigInt, Vector[(K,V)])): Process[Task, (BigInt, ByteBuffer)] = {
Process eval Task({
def rowToBytes(row: (BigInt, Vector[(K,V)])): fs2.Stream[IO, (BigInt, ByteBuffer)] = {
fs2.Stream eval IO ({
val (key, kvs) = row
val bytes = ByteBuffer.wrap(AvroEncoder.toBinary(kvs)(codec))
(key, bytes)
})(pool)
})
}

def retire(row: (BigInt, ByteBuffer)): Process[Task, ResultSet] = {
def retire(row: (BigInt, ByteBuffer)): fs2.Stream[IO, ResultSet] = {
val (id, value) = row
Process eval Task({
fs2.Stream eval IO ({
session.execute(writeStatement.bind(id: BigInteger, value))
})(pool)
})
}

val results = nondeterminism.njoin(maxOpen = threads, maxQueued = threads) {
rows flatMap elaborateRow flatMap rowToBytes map retire
}(Strategy.Executor(pool)) onComplete {
Process eval Task {
session.closeAsync()
session.getCluster.closeAsync()
}(pool)
}
val results = (rows flatMap elaborateRow flatMap rowToBytes map retire)
.join(threads)
.onComplete {
fs2.Stream eval IO {
session.closeAsync()
session.getCluster.closeAsync()
}
}

results.run.unsafePerformSync
results.compile.toVector.unsafeRunSync()
pool.shutdown()
}
}
Expand Down
3 changes: 2 additions & 1 deletion project/Dependencies.scala
Expand Up @@ -40,7 +40,8 @@ object Dependencies {

val awsSdkS3 = "com.amazonaws" % "aws-java-sdk-s3" % "1.11.143"

val scalazStream = "org.scalaz.stream" %% "scalaz-stream" % "0.8.6a"
val fs2Core = "co.fs2" %% "fs2-core" % "0.10.3"
val fs2Io = "co.fs2" %% "fs2-io" % "0.10.3"

val sparkCore = "org.apache.spark" %% "spark-core" % Version.spark
val hadoopClient = "org.apache.hadoop" % "hadoop-client" % Version.hadoop
Expand Down
11 changes: 1 addition & 10 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3RDDReader.scala
Expand Up @@ -23,22 +23,13 @@ import geotrellis.spark.io.index.{IndexRanges, MergeQueue}
import geotrellis.spark.io.avro.{AvroEncoder, AvroRecordCodec}
import geotrellis.spark.util.KryoWrapper

import scalaz.concurrent.{Strategy, Task}
import scalaz.std.vector._
import scalaz.stream.{Process, nondeterminism}

import com.typesafe.config.ConfigFactory
import com.amazonaws.services.s3.model.AmazonS3Exception

import org.apache.avro.Schema
import org.apache.commons.io.IOUtils
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import com.typesafe.config.ConfigFactory

import java.util.concurrent.Executors


trait S3RDDReader {
final val DefaultThreadCount =
ConfigFactory.load().getThreads("geotrellis.s3.threads.rdd.read")
Expand Down
44 changes: 22 additions & 22 deletions s3/src/main/scala/geotrellis/spark/io/s3/S3RDDWriter.scala
Expand Up @@ -21,20 +21,17 @@ import geotrellis.spark.io.avro._
import geotrellis.spark.io.avro.codecs.KeyValueRecordCodec
import geotrellis.spark.util.KryoWrapper

import cats.effect.IO
import com.amazonaws.services.s3.model.{AmazonS3Exception, ObjectMetadata, PutObjectRequest, PutObjectResult}

import org.apache.avro.Schema
import org.apache.commons.io.IOUtils
import org.apache.spark.rdd.RDD

import com.typesafe.config.ConfigFactory

import scalaz.concurrent.{Strategy, Task}
import scalaz.stream.{Process, nondeterminism}

import java.io.ByteArrayInputStream
import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext
import scala.reflect._


Expand Down Expand Up @@ -88,9 +85,10 @@ trait S3RDDWriter {
val schema = kwWriterSchema.value.getOrElse(_recordCodec.schema)

val pool = Executors.newFixedThreadPool(threads)
implicit val ec = ExecutionContext.fromExecutor(pool)

val rows: Process[Task, (String, Vector[(K,V)])] =
Process.unfold(partition)({ iter =>
val rows: fs2.Stream[IO, (String, Vector[(K,V)])] =
fs2.Stream.unfold(partition)({ iter =>
if (iter.hasNext) {
val record = iter.next()
val key = record._1
Expand All @@ -99,8 +97,8 @@ trait S3RDDWriter {
} else None
})

def elaborateRow(row: (String, Vector[(K,V)])): Process[Task, (String, Vector[(K,V)])] = {
Process eval Task({
def elaborateRow(row: (String, Vector[(K,V)])): fs2.Stream[IO, (String, Vector[(K,V)])] = {
fs2.Stream eval IO ({
val (key, kvs1) = row
val kvs2: Vector[(K,V)] =
if (mergeFunc.nonEmpty) {
Expand All @@ -115,43 +113,45 @@ trait S3RDDWriter {
mergeFunc match {
case Some(fn) =>
(kvs2 ++ kvs1)
.groupBy({ case (k,v) => k })
.groupBy({ case (k, _) => k })
.map({ case (k, kvs) =>
val vs = kvs.map({ case (k,v) => v }).toSeq
val vs = kvs.map({ case (_, v) => v })
val v: V = vs.tail.foldLeft(vs.head)(fn)
(k, v) })
.toVector
case None => kvs1
}
(key, kvs)
})(pool)
})
}

def rowToRequest(row: (String, Vector[(K,V)])): Process[Task, PutObjectRequest] = {
Process eval Task({
def rowToRequest(row: (String, Vector[(K,V)])): fs2.Stream[IO, PutObjectRequest] = {
fs2.Stream eval IO ({
val (key, kvs) = row
val bytes = AvroEncoder.toBinary(kvs)(_codec)
val metadata = new ObjectMetadata()
metadata.setContentLength(bytes.length)
val is = new ByteArrayInputStream(bytes)
putObjectModifier(new PutObjectRequest(bucket, key, is, metadata))
})(pool)
})
}

def retire(request: PutObjectRequest): Process[Task, PutObjectResult] = {
Process eval Task({
def retire(request: PutObjectRequest): fs2.Stream[IO, PutObjectResult] = {
fs2.Stream eval IO ({
request.getInputStream.reset() // reset in case of retransmission to avoid 400 error
s3client.putObject(request)
})(pool).retryEBO {
}).retryEBO {
case e: AmazonS3Exception if e.getStatusCode == 503 => true
case _ => false
}
}

val results = nondeterminism.njoin(maxOpen = threads, maxQueued = threads) {
rows flatMap elaborateRow flatMap rowToRequest map retire
}(Strategy.Executor(pool))
results.run.unsafePerformSync
(rows flatMap elaborateRow flatMap rowToRequest map retire)
.join(threads)
.compile
.toVector
.unsafeRunSync()

pool.shutdown()
}
}
Expand Down

0 comments on commit ff71b95

Please sign in to comment.