Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
kostaskougios committed Dec 12, 2023
1 parent 6ea30c0 commit cda0b98
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import endtoend.tests.helidon.{TestsHelidonFunctions, TestsHelidonFunctionsCalle
import functions.fibers.FiberExecutor
import functions.helidon.transport.HelidonWsTransport
import functions.helidon.ws.ServerWsListener
import functions.helidon.ws.transport.ClientServerWsListener
import functions.helidon.ws.transport.exceptions.RemoteFunctionFailedException
import functions.model.Serializer
import functions.model.Serializer.{Avro, Json}
Expand All @@ -16,21 +17,22 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers.*

import java.net.URI
import scala.util.Using

class EndToEndHelidonWsSuite extends AnyFunSuite:
def withServer[R](f: (WebServer, CountingHelidonFunctionsImpl) => R): R =
val impl = new CountingHelidonFunctionsImpl
val invokeMap = TestsHelidonFunctionsReceiverFactory.invokerMap(impl)
val listener = new ServerWsListener(invokeMap)

val wsB = WsRouting.builder().endpoint("/ws-test", listener)
val server = WebServer.builder
.port(0)
.addRouting(wsB)
.build
.start
try f(server, impl)
finally server.stop()
FiberExecutor.withFiberExecutor: executor =>
val impl = new CountingHelidonFunctionsImpl
val invokeMap = TestsHelidonFunctionsReceiverFactory.invokerMap(impl)
Using.resource(ServerWsListener(invokeMap, executor)): listener =>
val wsB = WsRouting.builder().endpoint("/ws-test", listener)
val server = WebServer.builder
.port(0)
.addRouting(wsB)
.build
.start
try f(server, impl)
finally server.stop()

def withTransport[R](serverPort: Int, serializer: Serializer)(f: TestsHelidonFunctions => R): R =
FiberExecutor.withFiberExecutor: executor =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ class InOutMessageProtocol(invokerMap: InvokerMap, myId: Int = Random.nextInt())
case (c4, i) =>
(c4.toRawCoordinates, i)

def listener(buffer: BufferData) =
def listener(buffer: BufferData): Either[BufferData, RfWsResponse] =
buffer.read() match
case 100 =>
// a call to a function
serverListener(buffer)
Left(serverListener(buffer))
case 200 =>
// return value of a call
Right(clientListener(buffer))
case x => throw new IllegalStateException(s"invalid data received : $x")

def serverListener(buffer: BufferData): BufferData =
private def serverListener(buffer: BufferData): BufferData =
val receiverId = buffer.readInt32()
val corId = buffer.readLong()
val coordsLength = buffer.readUnsignedInt32()
Expand All @@ -37,6 +40,7 @@ class InOutMessageProtocol(invokerMap: InvokerMap, myId: Int = Random.nextInt())
val f = im(coordinates4.toRawCoordinates)
val response = f(ReceiverInput(data, arg))
val buf = BufferData.growing(response.length + 12)
buf.write(200)
buf.writeInt32(receiverId)
buf.write(0)
buf.write(longToBytes(corId))
Expand All @@ -50,13 +54,14 @@ class InOutMessageProtocol(invokerMap: InvokerMap, myId: Int = Random.nextInt())
w.close()
val data = bos.toByteArray
val buf = BufferData.growing(data.length + 12)
buf.write(200)
buf.writeInt32(receiverId)
buf.write(1)
buf.write(longToBytes(corId))
buf.write(data)
buf

def clientListener(buffer: BufferData): RfWsResponse =
private def clientListener(buffer: BufferData): RfWsResponse =
val receivedId = buffer.readInt32()
if receivedId != myId then throw new IllegalStateException(s"Received an invalid client id : $receivedId , it should be my id of $myId")
val result = buffer.read()
Expand All @@ -66,6 +71,7 @@ class InOutMessageProtocol(invokerMap: InvokerMap, myId: Int = Random.nextInt())

def clientTransport(corId: Long, data: Array[Byte], argsData: Array[Byte], coordsData: Array[Byte]): BufferData =
val buf = BufferData.growing(data.length + argsData.length + coordsData.length + 32)
buf.write(100) // a call
buf.writeInt32(myId)
buf.write(longToBytes(corId))
buf.writeUnsignedInt32(coordsData.length)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package functions.helidon.ws.transport

import functions.fibers.FiberExecutor
import ClientWsListener.PoisonPill
import ClientServerWsListener.PoisonPill
import functions.helidon.ws.InOutMessageProtocol
import functions.helidon.ws.model.RfWsResponse
import functions.helidon.ws.transport.exceptions.RemoteFunctionFailedException
Expand All @@ -13,7 +13,7 @@ import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit}
import scala.annotation.tailrec
import scala.util.Using.Releasable

class ClientWsListener(protocol: InOutMessageProtocol, fiberExecutor: FiberExecutor, sendResponseTimeoutInMillis: Long) extends WsListener:
class ClientServerWsListener(protocol: InOutMessageProtocol, fiberExecutor: FiberExecutor, sendResponseTimeoutInMillis: Long) extends WsListener:
private val toSend = new LinkedBlockingQueue[BufferData](64)
private val latchMap = collection.concurrent.TrieMap.empty[Long, CountDownLatch]
private val dataMap = collection.concurrent.TrieMap.empty[Long, (Int, Array[Byte])]
Expand All @@ -38,13 +38,18 @@ class ClientWsListener(protocol: InOutMessageProtocol, fiberExecutor: FiberExecu

override def onMessage(session: WsSession, buffer: BufferData, last: Boolean): Unit =
try
val RfWsResponse(result, corId, data) = protocol.clientListener(buffer)
latchMap.get(corId) match
case Some(latch) =>
dataMap.put(corId, (result, data))
latch.countDown()
case None =>
println(s"Correlation id missing: $corId , received data ignored.")
protocol.listener(buffer) match
case Left(out) =>
// act as a server: do a call
session.send(out, true)
case Right(RfWsResponse(result, corId, data)) =>
// act as a client: respond to a call
latchMap.get(corId) match
case Some(latch) =>
dataMap.put(corId, (result, data))
latch.countDown()
case None =>
println(s"Correlation id missing: $corId , received data ignored.")
catch case t: Throwable => t.printStackTrace()

override def onOpen(session: WsSession): Unit =
Expand All @@ -61,7 +66,7 @@ class ClientWsListener(protocol: InOutMessageProtocol, fiberExecutor: FiberExecu
toSend.put(PoisonPill)
for latch <- latchMap.values do latch.countDown()

object ClientWsListener:
object ClientServerWsListener:
private val PoisonPill = BufferData.create("*poisonpill*")

given Releasable[ClientWsListener] = _.close()
given Releasable[ClientServerWsListener] = _.close()
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package functions.helidon.transport

import functions.fibers.FiberExecutor
import functions.helidon.ws.InOutMessageProtocol
import functions.helidon.ws.transport.ClientWsListener
import functions.helidon.ws.transport.ClientServerWsListener
import functions.model.TransportInput
import io.helidon.websocket.WsListener

Expand All @@ -11,7 +11,7 @@ import scala.util.Using.Releasable

class HelidonWsTransport(fiberExecutor: FiberExecutor, sendResponseTimeoutInMillis: Long):
private val protocol = new InOutMessageProtocol(Map.empty)
private val wsListener = new ClientWsListener(protocol, fiberExecutor, sendResponseTimeoutInMillis)
private val wsListener = new ClientServerWsListener(protocol, fiberExecutor, sendResponseTimeoutInMillis)
private val correlationId = new AtomicLong(0)

def clientWsListener: WsListener = wsListener
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package functions.helidon.ws

import functions.fibers.FiberExecutor
import functions.helidon.ws.transport.ClientServerWsListener
import functions.model.InvokerMap
import io.helidon.common.buffers.BufferData
import io.helidon.websocket.{WsListener, WsSession}

class ServerWsListener(invokerMap: InvokerMap) extends WsListener:

private val protocol = new InOutMessageProtocol(invokerMap)

override def onMessage(session: WsSession, buffer: BufferData, last: Boolean): Unit =
val out = protocol.serverListener(buffer)
session.send(out, true)
object ServerWsListener:
def apply(invokerMap: InvokerMap, fiberExecutor: FiberExecutor, sendTimeoutInMillis: Long = 4000) =
new ClientServerWsListener(new InOutMessageProtocol(invokerMap), fiberExecutor, sendTimeoutInMillis)

0 comments on commit cda0b98

Please sign in to comment.