From aba6619273321d20c6d043996a4bdcf32d0f8916 Mon Sep 17 00:00:00 2001 From: Gia Thuan Lam Date: Sat, 30 Aug 2025 01:18:01 +0000 Subject: [PATCH 1/4] Redesign project formatting. --- .bazelignore | 1 + .github/workflows/bazel.yml | 4 +++ BUILD.bazel | 34 ++++++++++++++++++- MODULE.bazel | 19 +++++++++++ REPO.bazel | 2 +- examples/build.gradle.kts | 13 ------- examples/gradle/libs.versions.toml | 1 - .../gradle/wrapper/gradle-wrapper.properties | 2 +- formatter/extensions.bzl => extensions.bzl | 0 formatter/.bazelversion | 1 - formatter/BUILD.bazel | 30 ---------------- formatter/MODULE.bazel | 23 ------------- 12 files changed, 59 insertions(+), 71 deletions(-) create mode 100644 .bazelignore rename formatter/extensions.bzl => extensions.bzl (100%) delete mode 100644 formatter/.bazelversion delete mode 100644 formatter/BUILD.bazel delete mode 100644 formatter/MODULE.bazel diff --git a/.bazelignore b/.bazelignore new file mode 100644 index 00000000..fde6025a --- /dev/null +++ b/.bazelignore @@ -0,0 +1 @@ +bzl-examples \ No newline at end of file diff --git a/.github/workflows/bazel.yml b/.github/workflows/bazel.yml index 877e3122..4d298d2c 100644 --- a/.github/workflows/bazel.yml +++ b/.github/workflows/bazel.yml @@ -26,6 +26,10 @@ jobs: mkdir -p "${GITHUB_WORKSPACE}/bin/" mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel" chmod +x "${GITHUB_WORKSPACE}/bin/bazel" + - name: Lint - Please run `bazelisk run //:format` + run: | + cd "${GITHUB_WORKSPACE}" + "${GITHUB_WORKSPACE}/bin/bazel" run //:format.check - name: Build uses: nick-invision/retry@v3 with: diff --git a/BUILD.bazel b/BUILD.bazel index c8682c27..c03f42c2 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1 +1,33 @@ -# Deliberately empty +load("@aspect_rules_lint//format:defs.bzl", "format_multirun") +load("@rules_java//java:defs.bzl", "java_binary") + +java_binary( + name = "ktfmt", + main_class = "com.facebook.ktfmt.cli.Main", + tags = ["manual"], + runtime_deps = ["@ktfmt//jar"], +) + +java_binary( + name = "java-format", + jvm_flags = [ + "--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED", + "--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED", + "--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED", + "--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED", + "--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED", + ], + main_class = "com.google.googlejavaformat.java.Main", + tags = ["manual"], + runtime_deps = ["@google-java-format//jar"], +) + +format_multirun( + name = "format", + java = ":java-format", + kotlin = ":ktfmt", + protocol_buffer = "@rules_buf_toolchains//:buf", + starlark = "@buildifier_prebuilt//:buildifier", + tags = ["manual"], + yaml = "@aspect_rules_lint//format:yamlfmt", +) diff --git a/MODULE.bazel b/MODULE.bazel index 0879f774..e357cf91 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -11,6 +11,19 @@ bazel_dep(name = "rules_java", version = "8.15.1") bazel_dep(name = "rules_jvm_external", version = "6.8") bazel_dep(name = "grpc-java", version = "1.71.0") +# Formatter. +bazel_dep(name = "aspect_rules_lint", version = "1.6.0", dev_dependency = True) +bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2", dev_dependency = True) +bazel_dep(name = "rules_buf", version = "0.5.2", dev_dependency = True) + +http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") + +http_jar( + name = "google-java-format", + sha256 = "33068bbbdce1099982ec1171f5e202898eb35f2919cf486141e439fc6e3a4203", + url = "https://github.com/google/google-java-format/releases/download/v1.17.0/google-java-format-1.17.0-all-deps.jar", +) + grpc_kotlin_maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven") grpc_kotlin_maven.install( name = "grpc_kotlin_maven", @@ -34,3 +47,9 @@ grpc_kotlin_maven.install( strict_visibility = True, ) use_repo(grpc_kotlin_maven, "grpc_kotlin_maven") + +install_ktfmt = use_extension("//:extensions.bzl", "install_ktfmt", dev_dependency = True) +use_repo(install_ktfmt, "ktfmt") + +buf = use_extension("@rules_buf//buf:extensions.bzl", "buf", dev_dependency = True) +use_repo(buf, "rules_buf_toolchains") diff --git a/REPO.bazel b/REPO.bazel index e7e86d5e..cd364214 100644 --- a/REPO.bazel +++ b/REPO.bazel @@ -1 +1 @@ -ignore_directories(["bzl-examples", "formatter", "**/bin"]) +ignore_directories(["bzl-examples", "**/bin"]) diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index f879e496..464ef492 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -4,19 +4,6 @@ plugins { alias(libs.plugins.protobuf) apply false alias(libs.plugins.kotlin.jvm) apply false alias(libs.plugins.kotlin.android) apply false - alias(libs.plugins.ktlint) apply false -} - -subprojects { - apply(plugin = "org.jlleitschuh.gradle.ktlint") - - configure { - filter { - exclude { - it.file.path.startsWith(project.layout.buildDirectory.get().dir("generated").toString()) - } - } - } } tasks.create("assemble").dependsOn(":server:installDist") diff --git a/examples/gradle/libs.versions.toml b/examples/gradle/libs.versions.toml index 3bc02112..bf890017 100644 --- a/examples/gradle/libs.versions.toml +++ b/examples/gradle/libs.versions.toml @@ -4,7 +4,6 @@ android-library = { id = "com.android.library", version = "8.3.0" } kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version = "2.1.20" } kotlin-android = { id = "org.jetbrains.kotlin.android", version = "2.1.20" } protobuf = { id = "com.google.protobuf", version = "0.9.5" } -ktlint = { id = "org.jlleitschuh.gradle.ktlint", version = "12.1.0" } palantir-graal = { id = "com.palantir.graal", version = "0.12.0" } jib = { id = "com.google.cloud.tools.jib", version = "3.4.1" } diff --git a/examples/gradle/wrapper/gradle-wrapper.properties b/examples/gradle/wrapper/gradle-wrapper.properties index a80b22ce..2a84e188 100644 --- a/examples/gradle/wrapper/gradle-wrapper.properties +++ b/examples/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-9.0.0-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/formatter/extensions.bzl b/extensions.bzl similarity index 100% rename from formatter/extensions.bzl rename to extensions.bzl diff --git a/formatter/.bazelversion b/formatter/.bazelversion deleted file mode 100644 index 56b6be4e..00000000 --- a/formatter/.bazelversion +++ /dev/null @@ -1 +0,0 @@ -8.3.1 diff --git a/formatter/BUILD.bazel b/formatter/BUILD.bazel deleted file mode 100644 index fb675def..00000000 --- a/formatter/BUILD.bazel +++ /dev/null @@ -1,30 +0,0 @@ -load("@aspect_rules_lint//format:defs.bzl", "format_multirun") -load("@rules_java//java:defs.bzl", "java_binary") - -java_binary( - name = "ktfmt", - main_class = "com.facebook.ktfmt.cli.Main", - runtime_deps = ["@ktfmt//jar"], -) - -java_binary( - name = "java-format", - jvm_flags = [ - "--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED", - "--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED", - "--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED", - "--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED", - "--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED", - ], - main_class = "com.google.googlejavaformat.java.Main", - runtime_deps = ["@google-java-format//jar"], -) - -format_multirun( - name = "format", - java = ":java-format", - kotlin = ":ktfmt", - protocol_buffer = "@rules_buf_toolchains//:buf", - starlark = "@buildifier_prebuilt//:buildifier", - yaml = "@aspect_rules_lint//format:yamlfmt", -) diff --git a/formatter/MODULE.bazel b/formatter/MODULE.bazel deleted file mode 100644 index c8aad7d1..00000000 --- a/formatter/MODULE.bazel +++ /dev/null @@ -1,23 +0,0 @@ -module( - name = "formatter", - version = "1.0", -) - -bazel_dep(name = "aspect_rules_lint", version = "1.6.0") -bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2") -bazel_dep(name = "rules_buf", version = "0.5.2") -bazel_dep(name = "rules_java", version = "8.15.1") - -http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") - -http_jar( - name = "google-java-format", - sha256 = "33068bbbdce1099982ec1171f5e202898eb35f2919cf486141e439fc6e3a4203", - url = "https://github.com/google/google-java-format/releases/download/v1.17.0/google-java-format-1.17.0-all-deps.jar", -) - -install_ktfmt = use_extension("//:extensions.bzl", "install_ktfmt") -use_repo(install_ktfmt, "ktfmt") - -buf = use_extension("@rules_buf//buf:extensions.bzl", "buf") -use_repo(buf, "rules_buf_toolchains") From ff328a9431abb3666139dd90b47f9d5ec74d07c4 Mon Sep 17 00:00:00 2001 From: Gia Thuan Lam Date: Sat, 30 Aug 2025 16:17:50 +0000 Subject: [PATCH 2/4] Create a Google-style ktfmt wrapper. --- BUILD.bazel | 10 +++++++++- MODULE.bazel | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/BUILD.bazel b/BUILD.bazel index c03f42c2..7783b91c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,5 +1,6 @@ load("@aspect_rules_lint//format:defs.bzl", "format_multirun") load("@rules_java//java:defs.bzl", "java_binary") +load("@rules_multirun//:defs.bzl", "command") java_binary( name = "ktfmt", @@ -8,6 +9,13 @@ java_binary( runtime_deps = ["@ktfmt//jar"], ) +command( + name = "ktfmt_wrapper", + arguments = ["--google-style"], + command = ":ktfmt", + tags = ["manual"], +) + java_binary( name = "java-format", jvm_flags = [ @@ -25,7 +33,7 @@ java_binary( format_multirun( name = "format", java = ":java-format", - kotlin = ":ktfmt", + kotlin = ":ktfmt_wrapper", protocol_buffer = "@rules_buf_toolchains//:buf", starlark = "@buildifier_prebuilt//:buildifier", tags = ["manual"], diff --git a/MODULE.bazel b/MODULE.bazel index e357cf91..911b0da7 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -15,6 +15,7 @@ bazel_dep(name = "grpc-java", version = "1.71.0") bazel_dep(name = "aspect_rules_lint", version = "1.6.0", dev_dependency = True) bazel_dep(name = "buildifier_prebuilt", version = "8.2.0.2", dev_dependency = True) bazel_dep(name = "rules_buf", version = "0.5.2", dev_dependency = True) +bazel_dep(name = "rules_multirun", version = "0.13.0", dev_dependency = True) http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") From 35a1f200bbb4145ab2e14df8c7ff66668f32e217 Mon Sep 17 00:00:00 2001 From: Gia Thuan Lam Date: Sat, 30 Aug 2025 05:01:48 +0000 Subject: [PATCH 3/4] Update contributing doc. --- CONTRIBUTING.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 468fe6e6..4189f928 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,9 +34,8 @@ How to get your contributions merged smoothly and quickly. - Provide a good **PR description** as a record of **what** change is being made and **why** it was made. Link to a GitHub issue if it exists. -- Don't fix code style and formatting unless you are already changing that line - to address an issue. PRs with irrelevant changes won't be merged. If you do - want to fix formatting or style, do that in a separate PR. +- Run `bazelisk run //:format` to fix code style and formatting before committing. + PRs with invalid format won't be merged. If you do want to support formatting for a new language, do that in a separate commit before adding another commit that reformats all relevant files. - If you are adding a new file, make sure it has the copyright message template at the top as a comment. You can copy over the message from an existing file From 35864b1df681b98c66e0e5138700797f7b6cdd6a Mon Sep 17 00:00:00 2001 From: Gia Thuan Lam Date: Sat, 30 Aug 2025 16:48:25 +0000 Subject: [PATCH 4/4] Reformat grpc-kotlin files. --- .github/dependabot.yml | 2 - .github/workflows/bazel.yml | 88 +- .github/workflows/gradle.yml | 46 +- .github/workflows/release.yml | 6 +- MODULE.bazel | 4 +- build.gradle.kts | 231 ++-- .../io/grpc/examples/bzlmod/FibonacciTest.kt | 28 +- compiler/build.gradle.kts | 130 +- .../java/io/grpc/kotlin/generator/BUILD.bazel | 2 +- .../grpc/kotlin/generator/GeneratorRunner.kt | 27 +- .../generator/GrpcClientStubGenerator.kt | 224 ++-- .../generator/GrpcCoroutineServerGenerator.kt | 206 ++-- .../generator/ProtoFileCodeGenerator.kt | 59 +- .../kotlin/generator/ServiceCodeGenerator.kt | 23 +- .../kotlin/generator/ServiceNameGenerator.kt | 2 +- .../generator/TopLevelConstantsGenerator.kt | 25 +- .../protoc/AbstractGeneratorRunner.kt | 42 +- .../grpc/kotlin/generator/protoc/BUILD.bazel | 2 +- .../generator/protoc/ClassSimpleName.kt | 14 +- .../kotlin/generator/protoc/CodeGenerators.kt | 27 +- .../kotlin/generator/protoc/ConstantName.kt | 3 +- .../kotlin/generator/protoc/Declarations.kt | 3 +- .../kotlin/generator/protoc/DescriptorUtil.kt | 74 +- .../generator/protoc/GeneratorConfig.kt | 6 +- .../generator/protoc/JavaPackagePolicy.kt | 4 +- .../generator/protoc/MemberSimpleName.kt | 6 +- .../kotlin/generator/protoc/ProtoFieldName.kt | 24 +- .../generator/protoc/ProtoMethodName.kt | 4 +- .../io/grpc/kotlin/generator/protoc/Scope.kt | 20 +- .../protoc/testing/DeclarationsSubject.kt | 22 +- .../protoc/testing/FileSpecSubject.kt | 6 +- .../protoc/testing/FunSpecSubject.kt | 6 +- .../protoc/testing/TypeSpecSubject.kt | 6 +- .../protoc/util/graph/TopologicalSortGraph.kt | 18 +- .../generator/protoc/util/sort/BUILD.bazel | 2 +- .../protoc/util/sort/PartialOrdering.kt | 17 +- .../protoc/util/sort/TopologicalSort.kt | 51 +- .../generator/protoc/ClassSimpleNameTest.kt | 9 +- .../generator/protoc/DeclarationsTest.kt | 105 +- .../generator/protoc/DescriptorUtilTest.kt | 2 +- .../generator/protoc/GeneratorConfigTest.kt | 134 +- .../generator/protoc/JavaPackagePolicyTest.kt | 4 +- .../generator/protoc/MemberSimpleNameTest.kt | 3 +- .../protoc/OptionalProto3FieldTest.kt | 1 - .../protoc/ProtoEnumValueNameTest.kt | 3 +- .../generator/protoc/ProtoFieldNameTest.kt | 11 +- .../generator/protoc/ProtoFileNameTest.kt | 10 +- .../generator/protoc/ProtoMethodNameTest.kt | 8 +- .../grpc/kotlin/generator/protoc/ScopeTest.kt | 8 +- .../protoc/testing/DeclarationsSubjectTest.kt | 42 +- .../protoc/testing/FileSpecSubjectTest.kt | 14 +- .../protoc/testing/FunSpecSubjectTest.kt | 19 +- .../protoc/testing/TypeSpecSubjectTest.kt | 20 +- .../test/proto/helloworld/helloworld.proto | 4 +- compiler/src/test/proto/testing/BUILD.bazel | 2 +- .../has_explicit_outer_class_name.proto | 2 +- .../rpc_name_contains_underscore.proto | 7 +- .../service_name_conflicts_with_file.proto | 2 +- .../proto/testing/test_proto3_optional.proto | 2 - examples/android/build.gradle.kts | 64 +- .../grpc/examples/helloworld/MainActivity.kt | 101 +- examples/build.gradle.kts | 11 +- examples/client/build.gradle.kts | 72 +- .../io/grpc/examples/animals/AnimalsClient.kt | 84 +- .../io/grpc/examples/animals/BUILD.bazel | 2 +- .../examples/helloworld/HelloWorldClient.kt | 33 +- .../io/grpc/examples/routeguide/BUILD.bazel | 2 +- .../examples/routeguide/RouteGuideClient.kt | 232 ++-- examples/native-client/build.gradle.kts | 37 +- .../examples/helloworld/HelloWorldClient.kt | 33 +- examples/protos/build.gradle.kts | 8 +- .../io/grpc/examples/animals/BUILD.bazel | 4 +- .../proto/io/grpc/examples/animals/dog.proto | 11 +- .../proto/io/grpc/examples/animals/pig.proto | 11 +- .../io/grpc/examples/animals/sheep.proto | 11 +- .../io/grpc/examples/helloworld/BUILD.bazel | 2 +- .../examples/helloworld/hello_world.proto | 8 +- .../io/grpc/examples/routeguide/BUILD.bazel | 4 +- .../examples/routeguide/route_guide.proto | 2 +- examples/server/build.gradle.kts | 106 +- .../io/grpc/examples/animals/AnimalsServer.kt | 83 +- .../io/grpc/examples/animals/BUILD.bazel | 2 +- .../examples/helloworld/HelloWorldServer.kt | 64 +- .../io/grpc/examples/routeguide/BUILD.bazel | 2 +- .../examples/routeguide/RouteGuideServer.kt | 149 ++- .../examples/animals/AnimalsServerTest.kt | 44 +- .../helloworld/HelloWorldServerTest.kt | 26 +- .../routeguide/RouteGuideServerTest.kt | 47 +- examples/settings.gradle.kts | 35 +- examples/stub-android/build.gradle.kts | 80 +- examples/stub-lite/build.gradle.kts | 70 +- examples/stub/build.gradle.kts | 60 +- .../io/grpc/examples/routeguide/BUILD.bazel | 4 +- .../io/grpc/examples/routeguide/Database.kt | 14 +- .../io/grpc/examples/routeguide/Points.kt | 34 +- integration_testing/build.gradle.kts | 97 +- .../kotlin/io/grpc/kotlin/ExamplesTest.kt | 159 +-- interop_testing/build.gradle.kts | 88 +- .../integration/AbstractInteropTest.kt | 1083 +++++++---------- .../testing/integration/Http2TestCases.java | 12 +- .../grpc/testing/integration/TestCases.java | 15 +- .../testing/integration/TestServiceClient.kt | 81 +- .../testing/integration/TestServiceImpl.kt | 54 +- .../integration/TestServiceServer.java | 10 +- .../io/grpc/testing/integration/Util.java | 20 +- .../src/main/proto/grpc/testing/empty.proto | 2 +- .../src/main/proto/grpc/testing/test.proto | 15 +- .../java/io/grpc/stub/StubConfigTest.java | 20 +- .../testing/integration/Http2OkHttpTest.java | 70 +- .../testing/integration/TestCasesTest.java | 4 +- kt_jvm_grpc.bzl | 4 +- settings.gradle.kts | 10 +- stub/build.gradle.kts | 175 ++- .../kotlin/AbstractCoroutineServerImpl.kt | 2 +- .../io/grpc/kotlin/AbstractCoroutineStub.kt | 8 +- .../main/java/io/grpc/kotlin/ClientCalls.kt | 135 +- .../CoroutineContextServerInterceptor.kt | 33 +- .../java/io/grpc/kotlin/GrpcContextElement.kt | 8 +- stub/src/main/java/io/grpc/kotlin/Helpers.kt | 23 +- .../src/main/java/io/grpc/kotlin/Readiness.kt | 15 +- .../main/java/io/grpc/kotlin/ServerCalls.kt | 179 ++- .../java/io/grpc/kotlin/AbstractCallsTest.kt | 37 +- .../java/io/grpc/kotlin/ClientCallsTest.kt | 572 ++++----- .../CoroutineContextServerInterceptorTest.kt | 131 +- .../java/io/grpc/kotlin/FlowControlTest.kt | 127 +- .../java/io/grpc/kotlin/GeneratedCodeTest.kt | 283 ++--- .../io/grpc/kotlin/GrpcContextElementTest.kt | 5 +- .../java/io/grpc/kotlin/ServerCallsTest.kt | 891 +++++++------- .../test/proto/helloworld/helloworld.proto | 2 +- 129 files changed, 3638 insertions(+), 4156 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 4dc7df02..ff2c238f 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -4,12 +4,10 @@ updates: directory: "/" schedule: interval: "daily" - - package-ecosystem: "gradle" directory: "/examples" schedule: interval: "daily" - - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/.github/workflows/bazel.yml b/.github/workflows/bazel.yml index 4d298d2c..c3907c63 100644 --- a/.github/workflows/bazel.yml +++ b/.github/workflows/bazel.yml @@ -9,48 +9,48 @@ jobs: bazel: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 - - name: Mount bazel cache - uses: actions/cache@v4 - with: - path: "/home/runner/.cache/bazel" - key: bazel - - name: Set up JDK 17 - uses: actions/setup-java@v5 - with: - java-version: '17' - distribution: 'temurin' - - name: Install bazelisk - run: | - curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.26.0/bazelisk-linux-amd64" - mkdir -p "${GITHUB_WORKSPACE}/bin/" - mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel" - chmod +x "${GITHUB_WORKSPACE}/bin/bazel" - - name: Lint - Please run `bazelisk run //:format` - run: | - cd "${GITHUB_WORKSPACE}" - "${GITHUB_WORKSPACE}/bin/bazel" run //:format.check - - name: Build - uses: nick-invision/retry@v3 - with: - timeout_minutes: 10 - max_attempts: 3 - command: | + - uses: actions/checkout@v5 + - name: Mount bazel cache + uses: actions/cache@v4 + with: + path: "/home/runner/.cache/bazel" + key: bazel + - name: Set up JDK 17 + uses: actions/setup-java@v5 + with: + java-version: '17' + distribution: 'temurin' + - name: Install bazelisk + run: | + curl -LO "https://github.com/bazelbuild/bazelisk/releases/download/v1.26.0/bazelisk-linux-amd64" + mkdir -p "${GITHUB_WORKSPACE}/bin/" + mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel" + chmod +x "${GITHUB_WORKSPACE}/bin/bazel" + - name: Lint - Please run `bazelisk run //:format` + run: | cd "${GITHUB_WORKSPACE}" - "${GITHUB_WORKSPACE}/bin/bazel" build //... - - name: Test - uses: nick-invision/retry@v3 - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - cd "${GITHUB_WORKSPACE}" - "${GITHUB_WORKSPACE}/bin/bazel" test //... - - name: Test bzl-examples/bzlmod - uses: nick-invision/retry@v3 - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - cd "${GITHUB_WORKSPACE}/bzl-examples/bzlmod" - "${GITHUB_WORKSPACE}/bin/bazel" test //... + "${GITHUB_WORKSPACE}/bin/bazel" run //:format.check + - name: Build + uses: nick-invision/retry@v3 + with: + timeout_minutes: 10 + max_attempts: 3 + command: | + cd "${GITHUB_WORKSPACE}" + "${GITHUB_WORKSPACE}/bin/bazel" build //... + - name: Test + uses: nick-invision/retry@v3 + with: + timeout_minutes: 10 + max_attempts: 3 + command: | + cd "${GITHUB_WORKSPACE}" + "${GITHUB_WORKSPACE}/bin/bazel" test //... + - name: Test bzl-examples/bzlmod + uses: nick-invision/retry@v3 + with: + timeout_minutes: 10 + max_attempts: 3 + command: | + cd "${GITHUB_WORKSPACE}/bzl-examples/bzlmod" + "${GITHUB_WORKSPACE}/bin/bazel" test //... diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index a5448a88..c2c18829 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -14,28 +14,24 @@ jobs: os: [ubuntu-latest, macos-13] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v5 - - - name: Set up JDK 17 - uses: actions/setup-java@v5 - with: - java-version: '17' - distribution: 'temurin' - - - name: Set up Gradle - uses: gradle/actions/setup-gradle@v4 - - - name: Test on Mac - if: matrix.os == 'macos-13' - run: | - brew install docker colima - colima start --network-address - export TESTCONTAINERS_DOCKER_SOCKET_OVERRIDE=/var/run/docker.sock - export TESTCONTAINERS_HOST_OVERRIDE=$(colima ls -j | jq -r '.address') - export DOCKER_HOST="unix://${HOME}/.colima/default/docker.sock" - ./gradlew test - - - name: Test on Ubuntu - if: matrix.os == 'ubuntu-latest' - run: | - ./gradlew test + - uses: actions/checkout@v5 + - name: Set up JDK 17 + uses: actions/setup-java@v5 + with: + java-version: '17' + distribution: 'temurin' + - name: Set up Gradle + uses: gradle/actions/setup-gradle@v4 + - name: Test on Mac + if: matrix.os == 'macos-13' + run: | + brew install docker colima + colima start --network-address + export TESTCONTAINERS_DOCKER_SOCKET_OVERRIDE=/var/run/docker.sock + export TESTCONTAINERS_HOST_OVERRIDE=$(colima ls -j | jq -r '.address') + export DOCKER_HOST="unix://${HOME}/.colima/default/docker.sock" + ./gradlew test + - name: Test on Ubuntu + if: matrix.os == 'ubuntu-latest' + run: | + ./gradlew test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d18b64e2..5d5515f2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,22 +4,18 @@ on: tags: - v** workflow_dispatch: - jobs: release: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - - name: Set up JDK 17 uses: actions/setup-java@v5 with: java-version: '17' distribution: 'temurin' - - name: Set up Gradle uses: gradle/actions/setup-gradle@v4 - - name: release env: SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} @@ -28,5 +24,5 @@ jobs: GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} run: | # todo: verify version is same as tag - + ./gradlew publishToSonatype closeAndReleaseSonatypeStagingRepository diff --git a/MODULE.bazel b/MODULE.bazel index 911b0da7..6ef25d0d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -37,14 +37,14 @@ grpc_kotlin_maven.install( "com.google.protobuf:protobuf-java:4.30.2", "com.google.protobuf:protobuf-kotlin:4.30.2", "com.google.guava:guava:33.3.1-android", - "com.squareup:kotlinpoet:1.14.2", # Max version without causing compiler errors. + "com.squareup:kotlinpoet:1.14.2", # Max version without causing compiler errors. "junit:junit:4.13.2", "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.1", "org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm:1.10.1", ], fetch_sources = False, - lock_file = "//:grpc_kotlin_maven_install.json", generate_compat_repositories = True, + lock_file = "//:grpc_kotlin_maven_install.json", strict_visibility = True, ) use_repo(grpc_kotlin_maven, "grpc_kotlin_maven") diff --git a/build.gradle.kts b/build.gradle.kts index 494a09fb..99fa6980 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -3,148 +3,143 @@ import org.gradle.api.tasks.testing.logging.TestLogEvent import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { - alias(libs.plugins.kotlin.jvm) apply false - alias(libs.plugins.protobuf) apply false - alias(libs.plugins.test.retry) - alias(libs.plugins.publish.plugin) - alias(libs.plugins.qoomon.git.versioning) + alias(libs.plugins.kotlin.jvm) apply false + alias(libs.plugins.protobuf) apply false + alias(libs.plugins.test.retry) + alias(libs.plugins.publish.plugin) + alias(libs.plugins.qoomon.git.versioning) } group = "io.grpc" gitVersioning.apply { - refs { - tag("v(?.*)") { - version = "\${ref.version}" - } - } + refs { tag("v(?.*)") { version = "\${ref.version}" } } - rev { - version = "\${commit}" - } + rev { version = "\${commit}" } } subprojects { - - apply { - plugin("java") - plugin("org.jetbrains.kotlin.jvm") - plugin("com.google.protobuf") - plugin("org.gradle.test-retry") - plugin("maven-publish") - plugin("signing") + apply { + plugin("java") + plugin("org.jetbrains.kotlin.jvm") + plugin("com.google.protobuf") + plugin("org.gradle.test-retry") + plugin("maven-publish") + plugin("signing") + } + + // gradle-nexus/publish-plugin needs these here too + group = rootProject.group + version = rootProject.version + + tasks.withType { + sourceCompatibility = JavaVersion.VERSION_17.toString() + targetCompatibility = JavaVersion.VERSION_17.toString() + } + + tasks.withType { + kotlinOptions { + freeCompilerArgs = listOf("-Xjsr305=strict") + jvmTarget = JavaVersion.VERSION_17.toString() } + } + + tasks.withType { + testLogging { + // set options for log level LIFECYCLE + events = + setOf( + TestLogEvent.FAILED, + TestLogEvent.PASSED, + TestLogEvent.SKIPPED, + TestLogEvent.STANDARD_OUT + ) - // gradle-nexus/publish-plugin needs these here too - group = rootProject.group - version = rootProject.version - - tasks.withType { - sourceCompatibility = JavaVersion.VERSION_17.toString() - targetCompatibility = JavaVersion.VERSION_17.toString() + exceptionFormat = TestExceptionFormat.FULL + showStandardStreams = true + showExceptions = true + showCauses = true + showStackTraces = true + + // set options for log level DEBUG and INFO + debug { + events = + setOf( + TestLogEvent.STARTED, + TestLogEvent.FAILED, + TestLogEvent.PASSED, + TestLogEvent.SKIPPED, + TestLogEvent.STANDARD_ERROR, + TestLogEvent.STANDARD_OUT + ) + + exceptionFormat = TestExceptionFormat.FULL + } + + info.events = debug.events + info.exceptionFormat = debug.exceptionFormat } - tasks.withType { - kotlinOptions { - freeCompilerArgs = listOf("-Xjsr305=strict") - jvmTarget = JavaVersion.VERSION_17.toString() - } - } + retry { maxRetries = 10 } - tasks.withType { - testLogging { - // set options for log level LIFECYCLE - events = setOf( - TestLogEvent.FAILED, - TestLogEvent.PASSED, - TestLogEvent.SKIPPED, - TestLogEvent.STANDARD_OUT - ) - - exceptionFormat = TestExceptionFormat.FULL - showStandardStreams = true - showExceptions = true - showCauses = true - showStackTraces = true - - // set options for log level DEBUG and INFO - debug { - events = setOf( - TestLogEvent.STARTED, - TestLogEvent.FAILED, - TestLogEvent.PASSED, - TestLogEvent.SKIPPED, - TestLogEvent.STANDARD_ERROR, - TestLogEvent.STANDARD_OUT - ) - - exceptionFormat = TestExceptionFormat.FULL - } - - info.events = debug.events - info.exceptionFormat = debug.exceptionFormat + afterSuite( + KotlinClosure2({ desc: TestDescriptor, result: TestResult -> + if (desc.parent == null) { // will match the outermost suite + println( + "Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)" + ) } - - retry { - maxRetries = 10 + }) + ) + } + + extensions.getByType().repositories { + maven { url = uri(rootProject.layout.buildDirectory.dir("maven-repo")) } + } + + extensions.getByType().publications { + create("maven") { + pom { + url.set("https://github.com/grpc/grpc-kotlin") + + scm { + connection.set("scm:git:https://github.com/grpc/grpc-kotlin.git") + developerConnection.set("scm:git:git@github.com:grpc/grpc-kotlin.git") + url.set("https://github.com/grpc/grpc-kotlin") } - afterSuite( - KotlinClosure2({ desc: TestDescriptor, result: TestResult -> - if (desc.parent == null) { // will match the outermost suite - println("Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)") - } - }) - ) - } - - extensions.getByType().repositories { - maven { - url = uri(rootProject.layout.buildDirectory.dir("maven-repo")) + licenses { + license { + name.set("Apache 2.0") + url.set("https://opensource.org/licenses/Apache-2.0") + } } - } - extensions.getByType().publications { - create("maven") { - pom { - url.set("https://github.com/grpc/grpc-kotlin") - - scm { - connection.set("scm:git:https://github.com/grpc/grpc-kotlin.git") - developerConnection.set("scm:git:git@github.com:grpc/grpc-kotlin.git") - url.set("https://github.com/grpc/grpc-kotlin") - } - - licenses { - license { - name.set("Apache 2.0") - url.set("https://opensource.org/licenses/Apache-2.0") - } - } - - developers { - developer { - id.set("grpc.io") - name.set("gRPC Contributors") - email.set("grpc-io@googlegroups.com") - url.set("https://grpc.io/") - organization.set("gRPC Authors") - organizationUrl.set("https://www.google.com") - } - } - } + developers { + developer { + id.set("grpc.io") + name.set("gRPC Contributors") + email.set("grpc-io@googlegroups.com") + url.set("https://grpc.io/") + organization.set("gRPC Authors") + organizationUrl.set("https://www.google.com") + } } + } } + } - extensions.getByType().sign(extensions.getByType().publications.named("maven").get()) - extensions.getByType().useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) + extensions + .getByType() + .sign(extensions.getByType().publications.named("maven").get()) + extensions + .getByType() + .useInMemoryPgpKeys(System.getenv("GPG_PRIVATE_KEY"), System.getenv("GPG_PASSPHRASE")) - tasks.withType { - onlyIf { System.getenv("GPG_PRIVATE_KEY") != null } - } + tasks.withType { onlyIf { System.getenv("GPG_PRIVATE_KEY") != null } } } nexusPublishing.repositories.sonatype { - username.set(System.getenv("SONATYPE_USERNAME")) - password.set(System.getenv("SONATYPE_PASSWORD")) + username.set(System.getenv("SONATYPE_USERNAME")) + password.set(System.getenv("SONATYPE_PASSWORD")) } diff --git a/bzl-examples/bzlmod/javatests/io/grpc/examples/bzlmod/FibonacciTest.kt b/bzl-examples/bzlmod/javatests/io/grpc/examples/bzlmod/FibonacciTest.kt index 4582db61..faae46ba 100644 --- a/bzl-examples/bzlmod/javatests/io/grpc/examples/bzlmod/FibonacciTest.kt +++ b/bzl-examples/bzlmod/javatests/io/grpc/examples/bzlmod/FibonacciTest.kt @@ -23,27 +23,29 @@ class FibonacciTest { val serverName = InProcessServerBuilder.generateName() grpcCleanup.register( - InProcessServerBuilder.forName(serverName) - .directExecutor() - .addService(Fibonacci()) - .build() - .start()) + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(Fibonacci()) + .build() + .start() + ) stub = - FibonacciServiceGrpcKt.FibonacciServiceCoroutineStub( - grpcCleanup.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build())) + FibonacciServiceGrpcKt.FibonacciServiceCoroutineStub( + grpcCleanup.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()) + ) } @Test fun query_succeeds() { runBlocking { val response = - stub.query( - queryRequest { - nth = 20 - mod = 1000000007 - }) + stub.query( + queryRequest { + nth = 20 + mod = 1000000007 + } + ) assertThat(response).isEqualTo(queryResponse { nthFibonacci = 6765 }) } } diff --git a/compiler/build.gradle.kts b/compiler/build.gradle.kts index 5f777d41..708d7ef3 100644 --- a/compiler/build.gradle.kts +++ b/compiler/build.gradle.kts @@ -1,97 +1,81 @@ import com.google.protobuf.gradle.* -plugins { - application -} +plugins { application } -application { - mainClass.set("io.grpc.kotlin.generator.GeneratorRunner") -} +application { mainClass.set("io.grpc.kotlin.generator.GeneratorRunner") } java { - withSourcesJar() - toolchain { - languageVersion = JavaLanguageVersion.of(17) - } + withSourcesJar() + toolchain { languageVersion = JavaLanguageVersion.of(17) } } dependencies { - // Kotlin and Java - implementation(libs.kotlinx.coroutines.core) - - // Grpc and Protobuf - implementation(project(":stub")) - implementation(libs.grpc.protobuf) - - // Misc - implementation(kotlin("reflect")) - implementation(libs.kotlinpoet) - implementation(libs.truth) - - // Testing - testImplementation(libs.junit) - testImplementation(libs.guava) - testImplementation(libs.jimfs) - testImplementation(libs.protobuf.gradle.plugin) - testImplementation(libs.protobuf.java) - testImplementation(libs.junit.jupiter.engine) - testImplementation(libs.mockito.core) + // Kotlin and Java + implementation(libs.kotlinx.coroutines.core) + + // Grpc and Protobuf + implementation(project(":stub")) + implementation(libs.grpc.protobuf) + + // Misc + implementation(kotlin("reflect")) + implementation(libs.kotlinpoet) + implementation(libs.truth) + + // Testing + testImplementation(libs.junit) + testImplementation(libs.guava) + testImplementation(libs.jimfs) + testImplementation(libs.protobuf.gradle.plugin) + testImplementation(libs.protobuf.java) + testImplementation(libs.junit.jupiter.engine) + testImplementation(libs.mockito.core) } tasks.jar { - manifest { - attributes["Main-Class"] = application.mainClass.get() - } + manifest { attributes["Main-Class"] = application.mainClass.get() } - from(sourceSets.main.get().output) + from(sourceSets.main.get().output) - dependsOn(configurations.runtimeClasspath) + dependsOn(configurations.runtimeClasspath) - from({ - configurations.runtimeClasspath.get().filter { it.name.endsWith("jar") }.map { zipTree(it) } - }) + from({ + configurations.runtimeClasspath.get().filter { it.name.endsWith("jar") }.map { zipTree(it) } + }) - duplicatesStrategy = DuplicatesStrategy.INCLUDE + duplicatesStrategy = DuplicatesStrategy.INCLUDE } publishing { - publications { - named("maven") { - pom { - name.set("gRPC Kotlin Compiler") - artifactId = "protoc-gen-grpc-kotlin" - description.set("gRPC Kotlin protoc compiler plugin") - } - - artifact(tasks.jar) { - classifier = "jdk8" - } - } + publications { + named("maven") { + pom { + name.set("gRPC Kotlin Compiler") + artifactId = "protoc-gen-grpc-kotlin" + description.set("gRPC Kotlin protoc compiler plugin") + } + + artifact(tasks.jar) { classifier = "jdk8" } } + } } protobuf { - protoc { - artifact = libs.protoc.asProvider().get().toString() - } - plugins { - id("grpc") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - id("grpckt") { - path = tasks.jar.get().archiveFile.get().asFile.absolutePath - } - } - generateProtoTasks { - all().forEach { - if (it.name.startsWith("generateTestProto")) { - it.dependsOn("jar") - } - - it.plugins { - id("grpc") - id("grpckt") - } - } + protoc { artifact = libs.protoc.asProvider().get().toString() } + plugins { + id("grpc") { artifact = libs.protoc.gen.grpc.java.get().toString() } + id("grpckt") { path = tasks.jar.get().archiveFile.get().asFile.absolutePath } + } + generateProtoTasks { + all().forEach { + if (it.name.startsWith("generateTestProto")) { + it.dependsOn("jar") + } + + it.plugins { + id("grpc") + id("grpckt") + } } + } } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/BUILD.bazel b/compiler/src/main/java/io/grpc/kotlin/generator/BUILD.bazel index f6dddca3..bfaebafa 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/BUILD.bazel +++ b/compiler/src/main/java/io/grpc/kotlin/generator/BUILD.bazel @@ -14,11 +14,11 @@ kt_jvm_library( "//compiler/src/main/java/io/grpc/kotlin/generator/protoc", "//stub/src/main/java/io/grpc/kotlin:context", "//stub/src/main/java/io/grpc/kotlin:stub", - "@protobuf//:protobuf_java", "@grpc-java//core", "@grpc_kotlin_maven//:com_google_guava_guava", "@grpc_kotlin_maven//:com_squareup_kotlinpoet", "@grpc_kotlin_maven//:org_jetbrains_kotlinx_kotlinx_coroutines_core", + "@protobuf//:protobuf_java", ], ) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/GeneratorRunner.kt b/compiler/src/main/java/io/grpc/kotlin/generator/GeneratorRunner.kt index 9bf0f813..0100a461 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/GeneratorRunner.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/GeneratorRunner.kt @@ -23,22 +23,23 @@ import io.grpc.kotlin.generator.protoc.GeneratorConfig import io.grpc.kotlin.generator.protoc.JavaPackagePolicy /** Main runner for code generation for Kotlin gRPC APIs. */ -object GeneratorRunner: AbstractGeneratorRunner() { - @JvmStatic - fun main(args: Array) = super.doMain(args) +object GeneratorRunner : AbstractGeneratorRunner() { + @JvmStatic fun main(args: Array) = super.doMain(args) private val config = GeneratorConfig(JavaPackagePolicy.OPEN_SOURCE, false) - val generator = ProtoFileCodeGenerator( - generators = listOf( - ::ServiceNameGenerator, - ::GrpcClientStubGenerator, - ::GrpcCoroutineServerGenerator, - ::TopLevelConstantsGenerator - ), - config = config, - topLevelSuffix = "GrpcKt" - ) + val generator = + ProtoFileCodeGenerator( + generators = + listOf( + ::ServiceNameGenerator, + ::GrpcClientStubGenerator, + ::GrpcCoroutineServerGenerator, + ::TopLevelConstantsGenerator + ), + config = config, + topLevelSuffix = "GrpcKt" + ) override fun generateCodeForFile(file: FileDescriptor): List = listOfNotNull(generator.generateCodeForFile(file)) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt index 2ff6265b..32748c6b 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt @@ -31,6 +31,8 @@ import com.squareup.kotlinpoet.TypeVariableName import com.squareup.kotlinpoet.asClassName import com.squareup.kotlinpoet.asTypeName import io.grpc.CallOptions +import io.grpc.Channel as GrpcChannel +import io.grpc.Metadata as GrpcMetadata import io.grpc.MethodDescriptor.MethodType import io.grpc.Status import io.grpc.StatusException @@ -48,12 +50,8 @@ import io.grpc.kotlin.generator.protoc.methodName import io.grpc.kotlin.generator.protoc.of import io.grpc.kotlin.generator.protoc.serviceName import kotlinx.coroutines.flow.Flow -import io.grpc.Channel as GrpcChannel -import io.grpc.Metadata as GrpcMetadata -/** - * Logic for generating gRPC stubs for Kotlin. - */ +/** Logic for generating gRPC stubs for Kotlin. */ @VisibleForTesting class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(config) { companion object { @@ -63,16 +61,16 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co private val GRPC_CHANNEL_PARAMETER_NAME = MemberSimpleName("channel") private val CALL_OPTIONS_PARAMETER_NAME = MemberSimpleName("callOptions") - private val HEADERS_PARAMETER: ParameterSpec = ParameterSpec - .builder("headers", GrpcMetadata::class) - .defaultValue("%T()", GrpcMetadata::class) - .build() + private val HEADERS_PARAMETER: ParameterSpec = + ParameterSpec.builder("headers", GrpcMetadata::class) + .defaultValue("%T()", GrpcMetadata::class) + .build() val GRPC_CHANNEL_PARAMETER = ParameterSpec.of(GRPC_CHANNEL_PARAMETER_NAME, GrpcChannel::class) - val CALL_OPTIONS_PARAMETER = ParameterSpec - .builder(MemberSimpleName("callOptions"), CallOptions::class) - .defaultValue("%M", CallOptions::class.member("DEFAULT")) - .build() + val CALL_OPTIONS_PARAMETER = + ParameterSpec.builder(MemberSimpleName("callOptions"), CallOptions::class) + .defaultValue("%M", CallOptions::class.member("DEFAULT")) + .build() private val FLOW = Flow::class.asClassName() @@ -81,19 +79,21 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co private val SERVER_STREAMING_RPC_HELPER = ClientCalls::class.member("serverStreamingRpc") private val BIDI_STREAMING_RPC_HELPER = ClientCalls::class.member("bidiStreamingRpc") - private val RPC_HELPER = mapOf( - MethodType.UNARY to UNARY_RPC_HELPER, - MethodType.CLIENT_STREAMING to CLIENT_STREAMING_RPC_HELPER, - MethodType.SERVER_STREAMING to SERVER_STREAMING_RPC_HELPER, - MethodType.BIDI_STREAMING to BIDI_STREAMING_RPC_HELPER - ) + private val RPC_HELPER = + mapOf( + MethodType.UNARY to UNARY_RPC_HELPER, + MethodType.CLIENT_STREAMING to CLIENT_STREAMING_RPC_HELPER, + MethodType.SERVER_STREAMING to SERVER_STREAMING_RPC_HELPER, + MethodType.BIDI_STREAMING to BIDI_STREAMING_RPC_HELPER + ) private val MethodDescriptor.type: MethodType - get() = if (isClientStreaming) { - if (isServerStreaming) MethodType.BIDI_STREAMING else MethodType.CLIENT_STREAMING - } else { - if (isServerStreaming) MethodType.SERVER_STREAMING else MethodType.UNARY - } + get() = + if (isClientStreaming) { + if (isServerStreaming) MethodType.BIDI_STREAMING else MethodType.CLIENT_STREAMING + } else { + if (isServerStreaming) MethodType.SERVER_STREAMING else MethodType.UNARY + } } override fun generate(service: ServiceDescriptor): Declarations = declarations { @@ -108,29 +108,26 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co // which we don't want. val stubSelfReference: TypeName = TypeVariableName(stubName.toString()) - val builder = TypeSpec - .classBuilder(stubName) - .superclass(AbstractCoroutineStub::class.asTypeName().parameterizedBy(stubSelfReference)) - .addKdoc( - "A stub for issuing RPCs to a(n) %L service as suspending coroutines.", - service.fullName - ) - .addAnnotation( - AnnotationSpec.builder(StubFor::class) - .addMember("%T::class", service.grpcClass) - .build() - ) - .primaryConstructor( - FunSpec - .constructorBuilder() - .addParameter(GRPC_CHANNEL_PARAMETER) - .addParameter(CALL_OPTIONS_PARAMETER) - .addAnnotation(JvmOverloads::class) - .build() - ) - .addSuperclassConstructorParameter("%N", GRPC_CHANNEL_PARAMETER) - .addSuperclassConstructorParameter("%N", CALL_OPTIONS_PARAMETER) - .addFunction(buildFun(stubSelfReference)) + val builder = + TypeSpec.classBuilder(stubName) + .superclass(AbstractCoroutineStub::class.asTypeName().parameterizedBy(stubSelfReference)) + .addKdoc( + "A stub for issuing RPCs to a(n) %L service as suspending coroutines.", + service.fullName + ) + .addAnnotation( + AnnotationSpec.builder(StubFor::class).addMember("%T::class", service.grpcClass).build() + ) + .primaryConstructor( + FunSpec.constructorBuilder() + .addParameter(GRPC_CHANNEL_PARAMETER) + .addParameter(CALL_OPTIONS_PARAMETER) + .addAnnotation(JvmOverloads::class) + .build() + ) + .addSuperclassConstructorParameter("%N", GRPC_CHANNEL_PARAMETER) + .addSuperclassConstructorParameter("%N", CALL_OPTIONS_PARAMETER) + .addFunction(buildFun(stubSelfReference)) for (method in service.methods) { builder.addFunction(generateRpcStub(method)) @@ -142,66 +139,63 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co * Outputs a `FunSpec` of an override of `AbstractCoroutineStub.build` for this particular stub. */ private fun buildFun(stubName: TypeName): FunSpec { - return FunSpec - .builder("build") + return FunSpec.builder("build") .returns(stubName) .addModifiers(KModifier.OVERRIDE) .addParameter(GRPC_CHANNEL_PARAMETER) .addParameter(ParameterSpec.of(CALL_OPTIONS_PARAMETER_NAME, CallOptions::class)) - .addStatement( - "return %T(%N, %N)", - stubName, - GRPC_CHANNEL_PARAMETER, - CALL_OPTIONS_PARAMETER - ) + .addStatement("return %T(%N, %N)", stubName, GRPC_CHANNEL_PARAMETER, CALL_OPTIONS_PARAMETER) .build() } @VisibleForTesting - fun generateRpcStub(method: MethodDescriptor): FunSpec = with(config) { - val name = method.methodName.toMemberSimpleName() - val requestType = method.inputType.messageClass() - val parameter = if (method.isClientStreaming) { - ParameterSpec.of(STREAMING_PARAMETER_NAME, FLOW.parameterizedBy(requestType)) - } else { - ParameterSpec.of(UNARY_PARAMETER_NAME, requestType) - } + fun generateRpcStub(method: MethodDescriptor): FunSpec = + with(config) { + val name = method.methodName.toMemberSimpleName() + val requestType = method.inputType.messageClass() + val parameter = + if (method.isClientStreaming) { + ParameterSpec.of(STREAMING_PARAMETER_NAME, FLOW.parameterizedBy(requestType)) + } else { + ParameterSpec.of(UNARY_PARAMETER_NAME, requestType) + } - val responseType = method.outputType.messageClass() + val responseType = method.outputType.messageClass() - val returnType = - if (method.isServerStreaming) FLOW.parameterizedBy(responseType) else responseType + val returnType = + if (method.isServerStreaming) FLOW.parameterizedBy(responseType) else responseType - val helperMethod = RPC_HELPER[method.type] ?: throw IllegalArgumentException() + val helperMethod = RPC_HELPER[method.type] ?: throw IllegalArgumentException() - val funSpecBuilder = - funSpecBuilder(name) - .addParameter(parameter) - .addParameter(HEADERS_PARAMETER) - .returns(returnType) - .addKdoc(rpcStubKDoc(method, parameter)) + val funSpecBuilder = + funSpecBuilder(name) + .addParameter(parameter) + .addParameter(HEADERS_PARAMETER) + .returns(returnType) + .addKdoc(rpcStubKDoc(method, parameter)) - if (method.options.deprecated) { - funSpecBuilder.addAnnotation( - AnnotationSpec.builder(Deprecated::class) - .addMember("%S", "The underlying service method is marked deprecated.") - .build() - ) - } + if (method.options.deprecated) { + funSpecBuilder.addAnnotation( + AnnotationSpec.builder(Deprecated::class) + .addMember("%S", "The underlying service method is marked deprecated.") + .build() + ) + } - val codeBlockMap = mapOf( - "helperMethod" to helperMethod, - "methodDescriptor" to method.descriptorCode, - "parameter" to parameter, - "headers" to HEADERS_PARAMETER - ) + val codeBlockMap = + mapOf( + "helperMethod" to helperMethod, + "methodDescriptor" to method.descriptorCode, + "parameter" to parameter, + "headers" to HEADERS_PARAMETER + ) - if (!method.isServerStreaming) { - funSpecBuilder.addModifiers(KModifier.SUSPEND) - } + if (!method.isServerStreaming) { + funSpecBuilder.addModifiers(KModifier.SUSPEND) + } - funSpecBuilder.addNamedCode( - """ + funSpecBuilder.addNamedCode( + """ return %helperMethod:M( channel, %methodDescriptor:L, @@ -209,23 +203,22 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co callOptions, %headers:N ) - """.trimIndent(), - codeBlockMap - ) - return funSpecBuilder.build() - } + """ + .trimIndent(), + codeBlockMap + ) + return funSpecBuilder.build() + } - private fun rpcStubKDoc( - method: MethodDescriptor, - parameter: ParameterSpec - ): CodeBlock { - val kDocBindings = mapOf( - "parameter" to parameter, - "flow" to Flow::class, - "status" to Status::class, - "statusException" to StatusException::class, - "headers" to HEADERS_PARAMETER - ) + private fun rpcStubKDoc(method: MethodDescriptor, parameter: ParameterSpec): CodeBlock { + val kDocBindings = + mapOf( + "parameter" to parameter, + "flow" to Flow::class, + "status" to Status::class, + "statusException" to StatusException::class, + "headers" to HEADERS_PARAMETER + ) val kDocComponents = mutableListOf() @@ -237,14 +230,16 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co [`Status.OK`][%status:T], and fails by throwing a [%statusException:T] otherwise. If collecting the flow downstream fails exceptionally (including via cancellation), the RPC is cancelled with that exception as a cause. - """.trimIndent() + """ + .trimIndent() } else { """ Executes this RPC and returns the response message, suspending until the RPC completes with [`Status.OK`][%status:T]. If the RPC completes with another status, a corresponding [%statusException:T] is thrown. If this coroutine is cancelled, the RPC is also cancelled with the corresponding exception as a cause. - """.trimIndent() + """ + .trimIndent() } ) @@ -258,7 +253,8 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co `%parameter:N` is cancelled. If the collection of `%parameter:N` completes exceptionally for any other reason, then the collection of the [%flow:T] of responses completes exceptionally for the same reason and the RPC is cancelled with that reason. - """.trimIndent() + """ + .trimIndent() ) } MethodType.CLIENT_STREAMING -> { @@ -269,7 +265,8 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co will be cancelled. If the collection of requests completes exceptionally for any other reason, the RPC will be cancelled for that reason and this method will throw that exception. - """.trimIndent() + """ + .trimIndent() ) } else -> {} @@ -294,9 +291,6 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co "@return The single response from the server." } ) - return CodeBlock - .builder() - .addNamed(kDocComponents.joinToString("\n\n"), kDocBindings) - .build() + return CodeBlock.builder().addNamed(kDocComponents.joinToString("\n\n"), kDocBindings).build() } } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/GrpcCoroutineServerGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/GrpcCoroutineServerGenerator.kt index 659a78f0..1c16d257 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/GrpcCoroutineServerGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/GrpcCoroutineServerGenerator.kt @@ -44,15 +44,13 @@ import io.grpc.kotlin.generator.protoc.declarations import io.grpc.kotlin.generator.protoc.methodName import io.grpc.kotlin.generator.protoc.of import io.grpc.kotlin.generator.protoc.serviceName -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.flow.Flow import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.flow.Flow -/** - * Generator for abstract classes of the form `MyServiceCoroutineImplBase`. - */ -class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerator(config) { +/** Generator for abstract classes of the form `MyServiceCoroutineImplBase`. */ +class GrpcCoroutineServerGenerator(config: GeneratorConfig) : ServiceCodeGenerator(config) { companion object { private const val IMPL_BASE_SUFFIX = "CoroutineImplBase" @@ -61,8 +59,7 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato private val STREAMING_REQUEST_NAME: MemberSimpleName = MemberSimpleName("requests") private val coroutineContextParameter: ParameterSpec = - ParameterSpec - .builder("coroutineContext", CoroutineContext::class) + ParameterSpec.builder("coroutineContext", CoroutineContext::class) .defaultValue("%T", EmptyCoroutineContext::class) .build() @@ -77,8 +74,7 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato private val BIDI_STREAMING_SMD: MemberName = ServerCalls::class.member("bidiStreamingServerMethodDefinition") - private val UNIMPLEMENTED_STATUS: MemberName = - Status::class.member("UNIMPLEMENTED") + private val UNIMPLEMENTED_STATUS: MemberName = Status::class.member("UNIMPLEMENTED") } override fun generate(service: ServiceDescriptor): Declarations = declarations { @@ -89,36 +85,37 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato val serviceImplClassName = service.serviceName.toClassSimpleName().withSuffix(IMPL_BASE_SUFFIX) val stubs: List = service.methods.map { serviceMethodStub(it) } - val implBuilder = TypeSpec - .classBuilder(serviceImplClassName) - .addModifiers(KModifier.ABSTRACT) - .addKdoc( - """ + val implBuilder = + TypeSpec.classBuilder(serviceImplClassName) + .addModifiers(KModifier.ABSTRACT) + .addKdoc( + """ Skeletal implementation of the %L service based on Kotlin coroutines. - """.trimIndent(), - service.fullName - ) - .primaryConstructor( - FunSpec.constructorBuilder() - .addParameter(coroutineContextParameter) - .build() - ) - .superclass(AbstractCoroutineServerImpl::class) - .addSuperclassConstructorParameter("%N", coroutineContextParameter) + """ + .trimIndent(), + service.fullName + ) + .primaryConstructor( + FunSpec.constructorBuilder().addParameter(coroutineContextParameter).build() + ) + .superclass(AbstractCoroutineServerImpl::class) + .addSuperclassConstructorParameter("%N", coroutineContextParameter) var serverServiceDefinitionBuilder = CodeBlock.of("%M(%M())", SERVER_SERVICE_DEFINITION_BUILDER_FACTORY, service.grpcDescriptor) for (stub in stubs) { implBuilder.addFunction(stub.methodSpec) - serverServiceDefinitionBuilder = CodeBlock.of( - """ + serverServiceDefinitionBuilder = + CodeBlock.of( + """ %L .addMethod(%L) - """.trimIndent(), - serverServiceDefinitionBuilder, - stub.serverMethodDef - ) + """ + .trimIndent(), + serverServiceDefinitionBuilder, + stub.serverMethodDef + ) } implBuilder.addFunction( @@ -136,88 +133,90 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato data class MethodImplStub( val methodSpec: FunSpec, /** - * A [CodeBlock] that computes a [ServerMethodDefinition] based on an implementation of - * the function described in [methodSpec]. + * A [CodeBlock] that computes a [ServerMethodDefinition] based on an implementation of the + * function described in [methodSpec]. */ val serverMethodDef: CodeBlock ) @VisibleForTesting - fun serviceMethodStub(method: MethodDescriptor): MethodImplStub = with(config) { - val requestType = method.inputType.messageClass() - val requestParam = if (method.isClientStreaming) { - ParameterSpec.of(STREAMING_REQUEST_NAME, FLOW.parameterizedBy(requestType)) - } else { - ParameterSpec.of(UNARY_REQUEST_NAME, requestType) - } + fun serviceMethodStub(method: MethodDescriptor): MethodImplStub = + with(config) { + val requestType = method.inputType.messageClass() + val requestParam = + if (method.isClientStreaming) { + ParameterSpec.of(STREAMING_REQUEST_NAME, FLOW.parameterizedBy(requestType)) + } else { + ParameterSpec.of(UNARY_REQUEST_NAME, requestType) + } - val methodSpecBuilder = FunSpec.builder(method.methodName.toMemberSimpleName()) - .addModifiers(KModifier.OPEN) - .addParameter(requestParam) - .addStatement( - "throw %T(%M.withDescription(%S))", - StatusException::class, - UNIMPLEMENTED_STATUS, - "Method ${method.fullName} is unimplemented" - ) + val methodSpecBuilder = + FunSpec.builder(method.methodName.toMemberSimpleName()) + .addModifiers(KModifier.OPEN) + .addParameter(requestParam) + .addStatement( + "throw %T(%M.withDescription(%S))", + StatusException::class, + UNIMPLEMENTED_STATUS, + "Method ${method.fullName} is unimplemented" + ) - if (method.options.deprecated) { - methodSpecBuilder.addAnnotation( - AnnotationSpec.builder(Deprecated::class) - .addMember("%S", "The underlying service method is marked deprecated.") - .build() - ) - } + if (method.options.deprecated) { + methodSpecBuilder.addAnnotation( + AnnotationSpec.builder(Deprecated::class) + .addMember("%S", "The underlying service method is marked deprecated.") + .build() + ) + } - val responseType = method.outputType.messageClass() - if (method.isServerStreaming) { - methodSpecBuilder.returns(FLOW.parameterizedBy(responseType)) - } else { - methodSpecBuilder.returns(responseType) - methodSpecBuilder.addModifiers(KModifier.SUSPEND) - } + val responseType = method.outputType.messageClass() + if (method.isServerStreaming) { + methodSpecBuilder.returns(FLOW.parameterizedBy(responseType)) + } else { + methodSpecBuilder.returns(responseType) + methodSpecBuilder.addModifiers(KModifier.SUSPEND) + } - methodSpecBuilder.addKdoc(stubKDoc(method, requestParam)) + methodSpecBuilder.addKdoc(stubKDoc(method, requestParam)) - val methodSpec = methodSpecBuilder.build() + val methodSpec = methodSpecBuilder.build() - val smdFactory = if (method.isServerStreaming) { - if (method.isClientStreaming) BIDI_STREAMING_SMD else SERVER_STREAMING_SMD - } else { - if (method.isClientStreaming) CLIENT_STREAMING_SMD else UNARY_SMD - } + val smdFactory = + if (method.isServerStreaming) { + if (method.isClientStreaming) BIDI_STREAMING_SMD else SERVER_STREAMING_SMD + } else { + if (method.isClientStreaming) CLIENT_STREAMING_SMD else UNARY_SMD + } - val serverMethodDef = - CodeBlock.of( - """ + val serverMethodDef = + CodeBlock.of( + """ %M( context = this.context, descriptor = %L, implementation = ::%N ) - """.trimIndent(), - smdFactory, - method.descriptorCode, - methodSpec - ) - - MethodImplStub(methodSpec, serverMethodDef) - } + """ + .trimIndent(), + smdFactory, + method.descriptorCode, + methodSpec + ) - private fun stubKDoc( - method: MethodDescriptor, - requestParam: ParameterSpec - ): CodeBlock { - val kDocBindings = mapOf( - "requestParam" to requestParam, - "methodName" to method.fullName, - "flow" to FLOW, - "status" to Status::class, - "statusException" to StatusException::class, - "cancellationException" to CancellationException::class, - "illegalStateException" to IllegalStateException::class - ) + MethodImplStub(methodSpec, serverMethodDef) + } + private fun stubKDoc(method: MethodDescriptor, requestParam: ParameterSpec): CodeBlock { + val kDocBindings = + mapOf( + "requestParam" to requestParam, + "methodName" to method.fullName, + "flow" to FLOW, + "status" to Status::class, + "statusException" to StatusException::class, + "cancellationException" to CancellationException::class, + "illegalStateException" to IllegalStateException::class + ) val kDocSections = mutableListOf() @@ -230,7 +229,8 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato [%cancellationException:T], the RPC will fail with status `Status.CANCELLED`. If creating or collecting the returned flow fails for any other reason, the RPC will fail with `Status.UNKNOWN` with the exception as a cause. - """.trimIndent() + """ + .trimIndent() ) } else { kDocSections.add("Returns the response to an RPC for %methodName:L.") @@ -240,7 +240,8 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato [%status:T]. If this method fails with a [%cancellationException:T], the RPC will fail with status `Status.CANCELLED`. If this method fails for any other reason, the RPC will fail with `Status.UNKNOWN` with the exception as a cause. - """.trimIndent() + """ + .trimIndent() ) } @@ -250,19 +251,18 @@ class GrpcCoroutineServerGenerator(config: GeneratorConfig): ServiceCodeGenerato @param %requestParam:N A [%flow:T] of requests from the client. This flow can be collected only once and throws [%illegalStateException:T] on attempts to collect it more than once. - """.trimIndent() + """ + .trimIndent() ) } else { kDocSections.add( """ @param %requestParam:N The request from the client. - """.trimIndent() + """ + .trimIndent() ) } - return CodeBlock - .builder() - .addNamed(kDocSections.joinToString("\n\n"), kDocBindings) - .build() + return CodeBlock.builder().addNamed(kDocSections.joinToString("\n\n"), kDocBindings).build() } -} \ No newline at end of file +} diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/ProtoFileCodeGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/ProtoFileCodeGenerator.kt index 9c6d291f..6b18bd35 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/ProtoFileCodeGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/ProtoFileCodeGenerator.kt @@ -38,40 +38,43 @@ class ProtoFileCodeGenerator( private val generators = generators.map { it(config) } - fun generateCodeForFile(fileDescriptor: FileDescriptor): FileSpec? = with(config) { - val outerTypeName = fileDescriptor.outerClassSimpleName.withSuffix(topLevelSuffix) + fun generateCodeForFile(fileDescriptor: FileDescriptor): FileSpec? = + with(config) { + val outerTypeName = fileDescriptor.outerClassSimpleName.withSuffix(topLevelSuffix) - var wroteAnything = false - val fileBuilder = FileSpec.builder(javaPackage(fileDescriptor), outerTypeName) + var wroteAnything = false + val fileBuilder = FileSpec.builder(javaPackage(fileDescriptor), outerTypeName) - for (service in fileDescriptor.services) { - val serviceDecls = declarations { - for (generator in generators) { - merge(generator.generate(service)) + for (service in fileDescriptor.services) { + val serviceDecls = declarations { + for (generator in generators) { + merge(generator.generate(service)) + } } - } - if (serviceDecls.hasEnclosingScopeDeclarations) { - wroteAnything = true - val serviceObjectBuilder = - TypeSpec - .objectBuilder(service.serviceName.toClassSimpleName().withSuffix(topLevelSuffix)) - .addKdoc( - """ + if (serviceDecls.hasEnclosingScopeDeclarations) { + wroteAnything = true + val serviceObjectBuilder = + TypeSpec.objectBuilder( + service.serviceName.toClassSimpleName().withSuffix(topLevelSuffix) + ) + .addKdoc( + """ Holder for Kotlin coroutine-based client and server APIs for %L. - """.trimIndent(), - service.fullName - ) - serviceDecls.writeToEnclosingType(serviceObjectBuilder) - fileBuilder.addType(serviceObjectBuilder.build()) - } + """ + .trimIndent(), + service.fullName + ) + serviceDecls.writeToEnclosingType(serviceObjectBuilder) + fileBuilder.addType(serviceObjectBuilder.build()) + } - if (serviceDecls.hasTopLevelDeclarations) { - wroteAnything = true - serviceDecls.writeOnlyTopLevel(fileBuilder) + if (serviceDecls.hasTopLevelDeclarations) { + wroteAnything = true + serviceDecls.writeOnlyTopLevel(fileBuilder) + } } - } - return if (wroteAnything) fileBuilder.build() else null - } + return if (wroteAnything) fileBuilder.build() else null + } } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/ServiceCodeGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/ServiceCodeGenerator.kt index fee21396..52a11691 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/ServiceCodeGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/ServiceCodeGenerator.kt @@ -39,7 +39,7 @@ abstract class ServiceCodeGenerator(protected val config: GeneratorConfig) { operator fun plus(other: ServiceCodeGenerator): ServiceCodeGenerator { val me = this - return object: ServiceCodeGenerator(config) { + return object : ServiceCodeGenerator(config) { override fun generate(service: ServiceDescriptor): Declarations = declarations { merge(me.generate(service)) merge(other.generate(service)) @@ -49,11 +49,11 @@ abstract class ServiceCodeGenerator(protected val config: GeneratorConfig) { /** Gets the fully qualified name of the Java class generated by gRPC. */ protected val ServiceDescriptor.grpcClass: ClassName - get() = with(config) { - javaPackage(file).nestedClass( - serviceName.toClassSimpleName().withSuffix(GRPC_CLASS_NAME_SUFFIX) - ) - } + get() = + with(config) { + javaPackage(file) + .nestedClass(serviceName.toClassSimpleName().withSuffix(GRPC_CLASS_NAME_SUFFIX)) + } /** Gets the name of the function that gets the [io.grpc.ServiceDescriptor]. */ protected val ServiceDescriptor.grpcDescriptor: MemberName @@ -61,9 +61,10 @@ abstract class ServiceCodeGenerator(protected val config: GeneratorConfig) { /** Gets the name of the function that gets the [io.grpc.MethodDescriptor]. */ protected val MethodDescriptor.descriptorCode: CodeBlock - get() = CodeBlock.of( - "%T.%L()", - service.grpcClass, - methodName.toMemberSimpleName().withPrefix("get").withSuffix("Method") - ) + get() = + CodeBlock.of( + "%T.%L()", + service.grpcClass, + methodName.toMemberSimpleName().withPrefix("get").withSuffix("Method") + ) } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/ServiceNameGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/ServiceNameGenerator.kt index 2f6f53cd..32ea7e97 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/ServiceNameGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/ServiceNameGenerator.kt @@ -17,4 +17,4 @@ class ServiceNameGenerator(config: GeneratorConfig) : ServiceCodeGenerator(confi ) } } -} \ No newline at end of file +} diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/TopLevelConstantsGenerator.kt b/compiler/src/main/java/io/grpc/kotlin/generator/TopLevelConstantsGenerator.kt index d46aa0a8..daa00165 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/TopLevelConstantsGenerator.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/TopLevelConstantsGenerator.kt @@ -13,10 +13,8 @@ import io.grpc.kotlin.generator.protoc.builder import io.grpc.kotlin.generator.protoc.declarations import io.grpc.kotlin.generator.protoc.methodName -/** - * Generates top-level properties for the service descriptor and method descriptors. - */ -class TopLevelConstantsGenerator(config: GeneratorConfig): ServiceCodeGenerator(config) { +/** Generates top-level properties for the service descriptor and method descriptors. */ +class TopLevelConstantsGenerator(config: GeneratorConfig) : ServiceCodeGenerator(config) { override fun generate(service: Descriptors.ServiceDescriptor): Declarations = declarations { addProperty( PropertySpec.builder("serviceDescriptor", ServiceDescriptor::class) @@ -32,23 +30,18 @@ class TopLevelConstantsGenerator(config: GeneratorConfig): ServiceCodeGenerator( with(config) { for (method in service.methods) { addProperty( - PropertySpec - .builder( + PropertySpec.builder( method.methodName.toMemberSimpleName().withSuffix("Method"), - MethodDescriptor::class.asTypeName().parameterizedBy( - method.inputType.messageClass(), - method.outputType.messageClass() - ) + MethodDescriptor::class.asTypeName() + .parameterizedBy(method.inputType.messageClass(), method.outputType.messageClass()) ) .getter( FunSpec.getterBuilder() .addAnnotation(JvmStatic::class) - .addStatement("return %T.%L()", + .addStatement( + "return %T.%L()", service.grpcClass, - method.methodName - .toMemberSimpleName() - .withPrefix("get") - .withSuffix("Method") + method.methodName.toMemberSimpleName().withPrefix("get").withSuffix("Method") ) .build() ) @@ -57,4 +50,4 @@ class TopLevelConstantsGenerator(config: GeneratorConfig): ServiceCodeGenerator( } } } -} \ No newline at end of file +} diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/AbstractGeneratorRunner.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/AbstractGeneratorRunner.kt index 3795ac4d..866ba94e 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/AbstractGeneratorRunner.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/AbstractGeneratorRunner.kt @@ -35,26 +35,27 @@ abstract class AbstractGeneratorRunner { @VisibleForTesting fun mainAsProtocPlugin(input: InputStream, output: OutputStream) { - val generatorRequest = try { - input.buffered().use { - PluginProtos.CodeGeneratorRequest.parseFrom(it) - } - } catch (failure: Exception) { - throw IOException( - """ + val generatorRequest = + try { + input.buffered().use { PluginProtos.CodeGeneratorRequest.parseFrom(it) } + } catch (failure: Exception) { + throw IOException( + """ Attempted to run proto extension generator as protoc plugin, but could not read CodeGeneratorRequest. - """.trimIndent(), - failure - ) - } + """ + .trimIndent(), + failure + ) + } output.buffered().use { CodeGenerators.codeGeneratorResponse { - val descriptorMap = CodeGenerators.descriptorMap(generatorRequest.protoFileList) - generatorRequest.filesToGenerate - .map(descriptorMap::getValue) // compiled descriptors to generate code for - .flatMap(::generateCodeForFile) // generated extensions - }.writeTo(it) + val descriptorMap = CodeGenerators.descriptorMap(generatorRequest.protoFileList) + generatorRequest.filesToGenerate + .map(descriptorMap::getValue) // compiled descriptors to generate code for + .flatMap(::generateCodeForFile) // generated extensions + } + .writeTo(it) } } @@ -68,9 +69,10 @@ abstract class AbstractGeneratorRunner { val fileNameToDescriptorSet = inTransitiveClosure.associateWith { readFileDescriptorSet(fs.getPath(it)) } - val descriptorMap = CodeGenerators.descriptorMapFromUnsorted( - fileNameToDescriptorSet.values.flatMap { it.fileList } - ) + val descriptorMap = + CodeGenerators.descriptorMapFromUnsorted( + fileNameToDescriptorSet.values.flatMap { it.fileList } + ) toGenerateExtensionsFor .asSequence() @@ -92,4 +94,4 @@ abstract class AbstractGeneratorRunner { private fun readFileDescriptorSet(path: Path): FileDescriptorSet = Files.newInputStream(path).buffered().use { FileDescriptorSet.parseFrom(it) } -} \ No newline at end of file +} diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/BUILD.bazel b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/BUILD.bazel index c47d153a..4eaba52f 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/BUILD.bazel +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/BUILD.bazel @@ -11,8 +11,8 @@ kt_jvm_library( srcs = glob(["*.kt"]), deps = [ "//compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/graph", - "@protobuf//:protobuf_java", "@grpc_kotlin_maven//:com_google_guava_guava", "@grpc_kotlin_maven//:com_squareup_kotlinpoet", + "@protobuf//:protobuf_java", ], ) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ClassSimpleName.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ClassSimpleName.kt index bb2f56e0..e4c5ca2f 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ClassSimpleName.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ClassSimpleName.kt @@ -21,8 +21,8 @@ import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.TypeSpec /** - * Represents a simple (unqualified, unnested) name of a Kotlin/Java class, interface, or enum, - * in UpperCamelCase. + * Represents a simple (unqualified, unnested) name of a Kotlin/Java class, interface, or enum, in + * UpperCamelCase. */ data class ClassSimpleName(val name: String) : CharSequence by name { /** Returns this class name with a suffix. */ @@ -46,14 +46,12 @@ data class ClassSimpleName(val name: String) : CharSequence by name { } /** Create a builder for a class with the specified simple name. */ -fun TypeSpec.Companion.classBuilder( - simpleName: ClassSimpleName -): TypeSpec.Builder = classBuilder(simpleName.name) +fun TypeSpec.Companion.classBuilder(simpleName: ClassSimpleName): TypeSpec.Builder = + classBuilder(simpleName.name) /** Create a builder for an object with the specified simple name. */ -fun TypeSpec.Companion.objectBuilder( - simpleName: ClassSimpleName -): TypeSpec.Builder = objectBuilder(simpleName.name) +fun TypeSpec.Companion.objectBuilder(simpleName: ClassSimpleName): TypeSpec.Builder = + objectBuilder(simpleName.name) /** Given a fully qualified class name, get the fully qualified name of a nested class inside it. */ fun ClassName.nestedClass(classSimpleName: ClassSimpleName): ClassName = diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/CodeGenerators.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/CodeGenerators.kt index cebc5953..04f44971 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/CodeGenerators.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/CodeGenerators.kt @@ -47,10 +47,7 @@ internal object CodeGenerators { val byFileName = protoFileList.associateBy { it.fileName } val depGraph = - GraphBuilder - .directed() - .expectedNodeCount(protoFileList.size) - .build() + GraphBuilder.directed().expectedNodeCount(protoFileList.size).build() byFileName.keys.forEach { depGraph.addNode(it) } @@ -61,21 +58,27 @@ internal object CodeGenerators { } return descriptorMap( - TopologicalSortGraph.topologicalOrdering(depGraph) - .map { byFileName.getValue(it) } + TopologicalSortGraph.topologicalOrdering(depGraph).map { byFileName.getValue(it) } ) } fun toCodeGeneratorResponseFile(fileSpec: FileSpec): PluginProtos.CodeGeneratorResponse.File = - PluginProtos.CodeGeneratorResponse.File.newBuilder().also { - it.name = fileSpec.path.toString() - it.content = fileSpec.toString() - }.build() + PluginProtos.CodeGeneratorResponse.File.newBuilder() + .also { + it.name = fileSpec.path.toString() + it.content = fileSpec.toString() + } + .build() - inline fun codeGeneratorResponse(build: () -> List): PluginProtos.CodeGeneratorResponse { + inline fun codeGeneratorResponse( + build: () -> List + ): PluginProtos.CodeGeneratorResponse { val builder = PluginProtos.CodeGeneratorResponse.newBuilder() try { - builder.setSupportedFeatures(PluginProtos.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL_VALUE.toLong()) + builder + .setSupportedFeatures( + PluginProtos.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL_VALUE.toLong() + ) .addAllFile(build().map { toCodeGeneratorResponseFile(it) }) } catch (failure: Exception) { builder.error = Throwables.getStackTraceAsString(failure) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ConstantName.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ConstantName.kt index 2de3e7f0..cbe5815d 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ConstantName.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ConstantName.kt @@ -25,5 +25,4 @@ data class ConstantName(val name: String) : CharSequence by name { } /** Returns the fully qualified name of this constant, as a member of the specified class. */ -fun ClassName.member(constantName: ConstantName): MemberName = - MemberName(this, constantName.name) +fun ClassName.member(constantName: ConstantName): MemberName = MemberName(this, constantName.name) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Declarations.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Declarations.kt index 47290780..3558a223 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Declarations.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Declarations.kt @@ -29,7 +29,8 @@ inline fun declarations(callback: Declarations.Builder.() -> Unit): Declarations * An immutable set of declarations, some of which may be always at the top level, and some of which * may be in some containing class or namespace (which may be a type or a file). */ -class Declarations private constructor( +class Declarations +private constructor( private val atTopLevel: List Unit>, private val atEnclosing: List ) { diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/DescriptorUtil.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/DescriptorUtil.kt index c5d33580..add1cb52 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/DescriptorUtil.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/DescriptorUtil.kt @@ -17,8 +17,6 @@ package io.grpc.kotlin.generator.protoc import com.google.common.base.Ascii -import com.google.common.base.CaseFormat.LOWER_UNDERSCORE -import com.google.common.base.CaseFormat.UPPER_CAMEL import com.google.protobuf.DescriptorProtos.DescriptorProto import com.google.protobuf.DescriptorProtos.EnumDescriptorProto import com.google.protobuf.DescriptorProtos.FileDescriptorProto @@ -38,29 +36,35 @@ import io.grpc.kotlin.generator.protoc.TypeNames.BYTE_STRING import io.grpc.kotlin.generator.protoc.TypeNames.STRING private val JavaType.scalarType: TypeName - get() = when (this) { - JavaType.BOOLEAN -> BOOLEAN - JavaType.INT -> INT - JavaType.LONG -> LONG - JavaType.FLOAT -> FLOAT - JavaType.DOUBLE -> DOUBLE - JavaType.STRING -> STRING - JavaType.BYTE_STRING -> BYTE_STRING - else -> throw IllegalArgumentException("Not a scalar type") - } + get() = + when (this) { + JavaType.BOOLEAN -> BOOLEAN + JavaType.INT -> INT + JavaType.LONG -> LONG + JavaType.FLOAT -> FLOAT + JavaType.DOUBLE -> DOUBLE + JavaType.STRING -> STRING + JavaType.BYTE_STRING -> BYTE_STRING + else -> throw IllegalArgumentException("Not a scalar type") + } /** - * Returns the fully qualified Kotlin type representing values of this field, assuming that it - * is a scalar field (defined at https://developers.google.com/protocol-buffers/docs/proto3#scalar). + * Returns the fully qualified Kotlin type representing values of this field, assuming that it is a + * scalar field (defined at https://developers.google.com/protocol-buffers/docs/proto3#scalar). */ val FieldDescriptor.scalarType: TypeName get() = javaType.scalarType private val JavaType.isJavaPrimitive: Boolean - get() = when (this) { - JavaType.BOOLEAN, JavaType.INT, JavaType.LONG, JavaType.FLOAT, JavaType.DOUBLE -> true - else -> false - } + get() = + when (this) { + JavaType.BOOLEAN, + JavaType.INT, + JavaType.LONG, + JavaType.FLOAT, + JavaType.DOUBLE -> true + else -> false + } /** True if the Java type representing the contents of this field is a primitive. */ val FieldDescriptor.isJavaPrimitive: Boolean @@ -89,8 +93,8 @@ val Descriptor.messageClassSimpleName: ClassSimpleName get() = toProto().messageClassSimpleName /** - * Returns the simple name of the Java class that represents a message type, given its descriptor - * in proto form. + * Returns the simple name of the Java class that represents a message type, given its descriptor in + * proto form. */ val DescriptorProto.messageClassSimpleName: ClassSimpleName get() = simpleName.toClassSimpleName() @@ -100,8 +104,8 @@ val EnumDescriptor.enumClassSimpleName: ClassSimpleName get() = toProto().enumClassSimpleName /** - * Returns the name of the Java class representing the proto enum type, given its descriptor - * in proto form. + * Returns the name of the Java class representing the proto enum type, given its descriptor in + * proto form. */ val EnumDescriptorProto.enumClassSimpleName: ClassSimpleName get() = simpleName.toClassSimpleName() @@ -110,19 +114,21 @@ val FileDescriptor.outerClassSimpleName: ClassSimpleName get() = toProto().outerClassSimpleName private val FileDescriptorProto.explicitOuterClassSimpleName: ClassSimpleName? - get() = when (val name = options.javaOuterClassname) { - "" -> null - else -> ClassSimpleName(name) - } + get() = + when (val name = options.javaOuterClassname) { + "" -> null + else -> ClassSimpleName(name) + } /** The simple name of the outer class of a proto file. */ val FileDescriptorProto.outerClassSimpleName: ClassSimpleName get() { - explicitOuterClassSimpleName?.let { return it } + explicitOuterClassSimpleName?.let { + return it + } - val defaultOuterClassName = ClassSimpleName( - fileName.name.replace("-", "_").underscoresToCamel() - ) + val defaultOuterClassName = + ClassSimpleName(fileName.name.replace("-", "_").underscoresToCamel()) val foundDuplicate = enumTypeList.any { it.enumClassSimpleName == defaultOuterClassName } || @@ -141,14 +147,10 @@ private fun String.underscoresToCamel(): String { var capNextLetter = true for ((i, ch) in this.withIndex()) { if (ch in 'a'..'z') { - builder.append( - if (capNextLetter) Ascii.toUpperCase(ch) else ch - ) + builder.append(if (capNextLetter) Ascii.toUpperCase(ch) else ch) capNextLetter = false } else if (ch in 'A'..'Z') { - builder.append( - if (i == 0 && !capNextLetter) Ascii.toLowerCase(ch) else ch - ) + builder.append(if (i == 0 && !capNextLetter) Ascii.toLowerCase(ch) else ch) capNextLetter = false } else if (ch in '0'..'9') { builder.append(ch) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/GeneratorConfig.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/GeneratorConfig.kt index 48769ea9..9d0b556c 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/GeneratorConfig.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/GeneratorConfig.kt @@ -40,12 +40,10 @@ data class GeneratorConfig( FunSpec.builder(name).addModifiers(*inlineModifiers) /** Generates a [FunSpec.Builder] for a getter with appropriate modifiers. */ - fun getterBuilder(): FunSpec.Builder = - FunSpec.getterBuilder().addModifiers(*inlineModifiers) + fun getterBuilder(): FunSpec.Builder = FunSpec.getterBuilder().addModifiers(*inlineModifiers) /** Generates a [FunSpec.Builder] for a setter with appropriate modifiers. */ - fun setterBuilder(): FunSpec.Builder = - FunSpec.setterBuilder().addModifiers(*inlineModifiers) + fun setterBuilder(): FunSpec.Builder = FunSpec.setterBuilder().addModifiers(*inlineModifiers) /** Returns the package associated with Java APIs for protos in the specified file. */ fun javaPackage(fileDescriptor: FileDescriptor): PackageScope = diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicy.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicy.kt index 21adde86..e9574a2c 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicy.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicy.kt @@ -18,9 +18,7 @@ package io.grpc.kotlin.generator.protoc import com.google.protobuf.DescriptorProtos.FileDescriptorProto -/** - * Describes a policy for converting proto message types to Java classes in the correct package. - */ +/** Describes a policy for converting proto message types to Java classes in the correct package. */ enum class JavaPackagePolicy { OPEN_SOURCE { override fun javaPackage(fileProto: FileDescriptorProto): PackageScope { diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/MemberSimpleName.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/MemberSimpleName.kt index a22b3336..1a922ddf 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/MemberSimpleName.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/MemberSimpleName.kt @@ -40,6 +40,7 @@ data class MemberSimpleName(val name: String) : CharSequence by name { } fun withSuffix(suffix: String): MemberSimpleName = MemberSimpleName(name + suffix) + fun withPrefix(prefix: String): MemberSimpleName = MemberSimpleName(prefix + CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, name)) @@ -99,9 +100,8 @@ fun PropertySpec.Companion.builder( ): PropertySpec.Builder = builder(simpleName.name, type, *modifiers) /** Create a builder for a function with the specified simple name. */ -fun FunSpec.Companion.builder( - simpleName: MemberSimpleName -): FunSpec.Builder = builder(simpleName.name) +fun FunSpec.Companion.builder(simpleName: MemberSimpleName): FunSpec.Builder = + builder(simpleName.name) /** Create a fully qualified [MemberName] in this class with the specified name. */ fun ClassName.member(memberSimpleName: MemberSimpleName): MemberName = member(memberSimpleName.name) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoFieldName.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoFieldName.kt index a0e3286c..60f7b099 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoFieldName.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoFieldName.kt @@ -26,14 +26,16 @@ import com.google.protobuf.Descriptors.OneofDescriptor data class ProtoFieldName(private val name: String) : CharSequence by name { companion object { // based on compiler/java/internal/helpers.cc - private val SPECIAL_CASES = setOf( - ProtoFieldName("class"), - ProtoFieldName("cached_size"), - ProtoFieldName("serialized_size") - ) + private val SPECIAL_CASES = + setOf( + ProtoFieldName("class"), + ProtoFieldName("cached_size"), + ProtoFieldName("serialized_size") + ) private val LETTER = CharMatcher.inRange('A', 'Z').or(CharMatcher.inRange('a', 'z')) private val DIGIT = CharMatcher.inRange('0', '9') + private operator fun CharMatcher.contains(c: Char) = matches(c) } @@ -45,13 +47,13 @@ data class ProtoFieldName(private val name: String) : CharSequence by name { val finalCamelCaseName = StringBuilder(name.length) for (word in nameComponents) { if (finalCamelCaseName.isEmpty()) { - finalCamelCaseName.append( - Ascii.toLowerCase(word[0]) - ).append(upperCaseAfterNumeric(word), 1, word.length) + finalCamelCaseName + .append(Ascii.toLowerCase(word[0])) + .append(upperCaseAfterNumeric(word), 1, word.length) } else { - finalCamelCaseName.append( - Ascii.toUpperCase(word[0]) - ).append(upperCaseAfterNumeric(word), 1, word.length) + finalCamelCaseName + .append(Ascii.toUpperCase(word[0])) + .append(upperCaseAfterNumeric(word), 1, word.length) } } return MemberSimpleName(finalCamelCaseName.toString()) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoMethodName.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoMethodName.kt index 0ed4ee1b..bf2757d7 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoMethodName.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoMethodName.kt @@ -35,9 +35,7 @@ data class ProtoMethodName(val name: String) : CharSequence by name { } private fun handleSpecialCharacters(name: MemberSimpleName): MemberSimpleName { - return name.split("_") - .map(::MemberSimpleName) - .reduce { acc, simpleName -> acc + simpleName } + return name.split("_").map(::MemberSimpleName).reduce { acc, simpleName -> acc + simpleName } } override fun toString() = name diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Scope.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Scope.kt index 89174a6e..16b73feb 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Scope.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/Scope.kt @@ -20,7 +20,7 @@ import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec /** - * Describes a location classes can be nested in, such as a package or another class. This can + * Describes a location classes can be nested in, such as a package or another class. This can * convert a [ClassSimpleName] to a fully qualified [ClassName]. */ sealed class Scope { @@ -29,25 +29,17 @@ sealed class Scope { fun nestedScope(simpleName: ClassSimpleName): Scope = ClassScope(nestedClass(simpleName)) } -/** - * The unqualified, top-level scope. - */ +/** The unqualified, top-level scope. */ object UnqualifiedScope : Scope() { - override fun nestedClass(simpleName: ClassSimpleName): ClassName = - ClassName("", simpleName.name) + override fun nestedClass(simpleName: ClassSimpleName): ClassName = ClassName("", simpleName.name) } -/** - * The scope of a package. - */ +/** The scope of a package. */ data class PackageScope(val pkg: String) : Scope() { - override fun nestedClass(simpleName: ClassSimpleName): ClassName = - ClassName(pkg, simpleName.name) + override fun nestedClass(simpleName: ClassSimpleName): ClassName = ClassName(pkg, simpleName.name) } -/** - * The scope of a fully qualified class. - */ +/** The scope of a fully qualified class. */ class ClassScope(private val className: ClassName) : Scope() { override fun nestedClass(simpleName: ClassSimpleName): ClassName = className.nestedClass(simpleName) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubject.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubject.kt index eaf02581..1110d517 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubject.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubject.kt @@ -30,31 +30,23 @@ fun assertThat(declarations: Declarations): DeclarationsSubject = assertAbout(declarationsSubjectFactory).that(declarations) /** A Truth subject for [Declarations]. */ -class DeclarationsSubject( - failureMetadata: FailureMetadata, - private val actual: Declarations? -) : Subject(failureMetadata, actual) { +class DeclarationsSubject(failureMetadata: FailureMetadata, private val actual: Declarations?) : + Subject(failureMetadata, actual) { fun generatesTopLevel(indentedCode: String) { val actualCode = - FileSpec.builder("", "MyDeclarations.kt") - .apply { actual?.writeOnlyTopLevel(this) } - .build() + FileSpec.builder("", "MyDeclarations.kt").apply { actual?.writeOnlyTopLevel(this) }.build() check("topLevel").about(fileSpecs).that(actualCode).generates(indentedCode) } fun generatesEnclosed(indentedCode: String) { val actualCode = - FileSpec.builder("", "MyDeclarations.kt") - .apply { actual?.writeToEnclosingFile(this) } - .build() + FileSpec.builder("", "MyDeclarations.kt").apply { actual?.writeToEnclosingFile(this) }.build() check("enclosed").about(fileSpecs).that(actualCode).generates(indentedCode) } fun generatesNoTopLevelMembers() { val actualCode = - FileSpec.builder("", "MyDeclarations.kt") - .apply { actual?.writeOnlyTopLevel(this) } - .build() + FileSpec.builder("", "MyDeclarations.kt").apply { actual?.writeOnlyTopLevel(this) }.build() check("topLevel") .withMessage("top level declarations: %s", actualCode) .that(actual?.hasTopLevelDeclarations) @@ -63,9 +55,7 @@ class DeclarationsSubject( fun generatesNoEnclosedMembers() { val actualCode = - FileSpec.builder("", "MyDeclarations.kt") - .apply { actual?.writeToEnclosingFile(this) } - .build() + FileSpec.builder("", "MyDeclarations.kt").apply { actual?.writeToEnclosingFile(this) }.build() check("enclosed") .withMessage("enclosed declarations: %s", actualCode) .that(actual?.hasEnclosingScopeDeclarations) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubject.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubject.kt index af65122b..ccc79487 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubject.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubject.kt @@ -27,10 +27,8 @@ val fileSpecs: Subject.Factory = Subject.Factory(::Fi fun assertThat(fileSpec: FileSpec): FileSpecSubject = assertAbout(fileSpecs).that(fileSpec) /** A Truth subject for [FileSpec]. */ -class FileSpecSubject( - failureMetadata: FailureMetadata, - private val actual: FileSpec? -) : Subject(failureMetadata, actual) { +class FileSpecSubject(failureMetadata: FailureMetadata, private val actual: FileSpec?) : + Subject(failureMetadata, actual) { fun generates(indentedCode: String) { val expectedCode = indentedCode.trimIndent() val actualCode = actual.toString().trim().lines().joinToString("\n") { it.trimEnd() } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubject.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubject.kt index 9018a5c8..57da7f4f 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubject.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubject.kt @@ -27,10 +27,8 @@ val funSpecs: Subject.Factory = Subject.Factory(::FunSp fun assertThat(funSpec: FunSpec): FunSpecSubject = Truth.assertAbout(funSpecs).that(funSpec) /** A Truth subject for [FunSpec]. */ -class FunSpecSubject( - failureMetadata: FailureMetadata, - private val actual: FunSpec? -) : Subject(failureMetadata, actual) { +class FunSpecSubject(failureMetadata: FailureMetadata, private val actual: FunSpec?) : + Subject(failureMetadata, actual) { fun generates(indentedCode: String) { val expectedCode = indentedCode.trimIndent() val actualCode = actual.toString().trim() diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubject.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubject.kt index 6988b73f..497d1e6a 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubject.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubject.kt @@ -27,10 +27,8 @@ val typeSpecs: Subject.Factory = Subject.Factory(::Ty fun assertThat(typeSpec: TypeSpec): TypeSpecSubject = assertAbout(typeSpecs).that(typeSpec) /** A Truth subject for [TypeSpec]. */ -class TypeSpecSubject( - failureMetadata: FailureMetadata, - private val actual: TypeSpec? -) : Subject(failureMetadata, actual) { +class TypeSpecSubject(failureMetadata: FailureMetadata, private val actual: TypeSpec?) : + Subject(failureMetadata, actual) { fun generates(indentedCode: String) { val expectedCode = indentedCode.trimIndent() val actualCode = actual.toString().trim() diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/graph/TopologicalSortGraph.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/graph/TopologicalSortGraph.kt index 7b047a24..380a0d52 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/graph/TopologicalSortGraph.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/graph/TopologicalSortGraph.kt @@ -23,13 +23,13 @@ import io.grpc.kotlin.generator.protoc.util.sort.TopologicalSort.sortLexicograph @Beta object TopologicalSortGraph { - fun topologicalOrdering(graph: Graph): List { - checkArgument(graph.isDirected, "Cannot get topological ordering of an undirected graph.") - val partialOrdering: PartialOrdering = object : PartialOrdering { - override fun getPredecessors(element: N): Set = element?.let { - graph.predecessors(it) - } ?: emptySet() - } - return sortLexicographicallyLeast(graph.nodes(), partialOrdering) - } + fun topologicalOrdering(graph: Graph): List { + checkArgument(graph.isDirected, "Cannot get topological ordering of an undirected graph.") + val partialOrdering: PartialOrdering = + object : PartialOrdering { + override fun getPredecessors(element: N): Set = + element?.let { graph.predecessors(it) } ?: emptySet() + } + return sortLexicographicallyLeast(graph.nodes(), partialOrdering) + } } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/BUILD.bazel b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/BUILD.bazel index bb5e52cb..3396114e 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/BUILD.bazel +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/BUILD.bazel @@ -12,4 +12,4 @@ kt_jvm_library( "PartialOrdering.kt", "TopologicalSort.kt", ], -) \ No newline at end of file +) diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/PartialOrdering.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/PartialOrdering.kt index 4b5eca32..0b305f4b 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/PartialOrdering.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/PartialOrdering.kt @@ -22,13 +22,12 @@ package io.grpc.kotlin.generator.protoc.util.sort * @author Okhtay Ilghami (okhtay@google.com) */ interface PartialOrdering { - /** - * Returns nodes that are considered "less than" `element` for purposes of a [ ]. Transitive predecessors do not need to be included. - * - * - * For example, if `getPredecessors(a)` includes `b` and `getPredecessors(b)` - * includes `c`, it is not necessary to include `c` in `getPredecessors(a)`. - * `c` is not a "direct" predecessor of `a`. - */ - fun getPredecessors(element: T): Set + /** + * Returns nodes that are considered "less than" `element` for purposes of a [ ]. Transitive + * predecessors do not need to be included. + * + * For example, if `getPredecessors(a)` includes `b` and `getPredecessors(b)` includes `c`, it is + * not necessary to include `c` in `getPredecessors(a)`. `c` is not a "direct" predecessor of `a`. + */ + fun getPredecessors(element: T): Set } diff --git a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/TopologicalSort.kt b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/TopologicalSort.kt index 9ed1cf40..bfa76451 100644 --- a/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/TopologicalSort.kt +++ b/compiler/src/main/java/io/grpc/kotlin/generator/protoc/util/sort/TopologicalSort.kt @@ -23,8 +23,8 @@ import java.util.PriorityQueue * described in TAOCP Section 2.2.3, with a little bit of extra state to provide a stable sort. The * constructed ordering is guaranteed to be deterministic. * - * - * The elements to be sorted should implement the standard [Object.hashCode] and [ ][Object.equals] methods. + * The elements to be sorted should implement the standard [Object.hashCode] and [ ][Object.equals] + * methods. */ object TopologicalSort { /** @@ -33,9 +33,9 @@ object TopologicalSort { * ordering" can be achieved by first lexicographically sorting the input list according to your * own criteria and then calling this method. * - * * A high-level sketch of toplogical sort from Wikipedia: * (http://en.wikipedia.org/wiki/Topological_sorting) + * * ``` * L ← Empty list that will contain the sorted elements * S ← Set of all nodes with no incoming edges @@ -52,15 +52,14 @@ object TopologicalSort { * return L (a topologically sorted order) * ``` * - * - * We extend the basic algorithm to traverse `S` in a particular order based on the - * original order of the elements to enforce a deterministic result (lexicographically based on - * the order of elements in the original input list). + * We extend the basic algorithm to traverse `S` in a particular order based on the original order + * of the elements to enforce a deterministic result (lexicographically based on the order of + * elements in the original input list). * * @param elements a mutable list of elements to be sorted. * @param order the partial order between elements. * @throws CyclicalGraphException if the graph is cyclical or any predecessor is not present in - * the input list. + * the input list. */ fun sortLexicographicallyLeast(elements: Collection, order: PartialOrdering): List { val internalElements = internalizeElements(elements, order) @@ -82,7 +81,9 @@ object TopologicalSort { if (sortedElements.size != elements.size) { val elementsInCycle = internalElements.filter { it.predecessorCount > 0 }.map { it.element } throw CyclicalGraphException( - "Cyclical graphs can not be topologically sorted.", elementsInCycle) + "Cyclical graphs can not be topologically sorted.", + elementsInCycle + ) } return sortedElements.map { it.element } } @@ -96,13 +97,13 @@ object TopologicalSort { * @return a list of [InternalElement]s initialized with dependency structure. */ private fun internalizeElements( - elements: Iterable, order: PartialOrdering + elements: Iterable, + order: PartialOrdering ): List> { val internalElements: MutableList> = mutableListOf() // Subtle: due to the potential for duplicates in elements, we need to map every element to a // list of the corresponding InternalElements. - val internalElementsByValue: MutableMap>> = - mutableMapOf() + val internalElementsByValue: MutableMap>> = mutableMapOf() for ((index, element) in elements.withIndex()) { val internalElement = InternalElement(element, index) internalElements.add(internalElement) @@ -136,9 +137,9 @@ object TopologicalSort { class CyclicalGraphException( message: String, // not parameterized because exceptions can't be parameterized /** - * A list of the elements that are part of the cycle, as well as elements that are - * greater than the elements in the cycle, according to the partial ordering. The elements in - * this list are not in a meaningful order. + * A list of the elements that are part of the cycle, as well as elements that are greater than + * the elements in the cycle, according to the partial ordering. The elements in this list are + * not in a meaningful order. */ val elementsInCycle: List<*> ) : RuntimeException(message) @@ -146,23 +147,19 @@ object TopologicalSort { /** * To bundle an element with a mutable structure of the dependency graph. * + * Each [InternalElement] counts how many predecessors it has left. Rather than keep a list of + * predecessors, we reverse the relation so that it's easy to navigate to the successors when an + * [InternalElement] is selected for sorting. * - * Each [InternalElement] counts how many predecessors it has left. Rather than keep a - * list of predecessors, we reverse the relation so that it's easy to navigate to the successors - * when an [InternalElement] is selected for sorting. - * - * - * This maintains a `originalIndex` to allow a "stable" sort based on the original - * position in the input list. + * This maintains a `originalIndex` to allow a "stable" sort based on the original position in the + * input list. */ - private data class InternalElement( - val element: T, - val originalIndex: Int - ) : Comparable> { + private data class InternalElement(val element: T, val originalIndex: Int) : + Comparable> { val successors: MutableList> = mutableListOf() var predecessorCount = 0 + override operator fun compareTo(other: InternalElement): Int = originalIndex.compareTo(other.originalIndex) - } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ClassSimpleNameTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ClassSimpleNameTest.kt index fcf20cb9..7d17d1b4 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ClassSimpleNameTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ClassSimpleNameTest.kt @@ -29,8 +29,7 @@ import org.junit.runners.JUnit4 class ClassSimpleNameTest { @Test fun withSuffix() { - assertThat(ClassSimpleName("FooBar").withSuffix("Baz")) - .isEqualTo(ClassSimpleName("FooBarBaz")) + assertThat(ClassSimpleName("FooBar").withSuffix("Baz")).isEqualTo(ClassSimpleName("FooBarBaz")) } @Test @@ -45,11 +44,9 @@ class ClassSimpleNameTest { fun asMemberWithPrefix() { val simpleName = ClassSimpleName("SimpleName") - assertThat(simpleName.asMemberWithPrefix("get")) - .isEqualTo(MemberSimpleName("getSimpleName")) + assertThat(simpleName.asMemberWithPrefix("get")).isEqualTo(MemberSimpleName("getSimpleName")) - assertThat(simpleName.asMemberWithPrefix("")) - .isEqualTo(MemberSimpleName("simpleName")) + assertThat(simpleName.asMemberWithPrefix("")).isEqualTo(MemberSimpleName("simpleName")) } @Test diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DeclarationsTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DeclarationsTest.kt index 359458a5..84a1f856 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DeclarationsTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DeclarationsTest.kt @@ -35,12 +35,7 @@ class DeclarationsTest { private val type = TypeSpec.objectBuilder("MyObject").build() private inline fun someFile(block: FileSpec.Builder.() -> Unit): String { - return FileSpec - .builder("com.foo.bar", "SomeFile.kt") - .apply(block) - .build() - .toString() - .trim() + return FileSpec.builder("com.foo.bar", "SomeFile.kt").apply(block).build().toString().trim() } @Test @@ -49,12 +44,9 @@ class DeclarationsTest { addTopLevelProperty(property) addFunction(function) } - assertThat( - someFile { - decls.writeAllAtTopLevel(this) - } - ).isEqualTo( - """ + assertThat(someFile { decls.writeAllAtTopLevel(this) }) + .isEqualTo( + """ package com.foo.bar import kotlin.Int @@ -63,8 +55,9 @@ class DeclarationsTest { public fun someFunction() { } - """.trimIndent() - ) + """ + .trimIndent() + ) } @Test @@ -73,27 +66,21 @@ class DeclarationsTest { addTopLevelProperty(property) addFunction(function) } - assertThat( - someFile { - decls.writeOnlyTopLevel(this) - } - ).isEqualTo( - """ + assertThat(someFile { decls.writeOnlyTopLevel(this) }) + .isEqualTo( + """ package com.foo.bar import kotlin.Int public val someProperty: Int - """.trimIndent() - ) + """ + .trimIndent() + ) } private fun FileSpec.Builder.writeToSomeEnclosingObject(decls: Declarations) { - addType( - TypeSpec.objectBuilder("SomeObject") - .apply { decls.writeToEnclosingType(this) } - .build() - ) + addType(TypeSpec.objectBuilder("SomeObject").apply { decls.writeToEnclosingType(this) }.build()) } @Test @@ -102,85 +89,59 @@ class DeclarationsTest { addTopLevelProperty(property) addFunction(function) } - assertThat( - someFile { - writeToSomeEnclosingObject(decls) - } - ).isEqualTo( - """ + assertThat(someFile { writeToSomeEnclosingObject(decls) }) + .isEqualTo( + """ package com.foo.bar public object SomeObject { public fun someFunction() { } } - """.trimIndent() - ) + """ + .trimIndent() + ) } @Test fun hasTopLevel() { - assertThat( - declarations { - addTopLevelProperty(property) - }.hasTopLevelDeclarations - ).isTrue() - assertThat( - declarations { - addProperty(property) - }.hasTopLevelDeclarations - ).isFalse() + assertThat(declarations { addTopLevelProperty(property) }.hasTopLevelDeclarations).isTrue() + assertThat(declarations { addProperty(property) }.hasTopLevelDeclarations).isFalse() } @Test fun hasEnclosingScopeDeclarations() { - assertThat( - declarations { - addTopLevelProperty(property) - }.hasEnclosingScopeDeclarations - ).isFalse() - assertThat( - declarations { - addProperty(property) - }.hasEnclosingScopeDeclarations - ).isTrue() + assertThat(declarations { addTopLevelProperty(property) }.hasEnclosingScopeDeclarations) + .isFalse() + assertThat(declarations { addProperty(property) }.hasEnclosingScopeDeclarations).isTrue() } @Test fun addTopLevelProperty() { - val decls = declarations { - addTopLevelProperty(property) - } - assertThat(decls).generatesTopLevel( - """ + val decls = declarations { addTopLevelProperty(property) } + assertThat(decls) + .generatesTopLevel(""" import kotlin.Int public val someProperty: Int - """ - ) + """) assertThat(decls).generatesNoEnclosedMembers() } @Test fun addTopLevelFunction() { - val decls = declarations { - addTopLevelFunction(function) - } + val decls = declarations { addTopLevelFunction(function) } assertThat(decls) - .generatesTopLevel( - """ + .generatesTopLevel(""" public fun someFunction() { } - """ - ) + """) assertThat(decls).generatesNoEnclosedMembers() } @Test fun addTopLevelType() { - val decls = declarations { - addTopLevelType(type) - } + val decls = declarations { addTopLevelType(type) } assertThat(decls).generatesTopLevel("public object MyObject") assertThat(decls).generatesNoEnclosedMembers() } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DescriptorUtilTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DescriptorUtilTest.kt index 7ea833be..62dacf92 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DescriptorUtilTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/DescriptorUtilTest.kt @@ -17,13 +17,13 @@ package io.grpc.kotlin.generator.protoc import com.google.common.truth.Truth.assertThat -import io.grpc.testing.ProtoFileWithHyphen import io.grpc.kotlin.generator.protoc.testproto.Example3 import io.grpc.kotlin.generator.protoc.testproto.Example3.ExampleEnum import io.grpc.kotlin.generator.protoc.testproto.Example3.ExampleMessage import io.grpc.kotlin.generator.protoc.testproto.HasNestedClassNameConflictOuterClass import io.grpc.kotlin.generator.protoc.testproto.HasOuterClassNameConflictOuterClass import io.grpc.kotlin.generator.protoc.testproto.MyExplicitOuterClassName +import io.grpc.testing.ProtoFileWithHyphen import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/GeneratorConfigTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/GeneratorConfigTest.kt index e359f711..5dc3bf61 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/GeneratorConfigTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/GeneratorConfigTest.kt @@ -28,11 +28,11 @@ import io.grpc.kotlin.generator.protoc.testproto.HasOuterClassNameConflictOuterC import io.grpc.kotlin.generator.protoc.testproto.MyExplicitOuterClassName import io.grpc.testing.ServiceNameConflictsWithFileOuterClass import io.grpc.testing.ServiceTOuterClass +import kotlin.reflect.KClass import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import testing.ImplicitJavaPackage -import kotlin.reflect.KClass /** Tests for [GeneratorConfig]. */ @RunWith(JUnit4::class) @@ -42,9 +42,7 @@ class GeneratorConfigTest { private fun generateFile(block: Declarations.Builder.() -> Unit): String { return FileSpec.builder("com.google", "FooBar.kt") - .apply { - declarations(block).writeAllAtTopLevel(this) - } + .apply { declarations(block).writeAllAtTopLevel(this) } .build() .toString() .trim() @@ -53,32 +51,28 @@ class GeneratorConfigTest { @Test fun funSpecBuilder() { with(GeneratorConfig(javaPackagePolicy, aggressiveInlining = false)) { - assertThat( - generateFile { - addFunction(funSpecBuilder(MemberSimpleName("fooBar")).build()) - } - ).isEqualTo( - """ + assertThat(generateFile { addFunction(funSpecBuilder(MemberSimpleName("fooBar")).build()) }) + .isEqualTo( + """ package com.google public fun fooBar() { } - """.trimIndent() - ) + """ + .trimIndent() + ) } with(GeneratorConfig(javaPackagePolicy, aggressiveInlining = true)) { - assertThat( - generateFile { - addFunction(funSpecBuilder(MemberSimpleName("fooBar")).build()) - } - ).isEqualTo( - """ + assertThat(generateFile { addFunction(funSpecBuilder(MemberSimpleName("fooBar")).build()) }) + .isEqualTo( + """ package com.google public inline fun fooBar() { } - """.trimIndent() - ) + """ + .trimIndent() + ) } } @@ -86,44 +80,48 @@ class GeneratorConfigTest { fun getterBuilder() { with(GeneratorConfig(javaPackagePolicy, aggressiveInlining = false)) { assertThat( - generateFile { - addProperty( - PropertySpec.builder("someProp", INT) - .getter(getterBuilder().addStatement("return 1").build()) - .build() - ) - } - ).isEqualTo( - """ + generateFile { + addProperty( + PropertySpec.builder("someProp", INT) + .getter(getterBuilder().addStatement("return 1").build()) + .build() + ) + } + ) + .isEqualTo( + """ package com.google import kotlin.Int public val someProp: Int get() = 1 - """.trimIndent() - ) + """ + .trimIndent() + ) } with(GeneratorConfig(javaPackagePolicy, aggressiveInlining = true)) { assertThat( - generateFile { - addProperty( - PropertySpec.builder("someProp", INT) - .getter(getterBuilder().addStatement("return 1").build()) - .build() - ) - } - ).isEqualTo( - """ + generateFile { + addProperty( + PropertySpec.builder("someProp", INT) + .getter(getterBuilder().addStatement("return 1").build()) + .build() + ) + } + ) + .isEqualTo( + """ package com.google import kotlin.Int public inline val someProp: Int get() = 1 - """.trimIndent() - ) + """ + .trimIndent() + ) } } @@ -132,16 +130,17 @@ class GeneratorConfigTest { val param = ParameterSpec.builder("newValue", INT).build() with(GeneratorConfig(javaPackagePolicy, aggressiveInlining = false)) { assertThat( - generateFile { - addProperty( - PropertySpec.builder("someProp", INT) - .mutable(true) - .setter(setterBuilder().addParameter(param).build()) - .build() - ) - } - ).isEqualTo( - """ + generateFile { + addProperty( + PropertySpec.builder("someProp", INT) + .mutable(true) + .setter(setterBuilder().addParameter(param).build()) + .build() + ) + } + ) + .isEqualTo( + """ package com.google import kotlin.Int @@ -149,22 +148,24 @@ class GeneratorConfigTest { public var someProp: Int set(newValue) { } - """.trimIndent() - ) + """ + .trimIndent() + ) } with(GeneratorConfig(javaPackagePolicy, aggressiveInlining = true)) { assertThat( - generateFile { - addProperty( - PropertySpec.builder("someProp", INT) - .mutable(true) - .setter(setterBuilder().addParameter(param).build()) - .build() - ) - } - ).isEqualTo( - """ + generateFile { + addProperty( + PropertySpec.builder("someProp", INT) + .mutable(true) + .setter(setterBuilder().addParameter(param).build()) + .build() + ) + } + ) + .isEqualTo( + """ package com.google import kotlin.Int @@ -172,8 +173,9 @@ class GeneratorConfigTest { public var someProp: Int inline set(newValue) { } - """.trimIndent() - ) + """ + .trimIndent() + ) } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicyTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicyTest.kt index d414c6be..1afd4468 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicyTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/JavaPackagePolicyTest.kt @@ -46,9 +46,7 @@ class JavaPackagePolicyTest { fun implicitJavaPackageGoogle() { with(JavaPackagePolicy.OPEN_SOURCE) { assertThat(javaPackage(ImplicitJavaPackage.getDescriptor().toProto())) - .isEqualTo( - PackageScope("testing") - ) + .isEqualTo(PackageScope("testing")) } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/MemberSimpleNameTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/MemberSimpleNameTest.kt index 803241c0..29050ccb 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/MemberSimpleNameTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/MemberSimpleNameTest.kt @@ -28,8 +28,7 @@ class MemberSimpleNameTest { fun withPrefix() { assertThat(MemberSimpleName("myField").withPrefix("get")) .isEqualTo(MemberSimpleName("getMyField")) - assertThat(MemberSimpleName("field").withPrefix("get")) - .isEqualTo(MemberSimpleName("getField")) + assertThat(MemberSimpleName("field").withPrefix("get")).isEqualTo(MemberSimpleName("getField")) } @Test diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/OptionalProto3FieldTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/OptionalProto3FieldTest.kt index b44f4ca5..612e97ae 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/OptionalProto3FieldTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/OptionalProto3FieldTest.kt @@ -21,5 +21,4 @@ class OptionalProto3FieldTest { assertThat(TestProto3Optional.OptionalProto3::class.java.getMethod("hasOptionalField")) .isNotNull() } - } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoEnumValueNameTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoEnumValueNameTest.kt index 2ae506d2..d01c9b49 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoEnumValueNameTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoEnumValueNameTest.kt @@ -27,8 +27,7 @@ import org.junit.runners.JUnit4 class ProtoEnumValueNameTest { @Test fun asConstantName() { - assertThat(ProtoEnumValueName("FOO_BAR").asConstantName) - .isEqualTo(ConstantName("FOO_BAR")) + assertThat(ProtoEnumValueName("FOO_BAR").asConstantName).isEqualTo(ConstantName("FOO_BAR")) } @Test diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFieldNameTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFieldNameTest.kt index 307bbb15..af8a302a 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFieldNameTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFieldNameTest.kt @@ -59,19 +59,14 @@ class ProtoFieldNameTest { fun fieldDescriptorName() { val fieldDescriptor = Example3.ExampleMessage.getDescriptor().findFieldByName("string_oneof_option") - assertThat(fieldDescriptor.fieldName) - .isEqualTo(ProtoFieldName("string_oneof_option")) + assertThat(fieldDescriptor.fieldName).isEqualTo(ProtoFieldName("string_oneof_option")) } @Test fun oneofName() { val oneofDescriptor = - Example3.ExampleMessage - .getDescriptor() - .oneofs - .find { it.name == "my_oneof" }!! - assertThat(oneofDescriptor.oneofName) - .isEqualTo(ProtoFieldName("my_oneof")) + Example3.ExampleMessage.getDescriptor().oneofs.find { it.name == "my_oneof" }!! + assertThat(oneofDescriptor.oneofName).isEqualTo(ProtoFieldName("my_oneof")) } @Test diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFileNameTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFileNameTest.kt index 765f0d29..fe845b50 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFileNameTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoFileNameTest.kt @@ -27,17 +27,11 @@ import org.junit.runners.JUnit4 class ProtoFileNameTest { @Test fun fileName() { - assertThat(Example3.getDescriptor().fileName) - .isEqualTo( - ProtoFileName( - "testing/example3.proto" - ) - ) + assertThat(Example3.getDescriptor().fileName).isEqualTo(ProtoFileName("testing/example3.proto")) } @Test fun name() { - assertThat(ProtoFileName("foo/bar/baz/quux.proto").name) - .isEqualTo("quux") + assertThat(ProtoFileName("foo/bar/baz/quux.proto").name).isEqualTo("quux") } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoMethodNameTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoMethodNameTest.kt index 8a9c18e7..4e4cde81 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoMethodNameTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ProtoMethodNameTest.kt @@ -9,20 +9,20 @@ import org.junit.runners.JUnit4 @RunWith(JUnit4::class) class ProtoMethodNameTest { @Test - fun toMemberSimpleNameWithSingleUnderscore(){ + fun toMemberSimpleNameWithSingleUnderscore() { assertThat(ProtoMethodName("say_hello").toMemberSimpleName()) .isEqualTo(MemberSimpleName("sayHello")) } @Test - fun toMemberSimpleNameWithMultipleUnderscores(){ + fun toMemberSimpleNameWithMultipleUnderscores() { assertThat(ProtoMethodName("say_hello_again").toMemberSimpleName()) .isEqualTo(MemberSimpleName("sayHelloAgain")) } @Test - fun toMemberSimpleNameWithRecommendedNamingStyle(){ + fun toMemberSimpleNameWithRecommendedNamingStyle() { assertThat(ProtoMethodName("SayHello").toMemberSimpleName()) .isEqualTo(MemberSimpleName("sayHello")) } -} \ No newline at end of file +} diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ScopeTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ScopeTest.kt index 53d8d776..ac6a3581 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ScopeTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/ScopeTest.kt @@ -39,11 +39,7 @@ class ScopeTest { @Test fun classScope() { - val name = - ClassScope(ClassName("com.foo.bar", "Baz")) - .nestedClass(ClassSimpleName("Quux")) - assertThat(name).isEqualTo( - ClassName("com.foo.bar", "Baz", "Quux") - ) + val name = ClassScope(ClassName("com.foo.bar", "Baz")).nestedClass(ClassSimpleName("Quux")) + assertThat(name).isEqualTo(ClassName("com.foo.bar", "Baz", "Quux")) } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubjectTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubjectTest.kt index 9c830cb8..bb3a797a 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubjectTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/DeclarationsSubjectTest.kt @@ -36,69 +36,51 @@ class DeclarationsSubjectTest { @Test fun generatesTopLevel() { - assertThat(decls).generatesTopLevel( - """ + assertThat(decls) + .generatesTopLevel(""" import kotlin.Int public val topLevel: Int - """ - ) + """) } @Test fun generatesEnclosed() { - assertThat(decls).generatesEnclosed( - """ + assertThat(decls) + .generatesEnclosed(""" import kotlin.Int public val enclosed: Int - """ - ) + """) } @Test fun generatesTopLevelFailure() { - expectFailureAbout( - declarationsSubjectFactory - ) { it.that(decls).generatesTopLevel("") } + expectFailureAbout(declarationsSubjectFactory) { it.that(decls).generatesTopLevel("") } } @Test fun generatesEnclosedFailure() { - expectFailureAbout( - declarationsSubjectFactory - ) { it.that(decls).generatesEnclosed("") } + expectFailureAbout(declarationsSubjectFactory) { it.that(decls).generatesEnclosed("") } } @Test fun generatesNoTopLevel() { - assertThat( - declarations { - addProperty(enclosedProperty) - } - ).generatesNoTopLevelMembers() + assertThat(declarations { addProperty(enclosedProperty) }).generatesNoTopLevelMembers() } @Test fun generatesNoEnclosed() { - assertThat( - declarations { - addTopLevelProperty(topLevelProperty) - } - ).generatesNoEnclosedMembers() + assertThat(declarations { addTopLevelProperty(topLevelProperty) }).generatesNoEnclosedMembers() } @Test fun generatesNoTopLevelFailure() { - expectFailureAbout( - declarationsSubjectFactory - ) { it.that(decls).generatesNoTopLevelMembers() } + expectFailureAbout(declarationsSubjectFactory) { it.that(decls).generatesNoTopLevelMembers() } } @Test fun generatesNoEnclosedFailure() { - expectFailureAbout( - declarationsSubjectFactory - ) { it.that(decls).generatesNoEnclosedMembers() } + expectFailureAbout(declarationsSubjectFactory) { it.that(decls).generatesNoEnclosedMembers() } } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubjectTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubjectTest.kt index 4d8a81ce..1c336fde 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubjectTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FileSpecSubjectTest.kt @@ -34,22 +34,16 @@ class FileSpecSubjectTest { @Test fun generates() { - assertThat(fileSpec).generates( - """ + assertThat(fileSpec).generates(""" import kotlin.Int public val bar: Int - """ - ) + """) } @Test fun generatesFailure() { - expectFailureAbout( - fileSpecs - ) { it.that(fileSpec).generates("") } - expectFailureAbout( - fileSpecs - ) { it.that(fileSpec).generates("object Foo") } + expectFailureAbout(fileSpecs) { it.that(fileSpec).generates("") } + expectFailureAbout(fileSpecs) { it.that(fileSpec).generates("object Foo") } } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubjectTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubjectTest.kt index 98ce95d6..08530aff 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubjectTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/FunSpecSubjectTest.kt @@ -17,7 +17,6 @@ package io.grpc.kotlin.generator.protoc.testing import com.google.common.truth.ExpectFailure -import com.google.common.truth.ExpectFailure.SimpleSubjectBuilderCallback import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.INT import com.squareup.kotlinpoet.ParameterSpec @@ -30,29 +29,23 @@ import org.junit.runners.JUnit4 @RunWith(JUnit4::class) class FunSpecSubjectTest { private val funSpec = - FunSpec - .builder("foo") + FunSpec.builder("foo") .addParameter(ParameterSpec.builder("bar", INT).build()) .returns(String::class.asTypeName()) .build() @Test fun generates() { - assertThat(funSpec).generates( - """ + assertThat(funSpec) + .generates(""" public fun foo(bar: kotlin.Int): kotlin.String { } - """ - ) + """) } @Test fun generatesFailure() { - ExpectFailure.expectFailureAbout( - funSpecs - ) { it.that(funSpec).generates("") } - ExpectFailure.expectFailureAbout( - funSpecs - ) { it.that(funSpec).generates("fun bar") } + ExpectFailure.expectFailureAbout(funSpecs) { it.that(funSpec).generates("") } + ExpectFailure.expectFailureAbout(funSpecs) { it.that(funSpec).generates("fun bar") } } } diff --git a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubjectTest.kt b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubjectTest.kt index c1e56dd4..76e2308a 100644 --- a/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubjectTest.kt +++ b/compiler/src/test/java/io/grpc/kotlin/generator/protoc/testing/TypeSpecSubjectTest.kt @@ -28,29 +28,21 @@ import org.junit.runners.JUnit4 @RunWith(JUnit4::class) class TypeSpecSubjectTest { private val typeSpec = - TypeSpec - .objectBuilder("Foo") - .addProperty(PropertySpec.builder("bar", INT).build()) - .build() + TypeSpec.objectBuilder("Foo").addProperty(PropertySpec.builder("bar", INT).build()).build() @Test fun generates() { - assertThat(typeSpec).generates( - """ + assertThat(typeSpec) + .generates(""" public object Foo { public val bar: kotlin.Int } - """ - ) + """) } @Test fun generatesFailure() { - expectFailureAbout( - typeSpecs - ) { it.that(typeSpec).generates("") } - expectFailureAbout( - typeSpecs - ) { it.that(typeSpec).generates("public object Foo") } + expectFailureAbout(typeSpecs) { it.that(typeSpec).generates("") } + expectFailureAbout(typeSpecs) { it.that(typeSpec).generates("public object Foo") } } } diff --git a/compiler/src/test/proto/helloworld/helloworld.proto b/compiler/src/test/proto/helloworld/helloworld.proto index f5eaf3dd..821abeeb 100644 --- a/compiler/src/test/proto/helloworld/helloworld.proto +++ b/compiler/src/test/proto/helloworld/helloworld.proto @@ -16,8 +16,8 @@ syntax = "proto3"; package helloworld; option java_multiple_files = true; -option java_package = "io.grpc.examples.helloworld"; option java_outer_classname = "HelloWorldProto"; +option java_package = "io.grpc.examples.helloworld"; // The greeting service definition. service Greeter { @@ -46,4 +46,4 @@ message MultiHelloRequest { // The response message containing the greetings message HelloReply { string message = 1; -} \ No newline at end of file +} diff --git a/compiler/src/test/proto/testing/BUILD.bazel b/compiler/src/test/proto/testing/BUILD.bazel index f9490b04..fa14f0a7 100644 --- a/compiler/src/test/proto/testing/BUILD.bazel +++ b/compiler/src/test/proto/testing/BUILD.bazel @@ -1,5 +1,5 @@ -load("@protobuf//bazel:proto_library.bzl", "proto_library") load("@protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@protobuf//bazel:proto_library.bzl", "proto_library") licenses(["notice"]) diff --git a/compiler/src/test/proto/testing/has_explicit_outer_class_name.proto b/compiler/src/test/proto/testing/has_explicit_outer_class_name.proto index e8976754..f9a6590b 100644 --- a/compiler/src/test/proto/testing/has_explicit_outer_class_name.proto +++ b/compiler/src/test/proto/testing/has_explicit_outer_class_name.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package testing; -option java_package = "io.grpc.kotlin.generator.protoc.testproto"; option java_outer_classname = "MyExplicitOuterClassName"; +option java_package = "io.grpc.kotlin.generator.protoc.testproto"; message HasExplicitOuterClassName {} diff --git a/compiler/src/test/proto/testing/rpc_name_contains_underscore.proto b/compiler/src/test/proto/testing/rpc_name_contains_underscore.proto index 196e091b..57981186 100644 --- a/compiler/src/test/proto/testing/rpc_name_contains_underscore.proto +++ b/compiler/src/test/proto/testing/rpc_name_contains_underscore.proto @@ -1,9 +1,10 @@ syntax = "proto3"; package io.grpc.testing.underscore; + option java_multiple_files = true; service NameContainsUnderscore { - rpc say_hello (HelloRequest) returns (HelloReply); - rpc say_hello_again (HelloRequest) returns (HelloReply); + rpc say_hello(HelloRequest) returns (HelloReply); + rpc say_hello_again(HelloRequest) returns (HelloReply); } message HelloRequest {} -message HelloReply {} \ No newline at end of file +message HelloReply {} diff --git a/compiler/src/test/proto/testing/service_name_conflicts_with_file.proto b/compiler/src/test/proto/testing/service_name_conflicts_with_file.proto index c3c06ce9..ada616c4 100644 --- a/compiler/src/test/proto/testing/service_name_conflicts_with_file.proto +++ b/compiler/src/test/proto/testing/service_name_conflicts_with_file.proto @@ -29,4 +29,4 @@ message HelloRequest { // The response message containing the greetings message HelloReply { string message = 1; -} \ No newline at end of file +} diff --git a/compiler/src/test/proto/testing/test_proto3_optional.proto b/compiler/src/test/proto/testing/test_proto3_optional.proto index 4328abf1..229def33 100644 --- a/compiler/src/test/proto/testing/test_proto3_optional.proto +++ b/compiler/src/test/proto/testing/test_proto3_optional.proto @@ -5,7 +5,5 @@ package io.grpc.testing; // --experimental_allow_proto3_optional or the filename (or a directory name) of the proto // file contains the string 'test_proto3_optional' message OptionalProto3 { - optional string optional_field = 1; - } diff --git a/examples/android/build.gradle.kts b/examples/android/build.gradle.kts index 3a9263a7..071d174b 100644 --- a/examples/android/build.gradle.kts +++ b/examples/android/build.gradle.kts @@ -1,49 +1,43 @@ plugins { - alias(libs.plugins.android.application) - alias(libs.plugins.kotlin.android) + alias(libs.plugins.android.application) + alias(libs.plugins.kotlin.android) } dependencies { - implementation(project(":stub-android")) + implementation(project(":stub-android")) - implementation(libs.androidx.activity.compose) - implementation(libs.androidx.compose.foundation.layout) - implementation(libs.androidx.compose.material) - implementation(libs.androidx.compose.runtime) - implementation(libs.androidx.compose.ui) + implementation(libs.androidx.activity.compose) + implementation(libs.androidx.compose.foundation.layout) + implementation(libs.androidx.compose.material) + implementation(libs.androidx.compose.runtime) + implementation(libs.androidx.compose.ui) - runtimeOnly(libs.grpc.okhttp) + runtimeOnly(libs.grpc.okhttp) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } android { - compileSdk = 34 - buildToolsVersion = "34.0.0" - namespace = "io.grpc.examples.helloworld" - - defaultConfig { - applicationId = "io.grpc.examples.hello" - minSdk = 26 - targetSdk = 34 - versionCode = 1 - versionName = "1.0" - - val serverUrl: String? by project - if (serverUrl != null) { - resValue("string", "server_url", serverUrl!!) - } else { - resValue("string", "server_url", "http://10.0.2.2:50051/") - } + compileSdk = 34 + buildToolsVersion = "34.0.0" + namespace = "io.grpc.examples.helloworld" + + defaultConfig { + applicationId = "io.grpc.examples.hello" + minSdk = 26 + targetSdk = 34 + versionCode = 1 + versionName = "1.0" + + val serverUrl: String? by project + if (serverUrl != null) { + resValue("string", "server_url", serverUrl!!) + } else { + resValue("string", "server_url", "http://10.0.2.2:50051/") } + } - buildFeatures { - compose = true - } + buildFeatures { compose = true } - composeOptions { - kotlinCompilerExtensionVersion = libs.androidx.compose.compiler.get().version - } + composeOptions { kotlinCompilerExtensionVersion = libs.androidx.compose.compiler.get().version } } diff --git a/examples/android/src/main/kotlin/io/grpc/examples/helloworld/MainActivity.kt b/examples/android/src/main/kotlin/io/grpc/examples/helloworld/MainActivity.kt index c111eb99..02fef1e9 100644 --- a/examples/android/src/main/kotlin/io/grpc/examples/helloworld/MainActivity.kt +++ b/examples/android/src/main/kotlin/io/grpc/examples/helloworld/MainActivity.kt @@ -24,84 +24,83 @@ import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.input.TextFieldValue import androidx.compose.ui.unit.dp import io.grpc.ManagedChannelBuilder +import java.io.Closeable import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.asExecutor import kotlinx.coroutines.launch -import java.io.Closeable class MainActivity : ComponentActivity() { - private val uri by lazy { Uri.parse(resources.getString(R.string.server_url)) } - private val greeterService by lazy { GreeterRCP(uri) } + private val uri by lazy { Uri.parse(resources.getString(R.string.server_url)) } + private val greeterService by lazy { GreeterRCP(uri) } - override fun onCreate(savedInstanceState: Bundle?) { - super.onCreate(savedInstanceState) + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) - setContent { - Surface(color = MaterialTheme.colors.background) { - Greeter(greeterService) - } - } - } + setContent { Surface(color = MaterialTheme.colors.background) { Greeter(greeterService) } } + } - override fun onDestroy() { - super.onDestroy() - greeterService.close() - } + override fun onDestroy() { + super.onDestroy() + greeterService.close() + } } class GreeterRCP(uri: Uri) : Closeable { - val responseState = mutableStateOf("") + val responseState = mutableStateOf("") - private val channel = let { - println("Connecting to ${uri.host}:${uri.port}") + private val channel = let { + println("Connecting to ${uri.host}:${uri.port}") - val builder = ManagedChannelBuilder.forAddress(uri.host, uri.port) - if (uri.scheme == "https") { - builder.useTransportSecurity() - } else { - builder.usePlaintext() - } - - builder.executor(Dispatchers.IO.asExecutor()).build() + val builder = ManagedChannelBuilder.forAddress(uri.host, uri.port) + if (uri.scheme == "https") { + builder.useTransportSecurity() + } else { + builder.usePlaintext() } - private val greeter = GreeterGrpcKt.GreeterCoroutineStub(channel) - - suspend fun sayHello(name: String) { - try { - val request = helloRequest { this.name = name } - val response = greeter.sayHello(request) - responseState.value = response.message - } catch (e: Exception) { - responseState.value = e.message ?: "Unknown Error" - e.printStackTrace() - } - } + builder.executor(Dispatchers.IO.asExecutor()).build() + } - override fun close() { - channel.shutdownNow() + private val greeter = GreeterGrpcKt.GreeterCoroutineStub(channel) + + suspend fun sayHello(name: String) { + try { + val request = helloRequest { this.name = name } + val response = greeter.sayHello(request) + responseState.value = response.message + } catch (e: Exception) { + responseState.value = e.message ?: "Unknown Error" + e.printStackTrace() } + } + + override fun close() { + channel.shutdownNow() + } } @Composable fun Greeter(greeterRCP: GreeterRCP) { - val scope = rememberCoroutineScope() + val scope = rememberCoroutineScope() - val nameState = remember { mutableStateOf(TextFieldValue()) } + val nameState = remember { mutableStateOf(TextFieldValue()) } - Column(Modifier.fillMaxWidth().fillMaxHeight(), Arrangement.Top, Alignment.CenterHorizontally) { - Text(stringResource(R.string.name_hint), modifier = Modifier.padding(top = 10.dp)) - OutlinedTextField(nameState.value, { nameState.value = it }) + Column(Modifier.fillMaxWidth().fillMaxHeight(), Arrangement.Top, Alignment.CenterHorizontally) { + Text(stringResource(R.string.name_hint), modifier = Modifier.padding(top = 10.dp)) + OutlinedTextField(nameState.value, { nameState.value = it }) - Button({ scope.launch { greeterRCP.sayHello(nameState.value.text) } }, Modifier.padding(10.dp)) { - Text(stringResource(R.string.send_request)) + Button( + { scope.launch { greeterRCP.sayHello(nameState.value.text) } }, + Modifier.padding(10.dp) + ) { + Text(stringResource(R.string.send_request)) } - if (greeterRCP.responseState.value.isNotEmpty()) { - Text(stringResource(R.string.server_response), modifier = Modifier.padding(top = 10.dp)) - Text(greeterRCP.responseState.value) - } + if (greeterRCP.responseState.value.isNotEmpty()) { + Text(stringResource(R.string.server_response), modifier = Modifier.padding(top = 10.dp)) + Text(greeterRCP.responseState.value) } + } } diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 464ef492..8c858484 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -1,10 +1,9 @@ plugins { - alias(libs.plugins.android.application) apply false - alias(libs.plugins.android.library) apply false - alias(libs.plugins.protobuf) apply false - alias(libs.plugins.kotlin.jvm) apply false - alias(libs.plugins.kotlin.android) apply false + alias(libs.plugins.android.application) apply false + alias(libs.plugins.android.library) apply false + alias(libs.plugins.protobuf) apply false + alias(libs.plugins.kotlin.jvm) apply false + alias(libs.plugins.kotlin.android) apply false } tasks.create("assemble").dependsOn(":server:installDist") - diff --git a/examples/client/build.gradle.kts b/examples/client/build.gradle.kts index 2be41339..a982ad5a 100644 --- a/examples/client/build.gradle.kts +++ b/examples/client/build.gradle.kts @@ -1,61 +1,59 @@ plugins { - application - alias(libs.plugins.kotlin.jvm) + application + alias(libs.plugins.kotlin.jvm) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } dependencies { - implementation(project(":stub")) - runtimeOnly(libs.grpc.netty) + implementation(project(":stub")) + runtimeOnly(libs.grpc.netty) } tasks.register("HelloWorldClient") { - dependsOn("classes") - classpath = sourceSets["main"].runtimeClasspath - mainClass.set("io.grpc.examples.helloworld.HelloWorldClientKt") + dependsOn("classes") + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("io.grpc.examples.helloworld.HelloWorldClientKt") } tasks.register("RouteGuideClient") { - dependsOn("classes") - classpath = sourceSets["main"].runtimeClasspath - mainClass.set("io.grpc.examples.routeguide.RouteGuideClientKt") + dependsOn("classes") + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("io.grpc.examples.routeguide.RouteGuideClientKt") } tasks.register("AnimalsClient") { - dependsOn("classes") - classpath = sourceSets["main"].runtimeClasspath - mainClass.set("io.grpc.examples.animals.AnimalsClientKt") + dependsOn("classes") + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("io.grpc.examples.animals.AnimalsClientKt") } val helloWorldClientStartScripts = - tasks.register("helloWorldClientStartScripts") { - mainClass.set("io.grpc.examples.helloworld.HelloWorldClientKt") - applicationName = "hello-world-client" - outputDir = tasks.named("startScripts").get().outputDir - classpath = tasks.named("startScripts").get().classpath - } + tasks.register("helloWorldClientStartScripts") { + mainClass.set("io.grpc.examples.helloworld.HelloWorldClientKt") + applicationName = "hello-world-client" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("startScripts").get().classpath + } val routeGuideClientStartScripts = - tasks.register("routeGuideClientStartScripts") { - mainClass.set("io.grpc.examples.routeguide.RouteGuideClientKt") - applicationName = "route-guide-client" - outputDir = tasks.named("startScripts").get().outputDir - classpath = tasks.named("startScripts").get().classpath - } + tasks.register("routeGuideClientStartScripts") { + mainClass.set("io.grpc.examples.routeguide.RouteGuideClientKt") + applicationName = "route-guide-client" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("startScripts").get().classpath + } val animalsClientStartScripts = - tasks.register("animalsClientStartScripts") { - mainClass.set("io.grpc.examples.animals.AnimalsClientKt") - applicationName = "animals-client" - outputDir = tasks.named("startScripts").get().outputDir - classpath = tasks.named("startScripts").get().classpath - } + tasks.register("animalsClientStartScripts") { + mainClass.set("io.grpc.examples.animals.AnimalsClientKt") + applicationName = "animals-client" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("startScripts").get().classpath + } tasks.named("startScripts") { - dependsOn(helloWorldClientStartScripts) - dependsOn(routeGuideClientStartScripts) - dependsOn(animalsClientStartScripts) + dependsOn(helloWorldClientStartScripts) + dependsOn(routeGuideClientStartScripts) + dependsOn(animalsClientStartScripts) } diff --git a/examples/client/src/main/kotlin/io/grpc/examples/animals/AnimalsClient.kt b/examples/client/src/main/kotlin/io/grpc/examples/animals/AnimalsClient.kt index 0855f8e6..99e8a2e4 100644 --- a/examples/client/src/main/kotlin/io/grpc/examples/animals/AnimalsClient.kt +++ b/examples/client/src/main/kotlin/io/grpc/examples/animals/AnimalsClient.kt @@ -22,59 +22,59 @@ import java.io.Closeable import java.util.concurrent.TimeUnit class AnimalsClient(private val channel: ManagedChannel) : Closeable { - private val dogStub: DogGrpcKt.DogCoroutineStub by lazy { DogGrpcKt.DogCoroutineStub(channel) } - private val pigStub: PigGrpcKt.PigCoroutineStub by lazy { PigGrpcKt.PigCoroutineStub(channel) } - private val sheepStub: SheepGrpcKt.SheepCoroutineStub by lazy { SheepGrpcKt.SheepCoroutineStub(channel) } + private val dogStub: DogGrpcKt.DogCoroutineStub by lazy { DogGrpcKt.DogCoroutineStub(channel) } + private val pigStub: PigGrpcKt.PigCoroutineStub by lazy { PigGrpcKt.PigCoroutineStub(channel) } + private val sheepStub: SheepGrpcKt.SheepCoroutineStub by lazy { + SheepGrpcKt.SheepCoroutineStub(channel) + } - suspend fun bark() { - val request = barkRequest {} - val response = dogStub.bark(request) - println("Received: ${response.message}") - } + suspend fun bark() { + val request = barkRequest {} + val response = dogStub.bark(request) + println("Received: ${response.message}") + } - suspend fun oink() { - val request = oinkRequest {} - val response = pigStub.oink(request) - println("Received: ${response.message}") - } + suspend fun oink() { + val request = oinkRequest {} + val response = pigStub.oink(request) + println("Received: ${response.message}") + } - suspend fun baa() { - val request = baaRequest {} - val response = sheepStub.baa(request) - println("Received: ${response.message}") - } + suspend fun baa() { + val request = baaRequest {} + val response = sheepStub.baa(request) + println("Received: ${response.message}") + } - override fun close() { - channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) - } + override fun close() { + channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) + } } -/** - * Talk to the animals. Fluent in dog, pig and sheep. - */ +/** Talk to the animals. Fluent in dog, pig and sheep. */ suspend fun main(args: Array) { - val usage = "usage: animals_client [{dog|pig|sheep} ...]" + val usage = "usage: animals_client [{dog|pig|sheep} ...]" - if (args.isEmpty()) { - println("No animals specified.") - println(usage) - } + if (args.isEmpty()) { + println("No animals specified.") + println(usage) + } - val port = 50051 + val port = 50051 - val channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build() + val channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build() - val client = AnimalsClient(channel) + val client = AnimalsClient(channel) - args.forEach { - when (it) { - "dog" -> client.bark() - "pig" -> client.oink() - "sheep" -> client.baa() - else -> { - println("Unknown animal type: \"$it\". Try \"dog\", \"pig\" or \"sheep\".") - println(usage) - } - } + args.forEach { + when (it) { + "dog" -> client.bark() + "pig" -> client.oink() + "sheep" -> client.baa() + else -> { + println("Unknown animal type: \"$it\". Try \"dog\", \"pig\" or \"sheep\".") + println(usage) + } } + } } diff --git a/examples/client/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel b/examples/client/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel index 1cde758a..1d7bd218 100644 --- a/examples/client/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel +++ b/examples/client/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel @@ -11,7 +11,7 @@ kt_jvm_binary( deps = [ "//examples/protos/src/main/proto/io/grpc/examples/animals:animals_kt_grpc", "//examples/protos/src/main/proto/io/grpc/examples/animals:animals_kt_proto", - "@protobuf//:protobuf_java_util", "@grpc-java//netty", + "@protobuf//:protobuf_java_util", ], ) diff --git a/examples/client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt b/examples/client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt index 2242584c..3c6766d4 100644 --- a/examples/client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt +++ b/examples/client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt @@ -23,30 +23,27 @@ import java.io.Closeable import java.util.concurrent.TimeUnit class HelloWorldClient(private val channel: ManagedChannel) : Closeable { - private val stub: GreeterCoroutineStub = GreeterCoroutineStub(channel) + private val stub: GreeterCoroutineStub = GreeterCoroutineStub(channel) - suspend fun greet(name: String) { - val request = helloRequest { this.name = name } - val response = stub.sayHello(request) - println("Received: ${response.message}") - } + suspend fun greet(name: String) { + val request = helloRequest { this.name = name } + val response = stub.sayHello(request) + println("Received: ${response.message}") + } - override fun close() { - channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) - } + override fun close() { + channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) + } } -/** - * Greeter, uses first argument as name to greet if present; - * greets "world" otherwise. - */ +/** Greeter, uses first argument as name to greet if present; greets "world" otherwise. */ suspend fun main(args: Array) { - val port = System.getenv("PORT")?.toInt() ?: 50051 + val port = System.getenv("PORT")?.toInt() ?: 50051 - val channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build() + val channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build() - val client = HelloWorldClient(channel) + val client = HelloWorldClient(channel) - val user = args.singleOrNull() ?: "world" - client.greet(user) + val user = args.singleOrNull() ?: "world" + client.greet(user) } diff --git a/examples/client/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel b/examples/client/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel index 915bd1af..23fb86c9 100644 --- a/examples/client/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel +++ b/examples/client/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel @@ -13,7 +13,7 @@ kt_jvm_binary( "//examples/protos/src/main/proto/io/grpc/examples/routeguide:route_guide_kt_grpc", "//examples/protos/src/main/proto/io/grpc/examples/routeguide:route_guide_kt_proto", "//examples/stub/src/main/kotlin/io/grpc/examples/routeguide:route_guide_stub", - "@protobuf//:protobuf_java_util", "@grpc-java//netty", + "@protobuf//:protobuf_java_util", ], ) diff --git a/examples/client/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideClient.kt b/examples/client/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideClient.kt index 76f76f28..78740184 100644 --- a/examples/client/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideClient.kt +++ b/examples/client/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideClient.kt @@ -19,142 +19,136 @@ package io.grpc.examples.routeguide import io.grpc.ManagedChannel import io.grpc.ManagedChannelBuilder import io.grpc.examples.routeguide.RouteGuideGrpcKt.RouteGuideCoroutineStub -import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.collect -import kotlinx.coroutines.flow.flow import java.io.Closeable import java.util.concurrent.TimeUnit import kotlin.random.Random import kotlin.random.nextLong +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.flow class RouteGuideClient(private val channel: ManagedChannel) : Closeable { - private val random = Random(314159) - private val stub = RouteGuideCoroutineStub(channel) - - override fun close() { - channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) + private val random = Random(314159) + private val stub = RouteGuideCoroutineStub(channel) + + override fun close() { + channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) + } + + suspend fun getFeature( + latitude: Int, + longitude: Int, + ) { + println("*** GetFeature: lat=$latitude lon=$longitude") + + val request = point(latitude, longitude) + val feature = stub.getFeature(request) + + if (feature.exists()) { + println("Found feature called \"${feature.name}\" at ${feature.location.toStr()}") + } else { + println("Found no feature at ${request.toStr()}") } - - suspend fun getFeature( - latitude: Int, - longitude: Int, - ) { - println("*** GetFeature: lat=$latitude lon=$longitude") - - val request = point(latitude, longitude) - val feature = stub.getFeature(request) - - if (feature.exists()) { - println("Found feature called \"${feature.name}\" at ${feature.location.toStr()}") - } else { - println("Found no feature at ${request.toStr()}") - } + } + + suspend fun listFeatures( + lowLat: Int, + lowLon: Int, + hiLat: Int, + hiLon: Int, + ) { + println("*** ListFeatures: lowLat=$lowLat lowLon=$lowLon hiLat=$hiLat liLon=$hiLon") + + val request = rectangle { + lo = point(lowLat, lowLon) + hi = point(hiLat, hiLon) } - - suspend fun listFeatures( - lowLat: Int, - lowLon: Int, - hiLat: Int, - hiLon: Int, - ) { - println("*** ListFeatures: lowLat=$lowLat lowLon=$lowLon hiLat=$hiLat liLon=$hiLon") - - val request = - rectangle { - lo = point(lowLat, lowLon) - hi = point(hiLat, hiLon) - } - var i = 1 - stub.listFeatures(request).collect { feature -> - println("Result #${i++}: $feature") - } + var i = 1 + stub.listFeatures(request).collect { feature -> println("Result #${i++}: $feature") } + } + + suspend fun recordRoute(points: Flow) { + println("*** RecordRoute") + val summary = stub.recordRoute(points) + println("Finished trip with ${summary.pointCount} points.") + println("Passed ${summary.featureCount} features.") + println("Travelled ${summary.distance} meters.") + val duration = summary.elapsedTime.seconds + println("It took $duration seconds.") + } + + fun generateRoutePoints( + features: List, + numPoints: Int, + ): Flow = flow { + for (i in 1..numPoints) { + val feature = features.random(random) + println("Visiting point ${feature.location.toStr()}") + emit(feature.location) + delay(timeMillis = random.nextLong(500L..1500L)) } + } - suspend fun recordRoute(points: Flow) { - println("*** RecordRoute") - val summary = stub.recordRoute(points) - println("Finished trip with ${summary.pointCount} points.") - println("Passed ${summary.featureCount} features.") - println("Travelled ${summary.distance} meters.") - val duration = summary.elapsedTime.seconds - println("It took $duration seconds.") + suspend fun routeChat() { + println("*** RouteChat") + val requests = generateOutgoingNotes() + stub.routeChat(requests).collect { note -> + println("Got message \"${note.message}\" at ${note.location.toStr()}") } - - fun generateRoutePoints( - features: List, - numPoints: Int, - ): Flow = - flow { - for (i in 1..numPoints) { - val feature = features.random(random) - println("Visiting point ${feature.location.toStr()}") - emit(feature.location) - delay(timeMillis = random.nextLong(500L..1500L)) - } - } - - suspend fun routeChat() { - println("*** RouteChat") - val requests = generateOutgoingNotes() - stub.routeChat(requests).collect { note -> - println("Got message \"${note.message}\" at ${note.location.toStr()}") - } - println("Finished RouteChat") + println("Finished RouteChat") + } + + private fun generateOutgoingNotes(): Flow = flow { + val notes = + listOf( + routeNote { + message = "First message" + location = point(0, 0) + }, + routeNote { + message = "Second message" + location = point(0, 0) + }, + routeNote { + message = "Third message" + location = point(10000000, 0) + }, + routeNote { + message = "Fourth message" + location = point(10000000, 10000000) + }, + routeNote { + message = "Last message" + location = point(0, 0) + }, + ) + for (note in notes) { + println("Sending message \"${note.message}\" at ${note.location.toStr()}") + emit(note) + delay(500) } - - private fun generateOutgoingNotes(): Flow = - flow { - val notes = - listOf( - routeNote { - message = "First message" - location = point(0, 0) - }, - routeNote { - message = "Second message" - location = point(0, 0) - }, - routeNote { - message = "Third message" - location = point(10000000, 0) - }, - routeNote { - message = "Fourth message" - location = point(10000000, 10000000) - }, - routeNote { - message = "Last message" - location = point(0, 0) - }, - ) - for (note in notes) { - println("Sending message \"${note.message}\" at ${note.location.toStr()}") - emit(note) - delay(500) - } - } + } } suspend fun main() { - val features = Database.features() + val features = Database.features() - val channel = ManagedChannelBuilder.forAddress("localhost", 8980).usePlaintext().build() + val channel = ManagedChannelBuilder.forAddress("localhost", 8980).usePlaintext().build() - RouteGuideClient(channel).use { - it.getFeature(409146138, -746188906) - it.getFeature(0, 0) - it.listFeatures(400000000, -750000000, 420000000, -730000000) - it.recordRoute(it.generateRoutePoints(features, 10)) - it.routeChat() - } + RouteGuideClient(channel).use { + it.getFeature(409146138, -746188906) + it.getFeature(0, 0) + it.listFeatures(400000000, -750000000, 420000000, -730000000) + it.recordRoute(it.generateRoutePoints(features, 10)) + it.routeChat() + } } private fun point( - lat: Int, - lon: Int, -): Point = - point { - latitude = lat - longitude = lon - } + lat: Int, + lon: Int, +): Point = point { + latitude = lat + longitude = lon +} diff --git a/examples/native-client/build.gradle.kts b/examples/native-client/build.gradle.kts index 63756854..9af65568 100644 --- a/examples/native-client/build.gradle.kts +++ b/examples/native-client/build.gradle.kts @@ -1,34 +1,31 @@ plugins { - application - alias(libs.plugins.kotlin.jvm) - alias(libs.plugins.palantir.graal) + application + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.palantir.graal) } -kotlin { - jvmToolchain(11) -} +kotlin { jvmToolchain(11) } dependencies { - implementation(project(":stub-lite")) + implementation(project(":stub-lite")) - runtimeOnly(libs.grpc.okhttp) + runtimeOnly(libs.grpc.okhttp) } -application { - mainClass.set("io.grpc.examples.helloworld.HelloWorldClientKt") -} +application { mainClass.set("io.grpc.examples.helloworld.HelloWorldClientKt") } // todo: add graalvm-config-create task // ./gradlew :native-client:install -// JAVA_HOME=~/.gradle/caches/com.palantir.graal/22.3.3/11/graalvm-ce-java11-22.3.3 JAVA_OPTS=-agentlib:native-image-agent=config-output-dir=native-client/src/main/resources/META-INF/native-image native-client/build/install/native-client/bin/native-client +// JAVA_HOME=~/.gradle/caches/com.palantir.graal/22.3.3/11/graalvm-ce-java11-22.3.3 +// JAVA_OPTS=-agentlib:native-image-agent=config-output-dir=native-client/src/main/resources/META-INF/native-image native-client/build/install/native-client/bin/native-client graal { - graalVersion("22.3.3") - javaVersion("11") - mainClass(application.mainClass.get()) - outputName("hello-world") - option("--verbose") - option("--no-fallback") - option("-H:+ReportExceptionStackTraces") - option("-H:+PrintClassInitialization") + graalVersion("22.3.3") + javaVersion("11") + mainClass(application.mainClass.get()) + outputName("hello-world") + option("--verbose") + option("--no-fallback") + option("-H:+ReportExceptionStackTraces") + option("-H:+PrintClassInitialization") } diff --git a/examples/native-client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt b/examples/native-client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt index 2242584c..3c6766d4 100644 --- a/examples/native-client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt +++ b/examples/native-client/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldClient.kt @@ -23,30 +23,27 @@ import java.io.Closeable import java.util.concurrent.TimeUnit class HelloWorldClient(private val channel: ManagedChannel) : Closeable { - private val stub: GreeterCoroutineStub = GreeterCoroutineStub(channel) + private val stub: GreeterCoroutineStub = GreeterCoroutineStub(channel) - suspend fun greet(name: String) { - val request = helloRequest { this.name = name } - val response = stub.sayHello(request) - println("Received: ${response.message}") - } + suspend fun greet(name: String) { + val request = helloRequest { this.name = name } + val response = stub.sayHello(request) + println("Received: ${response.message}") + } - override fun close() { - channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) - } + override fun close() { + channel.shutdown().awaitTermination(5, TimeUnit.SECONDS) + } } -/** - * Greeter, uses first argument as name to greet if present; - * greets "world" otherwise. - */ +/** Greeter, uses first argument as name to greet if present; greets "world" otherwise. */ suspend fun main(args: Array) { - val port = System.getenv("PORT")?.toInt() ?: 50051 + val port = System.getenv("PORT")?.toInt() ?: 50051 - val channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build() + val channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build() - val client = HelloWorldClient(channel) + val client = HelloWorldClient(channel) - val user = args.singleOrNull() ?: "world" - client.greet(user) + val user = args.singleOrNull() ?: "world" + client.greet(user) } diff --git a/examples/protos/build.gradle.kts b/examples/protos/build.gradle.kts index a42e252c..ba3dc853 100644 --- a/examples/protos/build.gradle.kts +++ b/examples/protos/build.gradle.kts @@ -2,10 +2,6 @@ // Note: We use the java-library plugin to get the protos into the artifact for this subproject // because there doesn't seem to be an better way. -plugins { - `java-library` -} +plugins { `java-library` } -java { - sourceSets.getByName("main").resources.srcDir("src/main/proto") -} +java { sourceSets.getByName("main").resources.srcDir("src/main/proto") } diff --git a/examples/protos/src/main/proto/io/grpc/examples/animals/BUILD.bazel b/examples/protos/src/main/proto/io/grpc/examples/animals/BUILD.bazel index b31f3559..6ed44cf8 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/animals/BUILD.bazel +++ b/examples/protos/src/main/proto/io/grpc/examples/animals/BUILD.bazel @@ -1,7 +1,7 @@ -load("@protobuf//bazel:proto_library.bzl", "proto_library") -load("//:kt_jvm_grpc.bzl", "kt_jvm_grpc_library", "kt_jvm_proto_library") load("@protobuf//bazel:java_lite_proto_library.bzl", "java_lite_proto_library") load("@protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@protobuf//bazel:proto_library.bzl", "proto_library") +load("//:kt_jvm_grpc.bzl", "kt_jvm_grpc_library", "kt_jvm_proto_library") licenses(["notice"]) diff --git a/examples/protos/src/main/proto/io/grpc/examples/animals/dog.proto b/examples/protos/src/main/proto/io/grpc/examples/animals/dog.proto index 70999a27..a857a6a9 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/animals/dog.proto +++ b/examples/protos/src/main/proto/io/grpc/examples/animals/dog.proto @@ -13,18 +13,17 @@ // limitations under the License. syntax = "proto3"; +package animals; + option java_multiple_files = true; -option java_package = "io.grpc.examples.animals"; option java_outer_classname = "DogProto"; - -package animals; +option java_package = "io.grpc.examples.animals"; service Dog { - rpc Bark (BarkRequest) returns (BarkReply) {} + rpc Bark(BarkRequest) returns (BarkReply) {} } -message BarkRequest { -} +message BarkRequest {} message BarkReply { string message = 1; diff --git a/examples/protos/src/main/proto/io/grpc/examples/animals/pig.proto b/examples/protos/src/main/proto/io/grpc/examples/animals/pig.proto index 2ed4027d..62d984af 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/animals/pig.proto +++ b/examples/protos/src/main/proto/io/grpc/examples/animals/pig.proto @@ -13,18 +13,17 @@ // limitations under the License. syntax = "proto3"; +package animals; + option java_multiple_files = true; -option java_package = "io.grpc.examples.animals"; option java_outer_classname = "PigProto"; - -package animals; +option java_package = "io.grpc.examples.animals"; service Pig { - rpc Oink (OinkRequest) returns (OinkReply) {} + rpc Oink(OinkRequest) returns (OinkReply) {} } -message OinkRequest { -} +message OinkRequest {} message OinkReply { string message = 1; diff --git a/examples/protos/src/main/proto/io/grpc/examples/animals/sheep.proto b/examples/protos/src/main/proto/io/grpc/examples/animals/sheep.proto index 4e649fdd..5c2c047f 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/animals/sheep.proto +++ b/examples/protos/src/main/proto/io/grpc/examples/animals/sheep.proto @@ -13,18 +13,17 @@ // limitations under the License. syntax = "proto3"; +package animals; + option java_multiple_files = true; -option java_package = "io.grpc.examples.animals"; option java_outer_classname = "SheepProto"; - -package animals; +option java_package = "io.grpc.examples.animals"; service Sheep { - rpc Baa (BaaRequest) returns (BaaReply) {} + rpc Baa(BaaRequest) returns (BaaReply) {} } -message BaaRequest { -} +message BaaRequest {} message BaaReply { string message = 1; diff --git a/examples/protos/src/main/proto/io/grpc/examples/helloworld/BUILD.bazel b/examples/protos/src/main/proto/io/grpc/examples/helloworld/BUILD.bazel index 769ac774..831b7710 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/helloworld/BUILD.bazel +++ b/examples/protos/src/main/proto/io/grpc/examples/helloworld/BUILD.bazel @@ -1,6 +1,6 @@ +load("@protobuf//bazel:java_lite_proto_library.bzl", "java_lite_proto_library") load("@protobuf//bazel:proto_library.bzl", "proto_library") load("//:kt_jvm_grpc.bzl", "kt_jvm_grpc_library", "kt_jvm_proto_library") -load("@protobuf//bazel:java_lite_proto_library.bzl", "java_lite_proto_library") licenses(["notice"]) diff --git a/examples/protos/src/main/proto/io/grpc/examples/helloworld/hello_world.proto b/examples/protos/src/main/proto/io/grpc/examples/helloworld/hello_world.proto index 40316ff7..ad0115c8 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/helloworld/hello_world.proto +++ b/examples/protos/src/main/proto/io/grpc/examples/helloworld/hello_world.proto @@ -13,16 +13,16 @@ // limitations under the License. syntax = "proto3"; +package helloworld; + option java_multiple_files = true; -option java_package = "io.grpc.examples.helloworld"; option java_outer_classname = "HelloWorldProto"; - -package helloworld; +option java_package = "io.grpc.examples.helloworld"; // The greeting service definition. service Greeter { // Sends a greeting - rpc SayHello (HelloRequest) returns (HelloReply) {} + rpc SayHello(HelloRequest) returns (HelloReply) {} } // The request message containing the user's name. diff --git a/examples/protos/src/main/proto/io/grpc/examples/routeguide/BUILD.bazel b/examples/protos/src/main/proto/io/grpc/examples/routeguide/BUILD.bazel index 469ed872..7a71aa9e 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/routeguide/BUILD.bazel +++ b/examples/protos/src/main/proto/io/grpc/examples/routeguide/BUILD.bazel @@ -1,7 +1,7 @@ -load("@protobuf//bazel:proto_library.bzl", "proto_library") -load("//:kt_jvm_grpc.bzl", "kt_jvm_grpc_library", "kt_jvm_proto_library") load("@protobuf//bazel:java_lite_proto_library.bzl", "java_lite_proto_library") load("@protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@protobuf//bazel:proto_library.bzl", "proto_library") +load("//:kt_jvm_grpc.bzl", "kt_jvm_grpc_library", "kt_jvm_proto_library") licenses(["notice"]) diff --git a/examples/protos/src/main/proto/io/grpc/examples/routeguide/route_guide.proto b/examples/protos/src/main/proto/io/grpc/examples/routeguide/route_guide.proto index 3d69733c..cd039658 100644 --- a/examples/protos/src/main/proto/io/grpc/examples/routeguide/route_guide.proto +++ b/examples/protos/src/main/proto/io/grpc/examples/routeguide/route_guide.proto @@ -112,4 +112,4 @@ message RouteSummary { // The duration of the traversal. google.protobuf.Duration elapsed_time = 4; -} \ No newline at end of file +} diff --git a/examples/server/build.gradle.kts b/examples/server/build.gradle.kts index 31ea7e07..f96b7dd7 100644 --- a/examples/server/build.gradle.kts +++ b/examples/server/build.gradle.kts @@ -1,87 +1,81 @@ plugins { - application - alias(libs.plugins.kotlin.jvm) - alias(libs.plugins.jib) + application + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.jib) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } dependencies { - implementation(project(":stub")) + implementation(project(":stub")) - runtimeOnly(libs.grpc.netty) + runtimeOnly(libs.grpc.netty) - testImplementation(libs.kotlin.test.junit) - testImplementation(libs.grpc.testing) + testImplementation(libs.kotlin.test.junit) + testImplementation(libs.grpc.testing) } tasks.register("HelloWorldServer") { - dependsOn("classes") - classpath = sourceSets["main"].runtimeClasspath - mainClass.set("io.grpc.examples.helloworld.HelloWorldServerKt") + dependsOn("classes") + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("io.grpc.examples.helloworld.HelloWorldServerKt") } tasks.register("RouteGuideServer") { - dependsOn("classes") - classpath = sourceSets["main"].runtimeClasspath - mainClass.set("io.grpc.examples.routeguide.RouteGuideServerKt") + dependsOn("classes") + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("io.grpc.examples.routeguide.RouteGuideServerKt") } tasks.register("AnimalsServer") { - dependsOn("classes") - classpath = sourceSets["main"].runtimeClasspath - mainClass.set("io.grpc.examples.animals.AnimalsServerKt") + dependsOn("classes") + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("io.grpc.examples.animals.AnimalsServerKt") } val helloWorldServerStartScripts = - tasks.register("helloWorldServerStartScripts") { - mainClass.set("io.grpc.examples.helloworld.HelloWorldServerKt") - applicationName = "hello-world-server" - outputDir = tasks.named("startScripts").get().outputDir - classpath = tasks.named("startScripts").get().classpath - } + tasks.register("helloWorldServerStartScripts") { + mainClass.set("io.grpc.examples.helloworld.HelloWorldServerKt") + applicationName = "hello-world-server" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("startScripts").get().classpath + } val routeGuideServerStartScripts = - tasks.register("routeGuideServerStartScripts") { - mainClass.set("io.grpc.examples.routeguide.RouteGuideServerKt") - applicationName = "route-guide-server" - outputDir = tasks.named("startScripts").get().outputDir - classpath = tasks.named("startScripts").get().classpath - } + tasks.register("routeGuideServerStartScripts") { + mainClass.set("io.grpc.examples.routeguide.RouteGuideServerKt") + applicationName = "route-guide-server" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("startScripts").get().classpath + } val animalsServerStartScripts = - tasks.register("animalsServerStartScripts") { - mainClass.set("io.grpc.examples.animals.AnimalsServerKt") - applicationName = "animals-server" - outputDir = tasks.named("startScripts").get().outputDir - classpath = tasks.named("startScripts").get().classpath - } + tasks.register("animalsServerStartScripts") { + mainClass.set("io.grpc.examples.animals.AnimalsServerKt") + applicationName = "animals-server" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("startScripts").get().classpath + } tasks.named("startScripts") { - dependsOn(helloWorldServerStartScripts) - dependsOn(routeGuideServerStartScripts) - dependsOn(animalsServerStartScripts) + dependsOn(helloWorldServerStartScripts) + dependsOn(routeGuideServerStartScripts) + dependsOn(animalsServerStartScripts) } tasks.withType { - useJUnit() + useJUnit() - testLogging { - events = - setOf( - org.gradle.api.tasks.testing.logging.TestLogEvent.PASSED, - org.gradle.api.tasks.testing.logging.TestLogEvent.SKIPPED, - org.gradle.api.tasks.testing.logging.TestLogEvent.FAILED, - ) - exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL - showStandardStreams = true - } + testLogging { + events = + setOf( + org.gradle.api.tasks.testing.logging.TestLogEvent.PASSED, + org.gradle.api.tasks.testing.logging.TestLogEvent.SKIPPED, + org.gradle.api.tasks.testing.logging.TestLogEvent.FAILED, + ) + exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL + showStandardStreams = true + } } -jib { - container { - mainClass = "io.grpc.examples.helloworld.HelloWorldServerKt" - } -} +jib { container { mainClass = "io.grpc.examples.helloworld.HelloWorldServerKt" } } diff --git a/examples/server/src/main/kotlin/io/grpc/examples/animals/AnimalsServer.kt b/examples/server/src/main/kotlin/io/grpc/examples/animals/AnimalsServer.kt index a3e690b9..17260711 100644 --- a/examples/server/src/main/kotlin/io/grpc/examples/animals/AnimalsServer.kt +++ b/examples/server/src/main/kotlin/io/grpc/examples/animals/AnimalsServer.kt @@ -20,59 +20,50 @@ import io.grpc.Server import io.grpc.ServerBuilder class AnimalsServer constructor(private val port: Int) { - val server: Server = - ServerBuilder - .forPort(port) - .addService(DogService()) - .addService(PigService()) - .addService(SheepService()) - .build() + val server: Server = + ServerBuilder.forPort(port) + .addService(DogService()) + .addService(PigService()) + .addService(SheepService()) + .build() - fun start() { - server.start() - println("Server started, listening on $port") - Runtime.getRuntime().addShutdownHook( - Thread { - println("*** shutting down gRPC server since JVM is shutting down") - this@AnimalsServer.stop() - println("*** server shut down") - }, - ) - } + fun start() { + server.start() + println("Server started, listening on $port") + Runtime.getRuntime() + .addShutdownHook( + Thread { + println("*** shutting down gRPC server since JVM is shutting down") + this@AnimalsServer.stop() + println("*** server shut down") + }, + ) + } - private fun stop() { - server.shutdown() - } + private fun stop() { + server.shutdown() + } - fun blockUntilShutdown() { - server.awaitTermination() - } + fun blockUntilShutdown() { + server.awaitTermination() + } - internal class DogService : DogGrpcKt.DogCoroutineImplBase() { - override suspend fun bark(request: BarkRequest) = - barkReply { - message = "Bark!" - } - } + internal class DogService : DogGrpcKt.DogCoroutineImplBase() { + override suspend fun bark(request: BarkRequest) = barkReply { message = "Bark!" } + } - internal class PigService : PigGrpcKt.PigCoroutineImplBase() { - override suspend fun oink(request: OinkRequest) = - oinkReply { - message = "Oink!" - } - } + internal class PigService : PigGrpcKt.PigCoroutineImplBase() { + override suspend fun oink(request: OinkRequest) = oinkReply { message = "Oink!" } + } - internal class SheepService : SheepGrpcKt.SheepCoroutineImplBase() { - override suspend fun baa(request: BaaRequest) = - baaReply { - message = "Baa!" - } - } + internal class SheepService : SheepGrpcKt.SheepCoroutineImplBase() { + override suspend fun baa(request: BaaRequest) = baaReply { message = "Baa!" } + } } fun main() { - val port = 50051 - val server = AnimalsServer(port) - server.start() - server.blockUntilShutdown() + val port = 50051 + val server = AnimalsServer(port) + server.start() + server.blockUntilShutdown() } diff --git a/examples/server/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel b/examples/server/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel index 1f28ee0b..6ffc492f 100644 --- a/examples/server/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel +++ b/examples/server/src/main/kotlin/io/grpc/examples/animals/BUILD.bazel @@ -11,7 +11,7 @@ kt_jvm_binary( deps = [ "//examples/protos/src/main/proto/io/grpc/examples/animals:animals_kt_grpc", "//examples/protos/src/main/proto/io/grpc/examples/animals:animals_kt_proto", - "@protobuf//:protobuf_java_util", "@grpc-java//netty", + "@protobuf//:protobuf_java_util", ], ) diff --git a/examples/server/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldServer.kt b/examples/server/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldServer.kt index 96d33472..d75356a5 100644 --- a/examples/server/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldServer.kt +++ b/examples/server/src/main/kotlin/io/grpc/examples/helloworld/HelloWorldServer.kt @@ -20,43 +20,39 @@ import io.grpc.Server import io.grpc.ServerBuilder class HelloWorldServer(private val port: Int) { - val server: Server = - ServerBuilder - .forPort(port) - .addService(HelloWorldService()) - .build() - - fun start() { - server.start() - println("Server started, listening on $port") - Runtime.getRuntime().addShutdownHook( - Thread { - println("*** shutting down gRPC server since JVM is shutting down") - this@HelloWorldServer.stop() - println("*** server shut down") - }, - ) - } - - private fun stop() { - server.shutdown() - } + val server: Server = ServerBuilder.forPort(port).addService(HelloWorldService()).build() - fun blockUntilShutdown() { - server.awaitTermination() - } - - internal class HelloWorldService : GreeterGrpcKt.GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest) = - helloReply { - message = "Hello ${request.name}" - } + fun start() { + server.start() + println("Server started, listening on $port") + Runtime.getRuntime() + .addShutdownHook( + Thread { + println("*** shutting down gRPC server since JVM is shutting down") + this@HelloWorldServer.stop() + println("*** server shut down") + }, + ) + } + + private fun stop() { + server.shutdown() + } + + fun blockUntilShutdown() { + server.awaitTermination() + } + + internal class HelloWorldService : GreeterGrpcKt.GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest) = helloReply { + message = "Hello ${request.name}" } + } } fun main() { - val port = System.getenv("PORT")?.toInt() ?: 50051 - val server = HelloWorldServer(port) - server.start() - server.blockUntilShutdown() + val port = System.getenv("PORT")?.toInt() ?: 50051 + val server = HelloWorldServer(port) + server.start() + server.blockUntilShutdown() } diff --git a/examples/server/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel b/examples/server/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel index ee2e683c..e333eb04 100644 --- a/examples/server/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel +++ b/examples/server/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel @@ -13,7 +13,7 @@ kt_jvm_binary( "//examples/protos/src/main/proto/io/grpc/examples/routeguide:route_guide_kt_grpc", "//examples/protos/src/main/proto/io/grpc/examples/routeguide:route_guide_kt_proto", "//examples/stub/src/main/kotlin/io/grpc/examples/routeguide:route_guide_stub", - "@protobuf//:protobuf_java_util", "@grpc-java//netty", + "@protobuf//:protobuf_java_util", ], ) diff --git a/examples/server/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideServer.kt b/examples/server/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideServer.kt index 0eccc13a..2cfa7ed7 100644 --- a/examples/server/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideServer.kt +++ b/examples/server/src/main/kotlin/io/grpc/examples/routeguide/RouteGuideServer.kt @@ -21,99 +21,98 @@ import com.google.common.base.Ticker import com.google.protobuf.util.Durations import io.grpc.Server import io.grpc.ServerBuilder +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.TimeUnit import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.flow -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.TimeUnit -/** - * Kotlin adaptation of RouteGuideServer from the Java gRPC example. - */ +/** Kotlin adaptation of RouteGuideServer from the Java gRPC example. */ class RouteGuideServer( - val port: Int, - val features: Collection = Database.features(), - val server: Server = ServerBuilder.forPort(port).addService(RouteGuideService(features)).build(), + val port: Int, + val features: Collection = Database.features(), + val server: Server = ServerBuilder.forPort(port).addService(RouteGuideService(features)).build(), ) { - fun start() { - server.start() - println("Server started, listening on $port") - Runtime.getRuntime().addShutdownHook( - Thread { - println("*** shutting down gRPC server since JVM is shutting down") - this@RouteGuideServer.stop() - println("*** server shut down") - }, - ) - } + fun start() { + server.start() + println("Server started, listening on $port") + Runtime.getRuntime() + .addShutdownHook( + Thread { + println("*** shutting down gRPC server since JVM is shutting down") + this@RouteGuideServer.stop() + println("*** server shut down") + }, + ) + } - fun stop() { - server.shutdown() - } + fun stop() { + server.shutdown() + } - fun blockUntilShutdown() { - server.awaitTermination() - } + fun blockUntilShutdown() { + server.awaitTermination() + } - internal class RouteGuideService( - private val features: Collection, - private val ticker: Ticker = Ticker.systemTicker(), - ) : RouteGuideGrpcKt.RouteGuideCoroutineImplBase() { - private val routeNotes = ConcurrentHashMap>() + internal class RouteGuideService( + private val features: Collection, + private val ticker: Ticker = Ticker.systemTicker(), + ) : RouteGuideGrpcKt.RouteGuideCoroutineImplBase() { + private val routeNotes = ConcurrentHashMap>() - override suspend fun getFeature(request: Point): Feature = - // No feature was found, return an unnamed feature. - features.find { it.location == request } ?: feature { location = request } + override suspend fun getFeature(request: Point): Feature = + // No feature was found, return an unnamed feature. + features.find { it.location == request } ?: feature { location = request } - override fun listFeatures(request: Rectangle): Flow = features.asFlow().filter { it.exists() && it.location in request } + override fun listFeatures(request: Rectangle): Flow = + features.asFlow().filter { it.exists() && it.location in request } - override suspend fun recordRoute(requests: Flow): RouteSummary { - var pointCount = 0 - var featureCount = 0 - var distance = 0 - var previous: Point? = null - val stopwatch = Stopwatch.createStarted(ticker) - requests.collect { request -> - pointCount++ - if (getFeature(request).exists()) { - featureCount++ - } - val prev = previous - if (prev != null) { - distance += prev distanceTo request - } - previous = request - } - return routeSummary { - this.pointCount = pointCount - this.featureCount = featureCount - this.distance = distance - this.elapsedTime = Durations.fromMicros(stopwatch.elapsed(TimeUnit.MICROSECONDS)) - } + override suspend fun recordRoute(requests: Flow): RouteSummary { + var pointCount = 0 + var featureCount = 0 + var distance = 0 + var previous: Point? = null + val stopwatch = Stopwatch.createStarted(ticker) + requests.collect { request -> + pointCount++ + if (getFeature(request).exists()) { + featureCount++ + } + val prev = previous + if (prev != null) { + distance += prev distanceTo request } + previous = request + } + return routeSummary { + this.pointCount = pointCount + this.featureCount = featureCount + this.distance = distance + this.elapsedTime = Durations.fromMicros(stopwatch.elapsed(TimeUnit.MICROSECONDS)) + } + } - override fun routeChat(requests: Flow): Flow = - flow { - requests.collect { note -> - val notes: MutableList = - routeNotes.computeIfAbsent(note.location) { - Collections.synchronizedList(mutableListOf()) - } - for (prevNote in notes.toTypedArray()) { // thread-safe snapshot - emit(prevNote) - } - notes += note - } - } + override fun routeChat(requests: Flow): Flow = flow { + requests.collect { note -> + val notes: MutableList = + routeNotes.computeIfAbsent(note.location) { + Collections.synchronizedList(mutableListOf()) + } + for (prevNote in notes.toTypedArray()) { // thread-safe snapshot + emit(prevNote) + } + notes += note + } } + } } fun main() { - val port = 8980 - val server = RouteGuideServer(port) - server.start() - server.blockUntilShutdown() + val port = 8980 + val server = RouteGuideServer(port) + server.start() + server.blockUntilShutdown() } diff --git a/examples/server/src/test/kotlin/io/grpc/examples/animals/AnimalsServerTest.kt b/examples/server/src/test/kotlin/io/grpc/examples/animals/AnimalsServerTest.kt index 73cfe71c..7434bde1 100644 --- a/examples/server/src/test/kotlin/io/grpc/examples/animals/AnimalsServerTest.kt +++ b/examples/server/src/test/kotlin/io/grpc/examples/animals/AnimalsServerTest.kt @@ -17,35 +17,33 @@ package io.grpc.examples.animals import io.grpc.testing.GrpcServerRule -import kotlinx.coroutines.runBlocking -import org.junit.Rule import kotlin.test.Test import kotlin.test.assertEquals +import kotlinx.coroutines.runBlocking +import org.junit.Rule class AnimalsServerTest { - @get:Rule - val grpcServerRule: GrpcServerRule = GrpcServerRule().directExecutor() + @get:Rule val grpcServerRule: GrpcServerRule = GrpcServerRule().directExecutor() - @Test - fun animals() = - runBlocking { - val dogService = AnimalsServer.DogService() - val pigService = AnimalsServer.PigService() - val sheepService = AnimalsServer.SheepService() - grpcServerRule.serviceRegistry.addService(dogService) - grpcServerRule.serviceRegistry.addService(pigService) - grpcServerRule.serviceRegistry.addService(sheepService) + @Test + fun animals() = runBlocking { + val dogService = AnimalsServer.DogService() + val pigService = AnimalsServer.PigService() + val sheepService = AnimalsServer.SheepService() + grpcServerRule.serviceRegistry.addService(dogService) + grpcServerRule.serviceRegistry.addService(pigService) + grpcServerRule.serviceRegistry.addService(sheepService) - val dogStub = DogGrpcKt.DogCoroutineStub(grpcServerRule.channel) - val dogBark = dogStub.bark(barkRequest { }) - assertEquals("Bark!", dogBark.message) + val dogStub = DogGrpcKt.DogCoroutineStub(grpcServerRule.channel) + val dogBark = dogStub.bark(barkRequest {}) + assertEquals("Bark!", dogBark.message) - val pigStub = PigGrpcKt.PigCoroutineStub(grpcServerRule.channel) - val pigOink = pigStub.oink(oinkRequest { }) - assertEquals("Oink!", pigOink.message) + val pigStub = PigGrpcKt.PigCoroutineStub(grpcServerRule.channel) + val pigOink = pigStub.oink(oinkRequest {}) + assertEquals("Oink!", pigOink.message) - val sheepStub = SheepGrpcKt.SheepCoroutineStub(grpcServerRule.channel) - val sheepBaa = sheepStub.baa(baaRequest { }) - assertEquals("Baa!", sheepBaa.message) - } + val sheepStub = SheepGrpcKt.SheepCoroutineStub(grpcServerRule.channel) + val sheepBaa = sheepStub.baa(baaRequest {}) + assertEquals("Baa!", sheepBaa.message) + } } diff --git a/examples/server/src/test/kotlin/io/grpc/examples/helloworld/HelloWorldServerTest.kt b/examples/server/src/test/kotlin/io/grpc/examples/helloworld/HelloWorldServerTest.kt index 3e3768cf..b96f795d 100644 --- a/examples/server/src/test/kotlin/io/grpc/examples/helloworld/HelloWorldServerTest.kt +++ b/examples/server/src/test/kotlin/io/grpc/examples/helloworld/HelloWorldServerTest.kt @@ -17,26 +17,24 @@ package io.grpc.examples.helloworld import io.grpc.testing.GrpcServerRule -import kotlinx.coroutines.runBlocking -import org.junit.Rule import kotlin.test.Test import kotlin.test.assertEquals +import kotlinx.coroutines.runBlocking +import org.junit.Rule class HelloWorldServerTest { - @get:Rule - val grpcServerRule: GrpcServerRule = GrpcServerRule().directExecutor() + @get:Rule val grpcServerRule: GrpcServerRule = GrpcServerRule().directExecutor() - @Test - fun sayHello() = - runBlocking { - val service = HelloWorldServer.HelloWorldService() - grpcServerRule.serviceRegistry.addService(service) + @Test + fun sayHello() = runBlocking { + val service = HelloWorldServer.HelloWorldService() + grpcServerRule.serviceRegistry.addService(service) - val stub = GreeterGrpcKt.GreeterCoroutineStub(grpcServerRule.channel) - val testName = "test name" + val stub = GreeterGrpcKt.GreeterCoroutineStub(grpcServerRule.channel) + val testName = "test name" - val reply = stub.sayHello(helloRequest { name = testName }) + val reply = stub.sayHello(helloRequest { name = testName }) - assertEquals("Hello $testName", reply.message) - } + assertEquals("Hello $testName", reply.message) + } } diff --git a/examples/server/src/test/kotlin/io/grpc/examples/routeguide/RouteGuideServerTest.kt b/examples/server/src/test/kotlin/io/grpc/examples/routeguide/RouteGuideServerTest.kt index bd78ecad..b0247026 100644 --- a/examples/server/src/test/kotlin/io/grpc/examples/routeguide/RouteGuideServerTest.kt +++ b/examples/server/src/test/kotlin/io/grpc/examples/routeguide/RouteGuideServerTest.kt @@ -17,39 +17,34 @@ package io.grpc.examples.routeguide import io.grpc.testing.GrpcServerRule +import kotlin.test.Test +import kotlin.test.assertEquals import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Rule -import kotlin.test.Test -import kotlin.test.assertEquals class RouteGuideServerTest { - @get:Rule - val grpcServerRule: GrpcServerRule = GrpcServerRule().directExecutor() + @get:Rule val grpcServerRule: GrpcServerRule = GrpcServerRule().directExecutor() - @Test - fun listFeatures() = - runBlocking { - val service = RouteGuideServer.RouteGuideService(Database.features()) - grpcServerRule.serviceRegistry.addService(service) + @Test + fun listFeatures() = runBlocking { + val service = RouteGuideServer.RouteGuideService(Database.features()) + grpcServerRule.serviceRegistry.addService(service) - val stub = RouteGuideGrpcKt.RouteGuideCoroutineStub(grpcServerRule.channel) + val stub = RouteGuideGrpcKt.RouteGuideCoroutineStub(grpcServerRule.channel) - val rectangle = - rectangle { - lo = - point { - latitude = 407838351 - longitude = -746143763 - } - hi = - point { - latitude = 407838351 - longitude = -746143763 - } - } + val rectangle = rectangle { + lo = point { + latitude = 407838351 + longitude = -746143763 + } + hi = point { + latitude = 407838351 + longitude = -746143763 + } + } - val features = stub.listFeatures(rectangle).toList() - assertEquals("Patriots Path, Mendham, NJ 07945, USA", features.first().name) - } + val features = stub.listFeatures(rectangle).toList() + assertEquals("Patriots Path, Mendham, NJ 07945, USA", features.first().name) + } } diff --git a/examples/settings.gradle.kts b/examples/settings.gradle.kts index bf6f3ac6..26986eb3 100644 --- a/examples/settings.gradle.kts +++ b/examples/settings.gradle.kts @@ -2,26 +2,33 @@ rootProject.name = "grpc-kotlin-examples" // when running the assemble task, ignore the android & graalvm related subprojects if (startParameter.taskRequests.find { it.args.contains("assemble") } == null) { - include("protos", "stub", "stub-lite", "client", "native-client", "server", "stub-android", "android") + include( + "protos", + "stub", + "stub-lite", + "client", + "native-client", + "server", + "stub-android", + "android" + ) } else { - include("protos", "stub", "server") + include("protos", "stub", "server") } pluginManagement { - repositories { - gradlePluginPortal() - google() - } + repositories { + gradlePluginPortal() + google() + } } dependencyResolutionManagement { - @Suppress("UnstableApiUsage") - repositories { - mavenCentral() - google() - } + @Suppress("UnstableApiUsage") + repositories { + mavenCentral() + google() + } } -plugins { - id("org.gradle.toolchains.foojay-resolver-convention") version "0.8.0" -} +plugins { id("org.gradle.toolchains.foojay-resolver-convention") version "0.8.0" } diff --git a/examples/stub-android/build.gradle.kts b/examples/stub-android/build.gradle.kts index 96ef8667..c128fd21 100644 --- a/examples/stub-android/build.gradle.kts +++ b/examples/stub-android/build.gradle.kts @@ -1,69 +1,47 @@ plugins { - alias(libs.plugins.android.library) - alias(libs.plugins.kotlin.android) - alias(libs.plugins.protobuf) + alias(libs.plugins.android.library) + alias(libs.plugins.kotlin.android) + alias(libs.plugins.protobuf) } dependencies { - protobuf(project(":protos")) + protobuf(project(":protos")) - api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.coroutines.core) - api(libs.grpc.stub) - api(libs.grpc.protobuf.lite) - api(libs.grpc.kotlin.stub) - api(libs.protobuf.kotlin.lite) + api(libs.grpc.stub) + api(libs.grpc.protobuf.lite) + api(libs.grpc.kotlin.stub) + api(libs.protobuf.kotlin.lite) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } android { - compileSdk = 34 - buildToolsVersion = "34.0.0" - namespace = "io.grpc.examples.stublite" + compileSdk = 34 + buildToolsVersion = "34.0.0" + namespace = "io.grpc.examples.stublite" } tasks.withType().all { - kotlinOptions { - freeCompilerArgs = listOf("-opt-in=kotlin.RequiresOptIn") - } + kotlinOptions { freeCompilerArgs = listOf("-opt-in=kotlin.RequiresOptIn") } } protobuf { - protoc { - artifact = libs.protoc.asProvider().get().toString() - } - plugins { - create("java") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - create("grpc") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - create("grpckt") { - artifact = libs.protoc.gen.grpc.kotlin.get().toString() + ":jdk8@jar" - } - } - generateProtoTasks { - all().forEach { - it.plugins { - create("java") { - option("lite") - } - create("grpc") { - option("lite") - } - create("grpckt") { - option("lite") - } - } - it.builtins { - create("kotlin") { - option("lite") - } - } - } + protoc { artifact = libs.protoc.asProvider().get().toString() } + plugins { + create("java") { artifact = libs.protoc.gen.grpc.java.get().toString() } + create("grpc") { artifact = libs.protoc.gen.grpc.java.get().toString() } + create("grpckt") { artifact = libs.protoc.gen.grpc.kotlin.get().toString() + ":jdk8@jar" } + } + generateProtoTasks { + all().forEach { + it.plugins { + create("java") { option("lite") } + create("grpc") { option("lite") } + create("grpckt") { option("lite") } + } + it.builtins { create("kotlin") { option("lite") } } } + } } diff --git a/examples/stub-lite/build.gradle.kts b/examples/stub-lite/build.gradle.kts index 2b787dd9..61d562c4 100644 --- a/examples/stub-lite/build.gradle.kts +++ b/examples/stub-lite/build.gradle.kts @@ -1,59 +1,41 @@ plugins { - alias(libs.plugins.kotlin.jvm) - alias(libs.plugins.protobuf) + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.protobuf) } dependencies { - protobuf(project(":protos")) + protobuf(project(":protos")) - api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.coroutines.core) - api(libs.grpc.stub) - api(libs.grpc.protobuf.lite) - api(libs.grpc.kotlin.stub) - api(libs.protobuf.kotlin.lite) + api(libs.grpc.stub) + api(libs.grpc.protobuf.lite) + api(libs.grpc.kotlin.stub) + api(libs.protobuf.kotlin.lite) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } tasks.withType().all { - kotlinOptions { - freeCompilerArgs = listOf("-opt-in=kotlin.RequiresOptIn") - } + kotlinOptions { freeCompilerArgs = listOf("-opt-in=kotlin.RequiresOptIn") } } protobuf { - protoc { - artifact = libs.protoc.asProvider().get().toString() - } - plugins { - create("grpc") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - create("grpckt") { - artifact = libs.protoc.gen.grpc.kotlin.get().toString() + ":jdk8@jar" - } - } - generateProtoTasks { - all().forEach { - it.builtins { - named("java") { - option("lite") - } - create("kotlin") { - option("lite") - } - } - it.plugins { - create("grpc") { - option("lite") - } - create("grpckt") { - option("lite") - } - } - } + protoc { artifact = libs.protoc.asProvider().get().toString() } + plugins { + create("grpc") { artifact = libs.protoc.gen.grpc.java.get().toString() } + create("grpckt") { artifact = libs.protoc.gen.grpc.kotlin.get().toString() + ":jdk8@jar" } + } + generateProtoTasks { + all().forEach { + it.builtins { + named("java") { option("lite") } + create("kotlin") { option("lite") } + } + it.plugins { + create("grpc") { option("lite") } + create("grpckt") { option("lite") } + } } + } } diff --git a/examples/stub/build.gradle.kts b/examples/stub/build.gradle.kts index b9677c80..00e9b773 100644 --- a/examples/stub/build.gradle.kts +++ b/examples/stub/build.gradle.kts @@ -1,51 +1,39 @@ plugins { - alias(libs.plugins.kotlin.jvm) - alias(libs.plugins.protobuf) + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.protobuf) } dependencies { - protobuf(project(":protos")) + protobuf(project(":protos")) - api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.coroutines.core) - api(libs.grpc.stub) - api(libs.grpc.protobuf) - api(libs.protobuf.java.util) - api(libs.protobuf.kotlin) - api(libs.grpc.kotlin.stub) + api(libs.grpc.stub) + api(libs.grpc.protobuf) + api(libs.protobuf.java.util) + api(libs.protobuf.kotlin) + api(libs.grpc.kotlin.stub) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } tasks.withType().all { - kotlinOptions { - freeCompilerArgs = listOf("-opt-in=kotlin.RequiresOptIn") - } + kotlinOptions { freeCompilerArgs = listOf("-opt-in=kotlin.RequiresOptIn") } } protobuf { - protoc { - artifact = libs.protoc.asProvider().get().toString() - } - plugins { - create("grpc") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - create("grpckt") { - artifact = libs.protoc.gen.grpc.kotlin.get().toString() + ":jdk8@jar" - } - } - generateProtoTasks { - all().forEach { - it.plugins { - create("grpc") - create("grpckt") - } - it.builtins { - create("kotlin") - } - } + protoc { artifact = libs.protoc.asProvider().get().toString() } + plugins { + create("grpc") { artifact = libs.protoc.gen.grpc.java.get().toString() } + create("grpckt") { artifact = libs.protoc.gen.grpc.kotlin.get().toString() + ":jdk8@jar" } + } + generateProtoTasks { + all().forEach { + it.plugins { + create("grpc") + create("grpckt") + } + it.builtins { create("kotlin") } } + } } diff --git a/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel b/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel index 3c671f13..0bddc196 100644 --- a/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel +++ b/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/BUILD.bazel @@ -8,11 +8,11 @@ kt_jvm_library( name = "route_guide_stub", srcs = [ "Database.kt", - "Points.kt" + "Points.kt", ], deps = [ "//examples/protos/src/main/proto/io/grpc/examples/routeguide:route_guide_java_proto", "//examples/protos/src/main/proto/io/grpc/examples/routeguide:route_guide_kt_grpc", - "@protobuf//:protobuf_java_util" + "@protobuf//:protobuf_java_util", ], ) diff --git a/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Database.kt b/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Database.kt index 0e218129..fab302b7 100644 --- a/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Database.kt +++ b/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Database.kt @@ -19,11 +19,11 @@ package io.grpc.examples.routeguide import com.google.protobuf.util.JsonFormat object Database { - fun features(): List { - return javaClass.getResourceAsStream("route_guide_db.json")?.use { - val featureDatabaseBuilder = FeatureDatabase.newBuilder() - JsonFormat.parser().merge(it.reader(), featureDatabaseBuilder) - featureDatabaseBuilder.build().featureList - } ?: emptyList() - } + fun features(): List { + return javaClass.getResourceAsStream("route_guide_db.json")?.use { + val featureDatabaseBuilder = FeatureDatabase.newBuilder() + JsonFormat.parser().merge(it.reader(), featureDatabaseBuilder) + featureDatabaseBuilder.build().featureList + } ?: emptyList() + } } diff --git a/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Points.kt b/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Points.kt index eecfdffb..98b939fa 100644 --- a/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Points.kt +++ b/examples/stub/src/main/kotlin/io/grpc/examples/routeguide/Points.kt @@ -28,33 +28,33 @@ private const val EARTH_RADIUS_IN_M = 6371000 private fun Int.toRadians() = Math.toRadians(toDouble()) infix fun Point.distanceTo(other: Point): Int { - val lat1 = latitude.toRadians() - val long1 = longitude.toRadians() - val lat2 = other.latitude.toRadians() - val long2 = other.latitude.toRadians() + val lat1 = latitude.toRadians() + val long1 = longitude.toRadians() + val lat2 = other.latitude.toRadians() + val long2 = other.latitude.toRadians() - val dLat = lat2 - lat1 - val dLong = long2 - long1 + val dLat = lat2 - lat1 + val dLong = long2 - long1 - val a = sin(dLat / 2).pow(2) + cos(lat1) * cos(lat2) * sin(dLong / 2).pow(2) - val c = 2 * atan2(sqrt(a), sqrt(1 - a)) - return (EARTH_RADIUS_IN_M * c).roundToInt() + val a = sin(dLat / 2).pow(2) + cos(lat1) * cos(lat2) * sin(dLong / 2).pow(2) + val c = 2 * atan2(sqrt(a), sqrt(1 - a)) + return (EARTH_RADIUS_IN_M * c).roundToInt() } operator fun Rectangle.contains(p: Point): Boolean { - val lowLong = minOf(lo.longitude, hi.longitude) - val hiLong = maxOf(lo.longitude, hi.longitude) - val lowLat = minOf(lo.latitude, hi.latitude) - val hiLat = maxOf(lo.latitude, hi.latitude) - return p.longitude in lowLong..hiLong && p.latitude in lowLat..hiLat + val lowLong = minOf(lo.longitude, hi.longitude) + val hiLong = maxOf(lo.longitude, hi.longitude) + val lowLat = minOf(lo.latitude, hi.latitude) + val hiLat = maxOf(lo.latitude, hi.latitude) + return p.longitude in lowLong..hiLong && p.latitude in lowLat..hiLat } private fun Int.normalizeCoordinate(): Double = this / 1.0e7 fun Point.toStr(): String { - val lat = latitude.normalizeCoordinate() - val long = longitude.normalizeCoordinate() - return "$lat, $long" + val lat = latitude.normalizeCoordinate() + val long = longitude.normalizeCoordinate() + return "$lat, $long" } fun Feature.exists(): Boolean = name.isNotEmpty() diff --git a/integration_testing/build.gradle.kts b/integration_testing/build.gradle.kts index ce79c62f..6545a27f 100644 --- a/integration_testing/build.gradle.kts +++ b/integration_testing/build.gradle.kts @@ -1,59 +1,60 @@ import org.gradle.api.tasks.testing.logging.TestExceptionFormat import org.gradle.api.tasks.testing.logging.TestLogEvent -plugins { - alias(libs.plugins.kotlin.jvm) -} +plugins { alias(libs.plugins.kotlin.jvm) } -kotlin { - jvmToolchain(17) -} +kotlin { jvmToolchain(17) } dependencies { - testImplementation(libs.testcontainers) - testImplementation(libs.gradle.test.kit) - testImplementation(libs.gradle.tooling.api) - testImplementation(libs.commons.io) - testImplementation(libs.junit.jupiter) - testImplementation(libs.slf4j.simple) - testRuntimeOnly(libs.junit.platform.launcher) + testImplementation(libs.testcontainers) + testImplementation(libs.gradle.test.kit) + testImplementation(libs.gradle.tooling.api) + testImplementation(libs.commons.io) + testImplementation(libs.junit.jupiter) + testImplementation(libs.slf4j.simple) + testRuntimeOnly(libs.junit.platform.launcher) } tasks.named("test") { - val examplesDir = File(rootDir, "examples") - inputs.dir(examplesDir) - dependsOn(":compiler:publishAllPublicationsToMavenRepository") - dependsOn(":stub:publishAllPublicationsToMavenRepository") - - useJUnitPlatform() - - testLogging { - showStandardStreams = true - exceptionFormat = TestExceptionFormat.FULL - events(TestLogEvent.STANDARD_OUT, TestLogEvent.STANDARD_ERROR, TestLogEvent.STARTED, TestLogEvent.PASSED, TestLogEvent.SKIPPED, TestLogEvent.FAILED) - } - - retry { - maxRetries = 1 - maxFailures = 1 - } - - systemProperties["grpc-kotlin-version"] = project.version - systemProperties["examples-dir"] = examplesDir - systemProperties["test-repo"] = (publishing.repositories.getByName("maven") as MavenArtifactRepository).url - - /* - val properties = Properties() - if (rootProject.file("local.properties").exists()) { - properties.load(rootProject.file("local.properties").inputStream()) - environment("ANDROID_HOME", properties.getProperty("sdk.dir")) - } - */ - - // todo: cleanup copyExamples.destinationDir or move copy to tests -} - -tasks.withType { - enabled = false + val examplesDir = File(rootDir, "examples") + inputs.dir(examplesDir) + dependsOn(":compiler:publishAllPublicationsToMavenRepository") + dependsOn(":stub:publishAllPublicationsToMavenRepository") + + useJUnitPlatform() + + testLogging { + showStandardStreams = true + exceptionFormat = TestExceptionFormat.FULL + events( + TestLogEvent.STANDARD_OUT, + TestLogEvent.STANDARD_ERROR, + TestLogEvent.STARTED, + TestLogEvent.PASSED, + TestLogEvent.SKIPPED, + TestLogEvent.FAILED + ) + } + + retry { + maxRetries = 1 + maxFailures = 1 + } + + systemProperties["grpc-kotlin-version"] = project.version + systemProperties["examples-dir"] = examplesDir + systemProperties["test-repo"] = + (publishing.repositories.getByName("maven") as MavenArtifactRepository).url + + /* + val properties = Properties() + if (rootProject.file("local.properties").exists()) { + properties.load(rootProject.file("local.properties").inputStream()) + environment("ANDROID_HOME", properties.getProperty("sdk.dir")) + } + */ + + // todo: cleanup copyExamples.destinationDir or move copy to tests } +tasks.withType { enabled = false } diff --git a/integration_testing/src/test/kotlin/io/grpc/kotlin/ExamplesTest.kt b/integration_testing/src/test/kotlin/io/grpc/kotlin/ExamplesTest.kt index cb0b2bbb..5cb1d063 100644 --- a/integration_testing/src/test/kotlin/io/grpc/kotlin/ExamplesTest.kt +++ b/integration_testing/src/test/kotlin/io/grpc/kotlin/ExamplesTest.kt @@ -1,5 +1,11 @@ package io.grpc.kotlin +import java.io.File +import java.net.URI +import java.nio.file.Path +import java.util.Properties +import kotlin.io.path.div +import kotlin.io.path.inputStream import org.apache.commons.io.FileUtils import org.gradle.testkit.runner.GradleRunner import org.junit.jupiter.api.Assertions.assertTrue @@ -9,97 +15,94 @@ import org.slf4j.LoggerFactory import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.output.Slf4jLogConsumer import org.testcontainers.containers.wait.strategy.Wait -import java.io.File -import java.net.URI -import java.nio.file.Path -import java.util.Properties -import kotlin.io.path.div -import kotlin.io.path.inputStream - class ExamplesTest { - private val logger = LoggerFactory.getLogger(ExamplesTest::class.java) + private val logger = LoggerFactory.getLogger(ExamplesTest::class.java) - // todo: add test to verify jdk8 usage - @Test - fun server_client(@TempDir tempDir: Path) { - val grpcKotlinVersion = System.getProperty("grpc-kotlin-version") - val examplesDir = System.getProperty("examples-dir") - val testRepo = System.getProperty("test-repo") + // todo: add test to verify jdk8 usage + @Test + fun server_client(@TempDir tempDir: Path) { + val grpcKotlinVersion = System.getProperty("grpc-kotlin-version") + val examplesDir = System.getProperty("examples-dir") + val testRepo = System.getProperty("test-repo") - FileUtils.copyDirectory(File(examplesDir), tempDir.toFile()) + FileUtils.copyDirectory(File(examplesDir), tempDir.toFile()) - val libsVersionsToml = File(tempDir.toFile(), "gradle/libs.versions.toml") + val libsVersionsToml = File(tempDir.toFile(), "gradle/libs.versions.toml") - val versionRegex = Regex("""version = "(.*)"""") + val versionRegex = Regex("""version = "(.*)"""") - val libsVersionsTomlNewLines = libsVersionsToml.readLines().map { line -> - if (line.contains("grpc-kotlin-stub") || line.contains("protoc-gen-grpc-kotlin")) { - line.replace(versionRegex, """version = "$grpcKotlinVersion"""") - } - else { - line - } + val libsVersionsTomlNewLines = + libsVersionsToml.readLines().map { line -> + if (line.contains("grpc-kotlin-stub") || line.contains("protoc-gen-grpc-kotlin")) { + line.replace(versionRegex, """version = "$grpcKotlinVersion"""") + } else { + line } + } - libsVersionsToml.writeText(libsVersionsTomlNewLines.joinToString("\n")) + libsVersionsToml.writeText(libsVersionsTomlNewLines.joinToString("\n")) - val settingsGradle = File(tempDir.toFile(), "settings.gradle.kts") - val settingsGradleNewLines = settingsGradle.readLines().map { line -> - if (line.contains("mavenCentral()")) { - """ + val settingsGradle = File(tempDir.toFile(), "settings.gradle.kts") + val settingsGradleNewLines = + settingsGradle.readLines().map { line -> + if (line.contains("mavenCentral()")) { + """ mavenCentral() maven(uri("$testRepo")) """ - } - else { - line - } + } else { + line } - settingsGradle.writeText(settingsGradleNewLines.joinToString("\n")) - - val gradleWrapperProperties = Properties() - gradleWrapperProperties.load((tempDir / "gradle/wrapper/gradle-wrapper.properties").inputStream()) - val distributionUrl = URI.create(gradleWrapperProperties.getProperty("distributionUrl")) - - val dependencyResult = GradleRunner.create() - .withProjectDir(tempDir.toFile()) - .withArguments(":stub:dependencies") - .withGradleDistribution(distributionUrl) - .build() - - assertTrue(dependencyResult.output.contains("io.grpc:grpc-kotlin-stub:$grpcKotlinVersion")) - - GradleRunner.create() - .withProjectDir(tempDir.toFile()) - .withArguments(":client:build") - .withGradleDistribution(distributionUrl) - .build() - - GradleRunner.create() - .withProjectDir(tempDir.toFile()) - .withArguments(":server:jibDockerBuild", "--image=grpc-kotlin-examples-server") - .withGradleDistribution(distributionUrl) - .build() - - val logConsumer = Slf4jLogConsumer(logger) - - val container = GenericContainer("grpc-kotlin-examples-server") - .withExposedPorts(50051) - .waitingFor(Wait.forListeningPort()) - - container.start() - container.followOutput(logConsumer) - - val clientResult = GradleRunner.create() - .withProjectDir(tempDir.toFile()) - .withEnvironment(mapOf("PORT" to container.firstMappedPort.toString())) - .withArguments(":client:HelloWorldClient") - .withGradleDistribution(distributionUrl) - .build() - - assertTrue(clientResult.output.contains("Received: Hello world")) - } - + } + settingsGradle.writeText(settingsGradleNewLines.joinToString("\n")) + + val gradleWrapperProperties = Properties() + gradleWrapperProperties.load( + (tempDir / "gradle/wrapper/gradle-wrapper.properties").inputStream() + ) + val distributionUrl = URI.create(gradleWrapperProperties.getProperty("distributionUrl")) + + val dependencyResult = + GradleRunner.create() + .withProjectDir(tempDir.toFile()) + .withArguments(":stub:dependencies") + .withGradleDistribution(distributionUrl) + .build() + + assertTrue(dependencyResult.output.contains("io.grpc:grpc-kotlin-stub:$grpcKotlinVersion")) + + GradleRunner.create() + .withProjectDir(tempDir.toFile()) + .withArguments(":client:build") + .withGradleDistribution(distributionUrl) + .build() + + GradleRunner.create() + .withProjectDir(tempDir.toFile()) + .withArguments(":server:jibDockerBuild", "--image=grpc-kotlin-examples-server") + .withGradleDistribution(distributionUrl) + .build() + + val logConsumer = Slf4jLogConsumer(logger) + + val container = + GenericContainer("grpc-kotlin-examples-server") + .withExposedPorts(50051) + .waitingFor(Wait.forListeningPort()) + + container.start() + container.followOutput(logConsumer) + + val clientResult = + GradleRunner.create() + .withProjectDir(tempDir.toFile()) + .withEnvironment(mapOf("PORT" to container.firstMappedPort.toString())) + .withArguments(":client:HelloWorldClient") + .withGradleDistribution(distributionUrl) + .build() + + assertTrue(clientResult.output.contains("Received: Hello world")) + } } diff --git a/interop_testing/build.gradle.kts b/interop_testing/build.gradle.kts index d2c33adc..be893739 100644 --- a/interop_testing/build.gradle.kts +++ b/interop_testing/build.gradle.kts @@ -1,78 +1,70 @@ import com.google.protobuf.gradle.* -plugins { - application -} +plugins { application } dependencies { - implementation(kotlin("test")) - implementation(libs.kotlinx.coroutines.core) + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.core) - implementation(project(":stub")) + implementation(project(":stub")) - implementation(libs.grpc.protobuf) - implementation(libs.grpc.protobuf.lite) - implementation(libs.grpc.auth) - implementation(libs.grpc.alts) - implementation(libs.grpc.netty) - implementation(libs.grpc.okhttp) - implementation(libs.grpc.testing) + implementation(libs.grpc.protobuf) + implementation(libs.grpc.protobuf.lite) + implementation(libs.grpc.auth) + implementation(libs.grpc.alts) + implementation(libs.grpc.netty) + implementation(libs.grpc.okhttp) + implementation(libs.grpc.testing) - implementation(libs.protobuf.java) + implementation(libs.protobuf.java) - implementation(libs.truth) + implementation(libs.truth) - testImplementation(libs.mockito.core) - testImplementation(libs.okhttp) { - because("transitive dep for grpc-okhttp") - } + testImplementation(libs.mockito.core) + testImplementation(libs.okhttp) { because("transitive dep for grpc-okhttp") } } protobuf { - protoc { - artifact = libs.protoc.asProvider().get().toString() - } - plugins { - id("grpc") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - id("grpckt") { - path = project(":compiler").tasks.jar.get().archiveFile.get().asFile.absolutePath - } + protoc { artifact = libs.protoc.asProvider().get().toString() } + plugins { + id("grpc") { artifact = libs.protoc.gen.grpc.java.get().toString() } + id("grpckt") { + path = project(":compiler").tasks.jar.get().archiveFile.get().asFile.absolutePath } - generateProtoTasks { - all().forEach { - if (it.name.startsWith("generateTestProto") || it.name.startsWith("generateProto")) { - it.dependsOn(":compiler:jar") - } + } + generateProtoTasks { + all().forEach { + if (it.name.startsWith("generateTestProto") || it.name.startsWith("generateProto")) { + it.dependsOn(":compiler:jar") + } - it.plugins { - id("grpc") - id("grpckt") - } - } + it.plugins { + id("grpc") + id("grpckt") + } } + } } -val testServiceClientStartScripts = tasks.register("testServiceClientStartScripts") { +val testServiceClientStartScripts = + tasks.register("testServiceClientStartScripts") { mainClass.set("io.grpc.testing.integration.TestServiceClient") applicationName = "test-service-client" outputDir = tasks.named("startScripts").get().outputDir classpath = tasks.named("startScripts").get().classpath -} + } -val testServiceServerStartScripts = tasks.register("testServiceServerStartScripts") { +val testServiceServerStartScripts = + tasks.register("testServiceServerStartScripts") { mainClass.set("io.grpc.testing.integration.TestServiceServer") applicationName = "test-service-server" outputDir = tasks.named("startScripts").get().outputDir classpath = tasks.named("startScripts").get().classpath -} + } tasks.named("startScripts") { - dependsOn(testServiceClientStartScripts) - dependsOn(testServiceServerStartScripts) + dependsOn(testServiceClientStartScripts) + dependsOn(testServiceServerStartScripts) } -tasks.withType { - enabled = false -} +tasks.withType { enabled = false } diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.kt b/interop_testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.kt index 44e6993f..e29b6f2a 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.kt +++ b/interop_testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.kt @@ -61,6 +61,21 @@ import io.grpc.testing.integration.Messages.StreamingInputCallRequest import io.grpc.testing.integration.Messages.StreamingInputCallResponse import io.grpc.testing.integration.Messages.StreamingOutputCallRequest import io.grpc.testing.integration.Messages.StreamingOutputCallResponse +import java.io.IOException +import java.io.InputStream +import java.net.SocketAddress +import java.util.Arrays +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.Executors +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.ScheduledExecutorService +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference +import java.util.logging.Level +import java.util.logging.Logger +import java.util.regex.Pattern +import kotlin.math.max +import kotlin.test.assertFailsWith import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -94,33 +109,16 @@ import org.junit.Test import org.junit.rules.DisableOnDebug import org.junit.rules.TestRule import org.junit.rules.Timeout -import java.io.IOException -import java.io.InputStream -import java.net.SocketAddress -import java.util.Arrays -import java.util.concurrent.ArrayBlockingQueue -import java.util.concurrent.Executors -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicReference -import java.util.logging.Level -import java.util.logging.Logger -import java.util.regex.Pattern -import kotlin.math.max -import kotlin.test.assertFailsWith /** * Abstract base class for all GRPC transport tests. * - * - * New tests should avoid using Mockito to support running on AppEngine. + * New tests should avoid using Mockito to support running on AppEngine. */ @ExperimentalCoroutinesApi @FlowPreview abstract class AbstractInteropTest { - @get:Rule - val globalTimeout: TestRule + @get:Rule val globalTimeout: TestRule private val serverCallCapture = AtomicReference>() private val clientCallCapture = AtomicReference>() private val requestHeadersCapture = AtomicReference() @@ -129,14 +127,11 @@ abstract class AbstractInteropTest { private var server: Server? = null private val serverStreamTracers = LinkedBlockingQueue() - private class ServerStreamTracerInfo internal constructor( - val fullMethodName: String, - val tracer: InteropServerStreamTracer - ) { + private class ServerStreamTracerInfo + internal constructor(val fullMethodName: String, val tracer: InteropServerStreamTracer) { class InteropServerStreamTracer : TestServerStreamTracer() { - @Volatile - var contextCapture: Context? = null + @Volatile var contextCapture: Context? = null override fun filterContext(context: Context): Context { contextCapture = context @@ -165,25 +160,22 @@ abstract class AbstractInteropTest { } val executor = Executors.newScheduledThreadPool(2) testServiceExecutor = executor - val allInterceptors: List = ImmutableList.builder() - .add(recordServerCallInterceptor(serverCallCapture)) - .add(TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture)) - .add(recordContextInterceptor(contextCapture)) - .addAll(TestServiceImpl.interceptors) - .build() + val allInterceptors: List = + ImmutableList.builder() + .add(recordServerCallInterceptor(serverCallCapture)) + .add(TestUtils.recordRequestHeadersInterceptor(requestHeadersCapture)) + .add(recordContextInterceptor(contextCapture)) + .addAll(TestServiceImpl.interceptors) + .build() builder - .addService( - ServerInterceptors.intercept( - TestServiceImpl(executor), - allInterceptors - ) - ) + .addService(ServerInterceptors.intercept(TestServiceImpl(executor), allInterceptors)) .addStreamTracerFactory(serverStreamTracerFactory) - server = try { - builder.build().start() - } catch (ex: IOException) { - throw RuntimeException(ex) - } + server = + try { + builder.build().start() + } catch (ex: IOException) { + throw RuntimeException(ex) + } } private fun stopServer() { @@ -207,12 +199,10 @@ abstract class AbstractInteropTest { protected lateinit var stub: TestServiceGrpcKt.TestServiceCoroutineStub // to be deleted when subclasses are ready to migrate - @JvmField - var blockingStub: TestServiceGrpc.TestServiceBlockingStub? = null + @JvmField var blockingStub: TestServiceGrpc.TestServiceBlockingStub? = null // to be deleted when subclasses are ready to migrate - @JvmField - var asyncStub: TestServiceGrpc.TestServiceStub? = null + @JvmField var asyncStub: TestServiceGrpc.TestServiceStub? = null private val clientStreamTracers = LinkedBlockingQueue() private val clientStreamTracerFactory: ClientStreamTracer.Factory = @@ -226,21 +216,18 @@ abstract class AbstractInteropTest { return tracer } } - private val tracerSetupInterceptor: ClientInterceptor = object : ClientInterceptor { - override fun interceptCall( - method: MethodDescriptor, - callOptions: CallOptions, - next: Channel - ): ClientCall { - return next.newCall( - method, callOptions.withStreamTracerFactory(clientStreamTracerFactory) - ) + private val tracerSetupInterceptor: ClientInterceptor = + object : ClientInterceptor { + override fun interceptCall( + method: MethodDescriptor, + callOptions: CallOptions, + next: Channel + ): ClientCall { + return next.newCall(method, callOptions.withStreamTracerFactory(clientStreamTracerFactory)) + } } - } - /** - * Must be called by the subclass setup method if overridden. - */ + /** Must be called by the subclass setup method if overridden. */ @Before fun setUp() { startServer() @@ -256,7 +243,7 @@ abstract class AbstractInteropTest { requestHeadersCapture.set(null) } - /** Clean up. */ + /** Clean up. */ @After open fun tearDown() { channel.shutdownNow() @@ -276,38 +263,30 @@ abstract class AbstractInteropTest { get() = null /** - * Returns the server builder used to create server for each test run. Return `null` if - * it shouldn't start a server in the same process. + * Returns the server builder used to create server for each test run. Return `null` if it + * shouldn't start a server in the same process. */ protected open val serverBuilder: ServerBuilder<*>? get() = null @Test fun emptyUnary() { - runBlocking { - assertEquals(EMPTY, stub.emptyCall(EMPTY)) - } + runBlocking { assertEquals(EMPTY, stub.emptyCall(EMPTY)) } } @Test fun largeUnary() { assumeEnoughMemory() - val request = SimpleRequest.newBuilder() - .setResponseSize(314159) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(271828))) - ) - .build() - val goldenResponse = SimpleResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(314159))) - ) - .build() - runBlocking { - assertResponse(goldenResponse, stub.unaryCall(request)) - } + val request = + SimpleRequest.newBuilder() + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() + runBlocking { assertResponse(goldenResponse, stub.unaryCall(request)) } assertStatsTrace( "grpc.testing.TestService/UnaryCall", Status.Code.OK, @@ -322,19 +301,22 @@ abstract class AbstractInteropTest { */ fun clientCompressedUnary(probe: Boolean) { assumeEnoughMemory() - val expectCompressedRequest = SimpleRequest.newBuilder() - .setExpectCompressed(BoolValue.newBuilder().setValue(true)) - .setResponseSize(314159) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) - .build() - val expectUncompressedRequest = SimpleRequest.newBuilder() - .setExpectCompressed(BoolValue.newBuilder().setValue(false)) - .setResponseSize(314159) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) - .build() - val goldenResponse = SimpleResponse.newBuilder() - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) - .build() + val expectCompressedRequest = + SimpleRequest.newBuilder() + .setExpectCompressed(BoolValue.newBuilder().setValue(true)) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() + val expectUncompressedRequest = + SimpleRequest.newBuilder() + .setExpectCompressed(BoolValue.newBuilder().setValue(false)) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() if (probe) { // Send a non-compressed message with expectCompress=true. Servers supporting this test case // should return INVALID_ARGUMENT. @@ -350,16 +332,21 @@ abstract class AbstractInteropTest { } runBlocking { assertResponse( - goldenResponse, stub.withCompression("gzip").unaryCall(expectCompressedRequest) + goldenResponse, + stub.withCompression("gzip").unaryCall(expectCompressedRequest) ) assertStatsTrace( "grpc.testing.TestService/UnaryCall", - Status.Code.OK, setOf(expectCompressedRequest), setOf(goldenResponse) + Status.Code.OK, + setOf(expectCompressedRequest), + setOf(goldenResponse) ) assertResponse(goldenResponse, stub.unaryCall(expectUncompressedRequest)) assertStatsTrace( "grpc.testing.TestService/UnaryCall", - Status.Code.OK, setOf(expectUncompressedRequest), setOf(goldenResponse) + Status.Code.OK, + setOf(expectUncompressedRequest), + setOf(goldenResponse) ) } } @@ -373,42 +360,48 @@ abstract class AbstractInteropTest { @Test fun serverCompressedUnary() { assumeEnoughMemory() - val responseShouldBeCompressed = SimpleRequest.newBuilder() - .setResponseCompressed(BoolValue.newBuilder().setValue(true)) - .setResponseSize(314159) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) - .build() - val responseShouldBeUncompressed = SimpleRequest.newBuilder() - .setResponseCompressed(BoolValue.newBuilder().setValue(false)) - .setResponseSize(314159) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) - .build() - val goldenResponse = SimpleResponse.newBuilder() - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) - .build() + val responseShouldBeCompressed = + SimpleRequest.newBuilder() + .setResponseCompressed(BoolValue.newBuilder().setValue(true)) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() + val responseShouldBeUncompressed = + SimpleRequest.newBuilder() + .setResponseCompressed(BoolValue.newBuilder().setValue(false)) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() runBlocking { assertResponse(goldenResponse, stub.unaryCall(responseShouldBeCompressed)) assertStatsTrace( "grpc.testing.TestService/UnaryCall", - Status.Code.OK, setOf(responseShouldBeCompressed), setOf(goldenResponse) + Status.Code.OK, + setOf(responseShouldBeCompressed), + setOf(goldenResponse) ) assertResponse(goldenResponse, stub.unaryCall(responseShouldBeUncompressed)) assertStatsTrace( "grpc.testing.TestService/UnaryCall", - Status.Code.OK, setOf(responseShouldBeUncompressed), setOf(goldenResponse) + Status.Code.OK, + setOf(responseShouldBeUncompressed), + setOf(goldenResponse) ) } } - /** - * Assuming "pick_first" policy is used, tests that all requests are sent to the same server. - */ + /** Assuming "pick_first" policy is used, tests that all requests are sent to the same server. */ fun pickFirstUnary() { - val request = SimpleRequest.newBuilder() - .setResponseSize(1) - .setFillServerId(true) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(1)))) - .build() + val request = + SimpleRequest.newBuilder() + .setResponseSize(1) + .setFillServerId(true) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(1)))) + .build() runBlocking { val firstResponse = stub.unaryCall(request) // Increase the chance of all servers are connected, in case the channel should be doing @@ -423,108 +416,70 @@ abstract class AbstractInteropTest { @Test fun serverStreaming() { - val request = StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(31415) - ) - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(9) - ) - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(2653) - ) - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(58979) - ) - .build() - val goldenResponses = Arrays.asList( - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(31415))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(9))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(2653))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(58979))) - ) + val request = + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(31415)) + .addResponseParameters(ResponseParameters.newBuilder().setSize(9)) + .addResponseParameters(ResponseParameters.newBuilder().setSize(2653)) + .addResponseParameters(ResponseParameters.newBuilder().setSize(58979)) .build() - ) - runBlocking { - assertResponses(goldenResponses, stub.streamingOutputCall(request).toList()) - } + val goldenResponses = + Arrays.asList( + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(31415)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(9)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(2653)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(58979)))) + .build() + ) + runBlocking { assertResponses(goldenResponses, stub.streamingOutputCall(request).toList()) } } @Test fun clientStreaming() { - val requests = Arrays.asList( - StreamingInputCallRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(27182))) - ) - .build(), - StreamingInputCallRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(8))) - ) - .build(), - StreamingInputCallRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(1828))) - ) - .build(), - StreamingInputCallRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(45904))) - ) - .build() - ) - val goldenResponse = StreamingInputCallResponse.newBuilder() - .setAggregatedPayloadSize(74922) - .build() - val response = runBlocking { - stub.streamingInputCall(requests.asFlow()) - } + val requests = + Arrays.asList( + StreamingInputCallRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(27182)))) + .build(), + StreamingInputCallRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(8)))) + .build(), + StreamingInputCallRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(1828)))) + .build(), + StreamingInputCallRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(45904)))) + .build() + ) + val goldenResponse = + StreamingInputCallResponse.newBuilder().setAggregatedPayloadSize(74922).build() + val response = runBlocking { stub.streamingInputCall(requests.asFlow()) } assertEquals(goldenResponse, response) } - /** - * Unsupported. - */ + /** Unsupported. */ open fun clientCompressedStreaming(probe: Boolean) { - val expectCompressedRequest = StreamingInputCallRequest.newBuilder() - .setExpectCompressed(BoolValue.newBuilder().setValue(true)) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(27182)))) - .build() + val expectCompressedRequest = + StreamingInputCallRequest.newBuilder() + .setExpectCompressed(BoolValue.newBuilder().setValue(true)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(27182)))) + .build() if (probe) { runBlocking { - val ex = assertFailsWith { - // Send a non-compressed message with expectCompress=true. Servers supporting this test - // case should return INVALID_ARGUMENT. - stub.streamingInputCall(flowOf(expectCompressedRequest)) - } + val ex = + assertFailsWith { + // Send a non-compressed message with expectCompress=true. Servers supporting this test + // case should return INVALID_ARGUMENT. + stub.streamingInputCall(flowOf(expectCompressedRequest)) + } assertEquals(Status.INVALID_ARGUMENT.code, ex.status.code) } } @@ -538,101 +493,67 @@ abstract class AbstractInteropTest { * cannot itself verify that the response was compressed. */ fun serverCompressedStreaming() { - val request = StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setCompressed(BoolValue.newBuilder().setValue(true)) - .setSize(31415) - ) - .addResponseParameters( - ResponseParameters.newBuilder() - .setCompressed(BoolValue.newBuilder().setValue(false)) - .setSize(92653) - ) - .build() - val goldenResponses = Arrays.asList( - StreamingOutputCallResponse.newBuilder() - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(31415)))) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(92653)))) - .build() - ) - runBlocking { - assertResponses(goldenResponses, stub.streamingOutputCall(request).toList()) - } - } - - @Test - fun pingPong() { - val requests = Arrays.asList( + val request = StreamingOutputCallRequest.newBuilder() .addResponseParameters( ResponseParameters.newBuilder() + .setCompressed(BoolValue.newBuilder().setValue(true)) .setSize(31415) ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(27182))) - ) - .build(), - StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(9) - ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(8))) - ) - .build(), - StreamingOutputCallRequest.newBuilder() .addResponseParameters( ResponseParameters.newBuilder() - .setSize(2653) - ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(1828))) - ) - .build(), - StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(58979) - ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(45904))) - ) - .build() - ) - val goldenResponses = listOf( - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(31415))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(9))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(2653))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(58979))) + .setCompressed(BoolValue.newBuilder().setValue(false)) + .setSize(92653) ) .build() - ) + val goldenResponses = + Arrays.asList( + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(31415)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(92653)))) + .build() + ) + runBlocking { assertResponses(goldenResponses, stub.streamingOutputCall(request).toList()) } + } + + @Test + fun pingPong() { + val requests = + Arrays.asList( + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(31415)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(27182)))) + .build(), + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(9)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(8)))) + .build(), + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(2653)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(1828)))) + .build(), + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(58979)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(45904)))) + .build() + ) + val goldenResponses = + listOf( + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(31415)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(9)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(2653)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(58979)))) + .build() + ) runBlocking { // TODO: per-element timeout assertResponses(goldenResponses, stub.fullDuplexCall(requests.asFlow()).toList()) @@ -641,23 +562,13 @@ abstract class AbstractInteropTest { @Test fun emptyStream() { - runBlocking { - assertResponses(listOf(), stub.fullDuplexCall(flowOf()).toList()) - } + runBlocking { assertResponses(listOf(), stub.fullDuplexCall(flowOf()).toList()) } } @Test fun cancelAfterBegin() { class MyEx : Exception() - runBlocking { - assertFailsWith { - stub.streamingInputCall( - flow { - throw MyEx() - } - ) - } - } + runBlocking { assertFailsWith { stub.streamingInputCall(flow { throw MyEx() }) } } } @Test @@ -677,9 +588,7 @@ abstract class AbstractInteropTest { val request = streamingOutputBuilder.build() val numRequests = 10 val responses = runBlocking { - stub.fullDuplexCall( - (1..numRequests).asFlow().map { request } - ).toList() + stub.fullDuplexCall((1..numRequests).asFlow().map { request }).toList() } assertEquals(responseSizes.size * numRequests, responses.size) for ((ix, response) in responses.withIndex()) { @@ -720,22 +629,20 @@ abstract class AbstractInteropTest { @Test fun serverStreamingShouldBeFlowControlled() { - val request = StreamingOutputCallRequest.newBuilder() - .addResponseParameters(ResponseParameters.newBuilder().setSize(100000)) - .addResponseParameters(ResponseParameters.newBuilder().setSize(100001)) - .build() - val goldenResponses = Arrays.asList( - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(100000))) - ).build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(100001))) - ).build() - ) + val request = + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(100000)) + .addResponseParameters(ResponseParameters.newBuilder().setSize(100001)) + .build() + val goldenResponses = + Arrays.asList( + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(100000)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(100001)))) + .build() + ) val start = System.nanoTime() // TODO(lowasser): change this to a Channel @@ -744,6 +651,7 @@ abstract class AbstractInteropTest { call.start( object : ClientCall.Listener() { override fun onHeaders(headers: Metadata) {} + override fun onMessage(message: StreamingOutputCallResponse) { queue.add(message) } @@ -781,39 +689,31 @@ abstract class AbstractInteropTest { @Test fun veryLargeRequest() { assumeEnoughMemory() - val request = SimpleRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(unaryPayloadLength()))) - ) - .setResponseSize(10) - .build() - val goldenResponse = SimpleResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(10))) - ) - .build() - runBlocking { - assertResponse(goldenResponse, stub.unaryCall(request)) - } + val request = + SimpleRequest.newBuilder() + .setPayload( + Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(unaryPayloadLength()))) + ) + .setResponseSize(10) + .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(10)))) + .build() + runBlocking { assertResponse(goldenResponse, stub.unaryCall(request)) } } @Test fun veryLargeResponse() { assumeEnoughMemory() - val request = SimpleRequest.newBuilder() - .setResponseSize(unaryPayloadLength()) - .build() - val goldenResponse = SimpleResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(unaryPayloadLength()))) - ) - .build() - runBlocking { - assertResponse(goldenResponse, stub.unaryCall(request)) - } + val request = SimpleRequest.newBuilder().setResponseSize(unaryPayloadLength()).build() + val goldenResponse = + SimpleResponse.newBuilder() + .setPayload( + Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(unaryPayloadLength()))) + ) + .build() + runBlocking { assertResponse(goldenResponse, stub.unaryCall(request)) } } @Test @@ -827,10 +727,11 @@ abstract class AbstractInteropTest { // .. and expect it to be echoed back in trailers val trailersCapture = AtomicReference() val headersCapture = AtomicReference() - stub = stub.withInterceptors(MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)) - runBlocking { - assertNotNull(stub.emptyCall(EMPTY)) - } + stub = + stub.withInterceptors( + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture) + ) + runBlocking { assertNotNull(stub.emptyCall(EMPTY)) } // Assert that our side channel object is echoed back in both headers and trailers assertEquals(contextValue, headersCapture.get().get(Util.METADATA_KEY)) assertEquals(contextValue, trailersCapture.get().get(Util.METADATA_KEY)) @@ -847,7 +748,10 @@ abstract class AbstractInteropTest { // .. and expect it to be echoed back in trailers val trailersCapture = AtomicReference() val headersCapture = AtomicReference() - stub = stub.withInterceptors(MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + stub = + stub.withInterceptors( + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture) + ) val responseSizes = listOf(50, 100, 150, 200) val streamingOutputBuilder = StreamingOutputCallRequest.newBuilder() for (size in responseSizes) { @@ -876,10 +780,7 @@ abstract class AbstractInteropTest { .withDeadlineAfter(10, TimeUnit.SECONDS) .streamingOutputCall( StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setIntervalUs(0) - ) + .addResponseParameters(ResponseParameters.newBuilder().setIntervalUs(0)) .build() ) .first() @@ -891,12 +792,12 @@ abstract class AbstractInteropTest { runBlocking { // warm up the channel and JVM stub.emptyCall(EmptyProtos.Empty.getDefaultInstance()) - val request = StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setIntervalUs(TimeUnit.SECONDS.toMicros(20).toInt()) - ) - .build() + val request = + StreamingOutputCallRequest.newBuilder() + .addResponseParameters( + ResponseParameters.newBuilder().setIntervalUs(TimeUnit.SECONDS.toMicros(20).toInt()) + ) + .build() try { stub.withDeadlineAfter(100, TimeUnit.MILLISECONDS).streamingOutputCall(request).first() fail("Expected deadline to be exceeded") @@ -921,21 +822,18 @@ abstract class AbstractInteropTest { runBlocking { // warm up the channel and JVM stub.emptyCall(EmptyProtos.Empty.getDefaultInstance()) - val responseParameters = ResponseParameters.newBuilder() - .setSize(1) - .setIntervalUs(10000) - val request = StreamingOutputCallRequest.newBuilder() - .addResponseParameters(responseParameters) - .addResponseParameters(responseParameters) - .addResponseParameters(responseParameters) - .addResponseParameters(responseParameters) - .build() - val statusEx = assertFailsWith { - stub - .withDeadlineAfter(30, TimeUnit.MILLISECONDS) - .streamingOutputCall(request) - .collect() - } + val responseParameters = ResponseParameters.newBuilder().setSize(1).setIntervalUs(10000) + val request = + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(responseParameters) + .addResponseParameters(responseParameters) + .addResponseParameters(responseParameters) + .addResponseParameters(responseParameters) + .build() + val statusEx = + assertFailsWith { + stub.withDeadlineAfter(30, TimeUnit.MILLISECONDS).streamingOutputCall(request).collect() + } assertEquals(Status.DEADLINE_EXCEEDED.code, statusEx.status.code) assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK) } @@ -978,58 +876,33 @@ abstract class AbstractInteropTest { @Test fun gracefulShutdown() { runBlocking { - val requests = listOf( - StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(3) - ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(2))) - ) - .build(), - StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(1) - ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(7))) - ) - .build(), - StreamingOutputCallRequest.newBuilder() - .addResponseParameters( - ResponseParameters.newBuilder() - .setSize(4) - ) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(1))) - ) - .build() - ) - val goldenResponses = listOf( - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(3))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(1))) - ) - .build(), - StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(4))) - ) - .build() - ) + val requests = + listOf( + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(3)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(2)))) + .build(), + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(1)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(7)))) + .build(), + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(4)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(1)))) + .build() + ) + val goldenResponses = + listOf( + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(3)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(1)))) + .build(), + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(4)))) + .build() + ) val requestChannel = kotlinx.coroutines.channels.Channel() @@ -1041,7 +914,8 @@ abstract class AbstractInteropTest { channel.shutdown() requestChannel.send(requests[1]) assertResponse(goldenResponses[1], responses.receive()) - // The previous ping-pong could have raced with the shutdown, but this one certainly shouldn't. + // The previous ping-pong could have raced with the shutdown, but this one certainly + // shouldn't. requestChannel.send(requests[2]) assertResponse(goldenResponses[2], responses.receive()) assertFalse(responses.isClosedForReceive) @@ -1054,32 +928,32 @@ abstract class AbstractInteropTest { fun customMetadata() { val responseSize = 314159 val requestSize = 271828 - val request = SimpleRequest.newBuilder() - .setResponseSize(responseSize) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(requestSize))) - ) - .build() - val streamingRequest = StreamingOutputCallRequest.newBuilder() - .addResponseParameters(ResponseParameters.newBuilder().setSize(responseSize)) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(requestSize)))) - .build() - val goldenResponse = SimpleResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(responseSize))) - ) - .build() - val goldenStreamingResponse = StreamingOutputCallResponse.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(responseSize))) - ) - .build() + val request = + SimpleRequest.newBuilder() + .setResponseSize(responseSize) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(requestSize)))) + .build() + val streamingRequest = + StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder().setSize(responseSize)) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(requestSize)))) + .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(responseSize)))) + .build() + val goldenStreamingResponse = + StreamingOutputCallResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(responseSize)))) + .build() val trailingBytes = byteArrayOf( - 0xa.toByte(), 0xb.toByte(), 0xa.toByte(), 0xb.toByte(), 0xa.toByte(), 0xb.toByte() + 0xa.toByte(), + 0xb.toByte(), + 0xa.toByte(), + 0xb.toByte(), + 0xa.toByte(), + 0xb.toByte() ) // Test UnaryCall var metadata = Metadata() @@ -1089,7 +963,10 @@ abstract class AbstractInteropTest { theStub = theStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)) var headersCapture = AtomicReference() var trailersCapture = AtomicReference() - theStub = theStub.withInterceptors(MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + theStub = + theStub.withInterceptors( + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture) + ) val response = runBlocking { theStub.unaryCall(request) } assertResponse(goldenResponse, response) assertEquals( @@ -1100,7 +977,10 @@ abstract class AbstractInteropTest { Arrays.equals(trailingBytes, trailersCapture.get().get(Util.ECHO_TRAILING_METADATA_KEY)) ) assertStatsTrace( - "grpc.testing.TestService/UnaryCall", Status.Code.OK, setOf(request), setOf(goldenResponse) + "grpc.testing.TestService/UnaryCall", + Status.Code.OK, + setOf(request), + setOf(goldenResponse) ) // Test FullDuplexCall metadata = Metadata() @@ -1108,12 +988,19 @@ abstract class AbstractInteropTest { metadata.put(Util.ECHO_TRAILING_METADATA_KEY, trailingBytes) var theStreamingStub = stub - theStreamingStub = theStreamingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)) + theStreamingStub = + theStreamingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)) headersCapture = AtomicReference() trailersCapture = AtomicReference() - theStreamingStub = theStreamingStub.withInterceptors(MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + theStreamingStub = + theStreamingStub.withInterceptors( + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture) + ) runBlocking { - assertResponse(goldenStreamingResponse, theStreamingStub.fullDuplexCall(flowOf(streamingRequest)).single()) + assertResponse( + goldenStreamingResponse, + theStreamingStub.fullDuplexCall(flowOf(streamingRequest)).single() + ) } assertEquals( "test_initial_metadata_value", @@ -1135,16 +1022,11 @@ abstract class AbstractInteropTest { runBlocking { val errorCode = 2 val errorMessage = "test status message" - val responseStatus = EchoStatus.newBuilder() - .setCode(errorCode) - .setMessage(errorMessage) - .build() - val simpleRequest = SimpleRequest.newBuilder() - .setResponseStatus(responseStatus) - .build() - val streamingRequest = StreamingOutputCallRequest.newBuilder() - .setResponseStatus(responseStatus) - .build() + val responseStatus = + EchoStatus.newBuilder().setCode(errorCode).setMessage(errorMessage).build() + val simpleRequest = SimpleRequest.newBuilder().setResponseStatus(responseStatus).build() + val streamingRequest = + StreamingOutputCallRequest.newBuilder().setResponseStatus(responseStatus).build() // Test UnaryCall try { stub.unaryCall(simpleRequest) @@ -1155,9 +1037,9 @@ abstract class AbstractInteropTest { } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.UNKNOWN) // Test FullDuplexCall - val status = assertFailsWith { - stub.fullDuplexCall(flowOf(streamingRequest)).collect() - }.status + val status = + assertFailsWith { stub.fullDuplexCall(flowOf(streamingRequest)).collect() } + .status assertEquals(Status.UNKNOWN.code, status.code) assertEquals(errorMessage, status.description) assertStatsTrace("grpc.testing.TestService/FullDuplexCall", Status.Code.UNKNOWN) @@ -1168,14 +1050,12 @@ abstract class AbstractInteropTest { fun specialStatusMessage() { val errorCode = 2 val errorMessage = "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n" - val simpleRequest = SimpleRequest.newBuilder() - .setResponseStatus( - EchoStatus.newBuilder() - .setCode(errorCode) - .setMessage(errorMessage) - .build() - ) - .build() + val simpleRequest = + SimpleRequest.newBuilder() + .setResponseStatus( + EchoStatus.newBuilder().setCode(errorCode).setMessage(errorMessage).build() + ) + .build() runBlocking { try { stub.unaryCall(simpleRequest) @@ -1188,13 +1068,14 @@ abstract class AbstractInteropTest { assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.UNKNOWN) } - /** Sends an rpc to an unimplemented method within TestService. */ + /** Sends an rpc to an unimplemented method within TestService. */ @Test fun unimplementedMethod() { runBlocking { - val ex = assertFailsWith { - stub.unimplementedCall(EmptyProtos.Empty.getDefaultInstance()) - } + val ex = + assertFailsWith { + stub.unimplementedCall(EmptyProtos.Empty.getDefaultInstance()) + } assertEquals(Status.UNIMPLEMENTED.code, ex.status.code) assertClientStatsTrace( "grpc.testing.TestService/UnimplementedCall", @@ -1203,16 +1084,17 @@ abstract class AbstractInteropTest { } } - /** Sends an rpc to an unimplemented service on the server. */ + /** Sends an rpc to an unimplemented service on the server. */ @Test fun unimplementedService() { val stub = UnimplementedServiceGrpcKt.UnimplementedServiceCoroutineStub(channel) .withInterceptors(tracerSetupInterceptor) runBlocking { - val ex = assertFailsWith { - stub.unimplementedCall(EmptyProtos.Empty.getDefaultInstance()) - } + val ex = + assertFailsWith { + stub.unimplementedCall(EmptyProtos.Empty.getDefaultInstance()) + } assertEquals(Status.UNIMPLEMENTED.code, ex.status.code) } assertStatsTrace( @@ -1221,21 +1103,22 @@ abstract class AbstractInteropTest { ) } - /** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */ + /** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */ @Test fun timeoutOnSleepingServer() { val stub = stub.withDeadlineAfter(1, TimeUnit.MILLISECONDS) - val request = StreamingOutputCallRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(27182))) - ) - .build() + val request = + StreamingOutputCallRequest.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(27182)))) + .build() runBlocking { val caught = CompletableDeferred() val responses = stub.fullDuplexCall(flowOf(request)).catch { caught.complete(it) }.toList() assertThat(responses).isEmpty() - assertEquals(Status.DEADLINE_EXCEEDED.code, (caught.getCompleted() as StatusException).status.code) + assertEquals( + Status.DEADLINE_EXCEEDED.code, + (caught.getCompleted() as StatusException).status.code + ) } } @@ -1249,57 +1132,48 @@ abstract class AbstractInteropTest { assertNotNull(obtainLocalClientAddr()) } - /** Sends a large unary rpc with service account credentials. */ + /** Sends a large unary rpc with service account credentials. */ fun serviceAccountCreds(jsonKey: String, credentialsStream: InputStream?, authScope: String) { // cast to ServiceAccountCredentials to double-check the right type of object was created. var credentials: GoogleCredentials = GoogleCredentials.fromStream(credentialsStream) as ServiceAccountCredentials credentials = credentials.createScoped(listOf(authScope)) val stub = this.stub.withCallCredentials(MoreCallCredentials.from(credentials)) - val request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .setResponseSize(314159) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(271828))) - ) - .build() + val request = + SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() val response = runBlocking { stub.unaryCall(request) } assertFalse(response.username.isEmpty()) - assertTrue( - "Received username: " + response.username, - jsonKey.contains(response.username) - ) + assertTrue("Received username: " + response.username, jsonKey.contains(response.username)) assertFalse(response.oauthScope.isEmpty()) assertTrue( "Received oauth scope: " + response.oauthScope, authScope.contains(response.oauthScope) ) - val goldenResponse = SimpleResponse.newBuilder() - .setOauthScope(response.oauthScope) - .setUsername(response.username) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(314159))) - ) - .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setOauthScope(response.oauthScope) + .setUsername(response.username) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() assertResponse(goldenResponse, response) } - /** Sends a large unary rpc with compute engine credentials. */ + /** Sends a large unary rpc with compute engine credentials. */ fun computeEngineCreds(serviceAccount: String?, oauthScope: String) { val credentials = ComputeEngineCredentials.create() val stub = stub.withCallCredentials(MoreCallCredentials.from(credentials)) - val request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .setResponseSize(314159) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(271828))) - ) - .build() + val request = + SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() val response = runBlocking { stub.unaryCall(request) } assertEquals(serviceAccount, response.username) assertFalse(response.oauthScope.isEmpty()) @@ -1307,52 +1181,44 @@ abstract class AbstractInteropTest { "Received oauth scope: " + response.oauthScope, oauthScope.contains(response.oauthScope) ) - val goldenResponse = SimpleResponse.newBuilder() - .setOauthScope(response.oauthScope) - .setUsername(response.username) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(314159))) - ) - .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setOauthScope(response.oauthScope) + .setUsername(response.username) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() assertResponse(goldenResponse, response) } - /** Sends an unary rpc with ComputeEngineChannelBuilder. */ + /** Sends an unary rpc with ComputeEngineChannelBuilder. */ fun computeEngineChannelCredentials( defaultServiceAccount: String, computeEngineStub: TestServiceGrpcKt.TestServiceCoroutineStub ) = runBlocking { - val request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setResponseSize(314159) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(271828))) - ) - .build() + val request = + SimpleRequest.newBuilder() + .setFillUsername(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() val response = computeEngineStub.unaryCall(request) assertEquals(defaultServiceAccount, response.username) - val goldenResponse = SimpleResponse.newBuilder() - .setUsername(defaultServiceAccount) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(314159))) - ) - .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setUsername(defaultServiceAccount) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() assertResponse(goldenResponse, response) } - /** Test JWT-based auth. */ + /** Test JWT-based auth. */ fun jwtTokenCreds(serviceAccountJson: InputStream?) { - val request = SimpleRequest.newBuilder() - .setResponseSize(314159) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(271828))) - ) - .setFillUsername(true) - .build() + val request = + SimpleRequest.newBuilder() + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .setFillUsername(true) + .build() val credentials = GoogleCredentials.fromStream(serviceAccountJson) as ServiceAccountCredentials val response = runBlocking { stub.withCallCredentials(MoreCallCredentials.from(credentials)).unaryCall(request) @@ -1361,24 +1227,18 @@ abstract class AbstractInteropTest { assertEquals(314159, response.payload.body.size().toLong()) } - /** Sends a unary rpc with raw oauth2 access token credentials. */ + /** Sends a unary rpc with raw oauth2 access token credentials. */ fun oauth2AuthToken(jsonKey: String, credentialsStream: InputStream, authScope: String) { var utilCredentials = GoogleCredentials.fromStream(credentialsStream) utilCredentials = utilCredentials.createScoped(listOf(authScope)) val accessToken = utilCredentials.refreshAccessToken() val credentials = OAuth2Credentials.create(accessToken) - val request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .build() + val request = SimpleRequest.newBuilder().setFillUsername(true).setFillOauthScope(true).build() val response = runBlocking { stub.withCallCredentials(MoreCallCredentials.from(credentials)).unaryCall(request) } assertFalse(response.username.isEmpty()) - assertTrue( - "Received username: " + response.username, - jsonKey.contains(response.username) - ) + assertTrue("Received username: " + response.username, jsonKey.contains(response.username)) assertFalse(response.oauthScope.isEmpty()) assertTrue( "Received oauth scope: " + response.oauthScope, @@ -1386,7 +1246,7 @@ abstract class AbstractInteropTest { ) } - /** Sends a unary rpc with "per rpc" raw oauth2 access token credentials. */ + /** Sends a unary rpc with "per rpc" raw oauth2 access token credentials. */ fun perRpcCreds(jsonKey: String, credentialsStream: InputStream, oauthScope: String) { // In gRpc Java, we don't have per Rpc credentials, user can use an intercepted stub only once // for that purpose. @@ -1394,32 +1254,28 @@ abstract class AbstractInteropTest { oauth2AuthToken(jsonKey, credentialsStream, oauthScope) } - /** Sends an unary rpc with "google default credentials". */ + /** Sends an unary rpc with "google default credentials". */ fun googleDefaultCredentials( defaultServiceAccount: String, googleDefaultStub: TestServiceGrpcKt.TestServiceCoroutineStub ) = runBlocking { - val request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setResponseSize(314159) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(271828))) - ) - .build() + val request = + SimpleRequest.newBuilder() + .setFillUsername(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(271828)))) + .build() val response = googleDefaultStub.unaryCall(request) assertEquals(defaultServiceAccount, response.username) - val goldenResponse = SimpleResponse.newBuilder() - .setUsername(defaultServiceAccount) - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFrom(ByteArray(314159))) - ) - .build() + val goldenResponse = + SimpleResponse.newBuilder() + .setUsername(defaultServiceAccount) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(ByteArray(314159)))) + .build() assertResponse(goldenResponse, response) } - /** Helper for getting remote address from [io.grpc.ClientCall.getAttributes] */ + /** Helper for getting remote address from [io.grpc.ClientCall.getAttributes] */ private fun obtainRemoteServerAddr(): SocketAddress? { return runBlocking { stub @@ -1430,7 +1286,7 @@ abstract class AbstractInteropTest { } } - /** Helper for getting local address from [io.grpc.ClientCall.getAttributes] */ + /** Helper for getting local address from [io.grpc.ClientCall.getAttributes] */ private fun obtainLocalClientAddr(): SocketAddress? { return runBlocking { stub @@ -1502,7 +1358,7 @@ abstract class AbstractInteropTest { assertEquals(method, tracerInfo!!.fullMethodName) assertNotNull(tracerInfo.tracer.contextCapture) // On the server, streamClosed() may be called after the client receives the final status. -// So we use a timeout. + // So we use a timeout. try { assertTrue(tracerInfo.tracer.await(1, TimeUnit.SECONDS)) } catch (e: InterruptedException) { @@ -1514,9 +1370,7 @@ abstract class AbstractInteropTest { } } - /** - * Check information recorded by tracers. - */ + /** Check information recorded by tracers. */ private fun checkTracers( tracer: TestStreamTracer, sentMessages: Collection, @@ -1526,9 +1380,8 @@ abstract class AbstractInteropTest { var seqNo = 0 for (msg in sentMessages) { assertThat(tracer.nextOutboundEvent()).isEqualTo(String.format("outboundMessage(%d)", seqNo)) - assertThat(tracer.nextOutboundEvent()).matches( - String.format("outboundMessageSent\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo) - ) + assertThat(tracer.nextOutboundEvent()) + .matches(String.format("outboundMessageSent\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo)) seqNo++ uncompressedSentSize += msg.serializedSize.toLong() } @@ -1537,9 +1390,8 @@ abstract class AbstractInteropTest { seqNo = 0 for (msg in receivedMessages) { assertThat(tracer.nextInboundEvent()).isEqualTo(String.format("inboundMessage(%d)", seqNo)) - assertThat(tracer.nextInboundEvent()).matches( - String.format("inboundMessageRead\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo) - ) + assertThat(tracer.nextInboundEvent()) + .matches(String.format("inboundMessageRead\\(%d, -?[0-9]+, -?[0-9]+\\)", seqNo)) uncompressedReceivedSize += msg.serializedSize.toLong() seqNo++ } @@ -1588,13 +1440,12 @@ abstract class AbstractInteropTest { companion object { private val logger = Logger.getLogger(AbstractInteropTest::class.java.name) - /** Must be at least [.unaryPayloadLength], plus some to account for encoding overhead. */ + /** Must be at least [.unaryPayloadLength], plus some to account for encoding overhead. */ const val MAX_MESSAGE_SIZE = 16 * 1024 * 1024 - @JvmField - protected val EMPTY = EmptyProtos.Empty.getDefaultInstance() + @JvmField protected val EMPTY = EmptyProtos.Empty.getDefaultInstance() /** - * Some tests run on memory constrained environments. Rather than OOM, just give up. 64 is + * Some tests run on memory constrained environments. Rather than OOM, just give up. 64 is * chosen as a maximum amount of memory a large test would need. */ private fun assumeEnoughMemory() { @@ -1608,8 +1459,7 @@ abstract class AbstractInteropTest { } /** - * Captures the request attributes. Useful for testing ServerCalls. - * [ServerCall.getAttributes] + * Captures the request attributes. Useful for testing ServerCalls. [ServerCall.getAttributes] */ private fun recordServerCallInterceptor( serverCallCapture: AtomicReference> @@ -1627,8 +1477,7 @@ abstract class AbstractInteropTest { } /** - * Captures the request attributes. Useful for testing ClientCalls. - * [ClientCall.getAttributes] + * Captures the request attributes. Useful for testing ClientCalls. [ClientCall.getAttributes] */ private fun recordClientCallInterceptor( clientCallCapture: AtomicReference> @@ -1659,9 +1508,7 @@ abstract class AbstractInteropTest { } } - /** - * Constructor for tests. - */ + /** Constructor for tests. */ init { var timeout: TestRule = Timeout.seconds(60) try { diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java b/interop_testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java index b064ee74..e051588d 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java +++ b/interop_testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java @@ -18,9 +18,7 @@ import com.google.common.base.Preconditions; -/** - * Enum of HTTP/2 interop test cases. - */ +/** Enum of HTTP/2 interop test cases. */ public enum Http2TestCases { RST_AFTER_HEADER("server resets stream after sending header"), RST_AFTER_DATA("server resets stream after sending data"), @@ -35,16 +33,14 @@ public enum Http2TestCases { this.description = description; } - /** - * Returns a description of the test case. - */ + /** Returns a description of the test case. */ public String description() { return description; } /** - * Returns the {@link Http2TestCases} matching the string {@code s}. The - * matching is case insensitive. + * Returns the {@link Http2TestCases} matching the string {@code s}. The matching is case + * insensitive. */ public static Http2TestCases fromString(String s) { Preconditions.checkNotNull(s, "s"); diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/TestCases.java b/interop_testing/src/main/java/io/grpc/testing/integration/TestCases.java index 2d1648e1..fda15c53 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/TestCases.java +++ b/interop_testing/src/main/java/io/grpc/testing/integration/TestCases.java @@ -18,9 +18,7 @@ import com.google.common.base.Preconditions; -/** - * Enum of interop test cases. - */ +/** Enum of interop test cases. */ public enum TestCases { EMPTY_UNARY("empty (zero bytes) request and response"), CACHEABLE_UNARY("cacheable unary rpc sent using GET"), @@ -43,8 +41,7 @@ public enum TestCases { JWT_TOKEN_CREDS("JWT-based auth"), OAUTH2_AUTH_TOKEN("raw oauth2 access token auth"), PER_RPC_CREDS("per rpc raw oauth2 access token auth"), - GOOGLE_DEFAULT_CREDENTIALS( - "google default credentials, i.e. GoogleManagedChannel based auth"), + GOOGLE_DEFAULT_CREDENTIALS("google default credentials, i.e. GoogleManagedChannel based auth"), CUSTOM_METADATA("unary and full duplex calls with metadata"), STATUS_CODE_AND_MESSAGE("request error code and message"), SPECIAL_STATUS_MESSAGE("special characters in status message"), @@ -62,16 +59,14 @@ public enum TestCases { this.description = description; } - /** - * Returns a description of the test case. - */ + /** Returns a description of the test case. */ public String description() { return description; } /** - * Returns the {@link TestCases} matching the string {@code s}. The - * matching is done case insensitive. + * Returns the {@link TestCases} matching the string {@code s}. The matching is done case + * insensitive. */ public static TestCases fromString(String s) { Preconditions.checkNotNull(s, "s"); diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceClient.kt b/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceClient.kt index 557103d5..365d4e95 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceClient.kt +++ b/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceClient.kt @@ -30,16 +30,16 @@ import io.grpc.okhttp.OkHttpChannelBuilder import io.grpc.okhttp.internal.Platform import io.grpc.testing.integration.TestServiceGrpcKt.TestServiceCoroutineStub import io.netty.handler.ssl.SslContext -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.FlowPreview import java.io.File import java.io.FileInputStream import java.util.concurrent.TimeUnit import kotlin.system.exitProcess +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.FlowPreview /** - * Application that starts a client for the [TestServiceGrpc.TestServiceImplBase] and runs - * through a series of tests. + * Application that starts a client for the [TestServiceGrpc.TestServiceImplBase] and runs through a + * series of tests. */ @ExperimentalCoroutinesApi @FlowPreview @@ -62,7 +62,7 @@ class TestServiceClient { @VisibleForTesting fun parseArgs(args: Array) { var usage = false - argsLoop@for (arg in args) { + argsLoop@ for (arg in args) { if (!arg.startsWith("--")) { System.err.println("All arguments must start with '--': $arg") usage = true @@ -136,7 +136,8 @@ class TestServiceClient { | --default_service_account Email of GCE default service account. Default ${c.defaultServiceAccount} | --service_account_key_file Path to service account json key file.${c.serviceAccountKeyFile} | --oauth_scope Scope for OAuth tokens. Default ${c.oauthScope} - """.trimMargin() + """ + .trimMargin() ) exitProcess(1) } @@ -231,8 +232,10 @@ class TestServiceClient { private inner class Tester : AbstractInteropTest() { override fun createChannel(): ManagedChannel { when (customCredentialsType) { - "google_default_credentials" -> return GoogleDefaultChannelBuilder.forAddress(serverHost, serverPort).build() - "compute_engine_channel_creds" -> return ComputeEngineChannelBuilder.forAddress(serverHost, serverPort).build() + "google_default_credentials" -> + return GoogleDefaultChannelBuilder.forAddress(serverHost, serverPort).build() + "compute_engine_channel_creds" -> + return ComputeEngineChannelBuilder.forAddress(serverHost, serverPort).build() } if (useAlts) { return AltsChannelBuilder.forAddress(serverHost, serverPort).build() @@ -244,16 +247,17 @@ class TestServiceClient { sslContext = GrpcSslContexts.forClient().trustManager(TestUtils.loadCert("ca.pem")).build() } - val nettyBuilder = NettyChannelBuilder.forAddress(serverHost, serverPort) - .flowControlWindow(65 * 1024) - .negotiationType( - when { - useTls -> NegotiationType.TLS - useH2cUpgrade -> NegotiationType.PLAINTEXT_UPGRADE - else -> NegotiationType.PLAINTEXT - } - ) - .sslContext(sslContext) + val nettyBuilder = + NettyChannelBuilder.forAddress(serverHost, serverPort) + .flowControlWindow(65 * 1024) + .negotiationType( + when { + useTls -> NegotiationType.TLS + useH2cUpgrade -> NegotiationType.PLAINTEXT_UPGRADE + else -> NegotiationType.PLAINTEXT + } + ) + .sslContext(sslContext) if (serverHostOverride != null) { nettyBuilder.overrideAuthority(serverHostOverride) } @@ -261,14 +265,15 @@ class TestServiceClient { } else { val okBuilder = OkHttpChannelBuilder.forAddress(serverHost, serverPort) if (serverHostOverride != null) { // Force the hostname to match the cert the server uses. - okBuilder.overrideAuthority( - Util.authorityFromHostAndPort(serverHostOverride, serverPort)) + okBuilder.overrideAuthority(Util.authorityFromHostAndPort(serverHostOverride, serverPort)) } if (useTls) { if (useTestCa) { - val factory = TestUtils.newSslSocketFactoryForCa( - Platform.get().provider, TestUtils.loadCert("ca.pem") - ) + val factory = + TestUtils.newSslSocketFactoryForCa( + Platform.get().provider, + TestUtils.loadCert("ca.pem") + ) okBuilder.sslSocketFactory(factory) } } else { @@ -283,9 +288,7 @@ class TestServiceClient { companion object { private val UTF_8 = Charsets.UTF_8 - /** - * The main application allowing this client to be launched from the command line. - */ + /** The main application allowing this client to be launched from the command line. */ @Throws(Exception::class) @JvmStatic fun main(args: Array) { // Let Netty or OkHttp use Conscrypt if it is available. @@ -293,16 +296,19 @@ class TestServiceClient { val client = TestServiceClient() client.parseArgs(args) client.setUp() - Runtime.getRuntime().addShutdownHook(object : Thread() { - override fun run() { - println("Shutting down") - try { - client.tearDown() - } catch (e: Exception) { - e.printStackTrace() + Runtime.getRuntime() + .addShutdownHook( + object : Thread() { + override fun run() { + println("Shutting down") + try { + client.tearDown() + } catch (e: Exception) { + e.printStackTrace() + } + } } - } - }) + ) try { client.run() } finally { @@ -315,10 +321,7 @@ class TestServiceClient { val builder = StringBuilder() for (testCase in TestCases.values()) { val strTestcase = testCase.name.lowercase() - builder.append("\n ") - .append(strTestcase) - .append(": ") - .append(testCase.description()) + builder.append("\n ").append(strTestcase).append(": ").append(testCase.description()) } return builder.toString() } diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.kt b/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.kt index 66d4df75..9c796d9d 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.kt +++ b/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.kt @@ -45,9 +45,8 @@ import kotlinx.coroutines.flow.toList */ @ExperimentalCoroutinesApi @FlowPreview // most of these methods are graduating imminently but that has not yet landed -class TestServiceImpl( - executor: Executor -) : TestServiceGrpcKt.TestServiceCoroutineImplBase(executor.asCoroutineDispatcher()) { +class TestServiceImpl(executor: Executor) : + TestServiceGrpcKt.TestServiceCoroutineImplBase(executor.asCoroutineDispatcher()) { private val random = Random() private val compressableBuffer: ByteString = ByteString.copyFrom(ByteArray(1024)) @@ -56,13 +55,11 @@ class TestServiceImpl( override suspend fun unaryCall(request: Messages.SimpleRequest): Messages.SimpleResponse { if (request.hasResponseStatus()) { - throw Status - .fromCodeValue(request.responseStatus.code) + throw Status.fromCodeValue(request.responseStatus.code) .withDescription(request.responseStatus.message) .asException() } - return Messages.SimpleResponse - .newBuilder() + return Messages.SimpleResponse.newBuilder() .apply { if (request.responseSize != 0) { val offset = random.nextInt(compressableBuffer.size()) @@ -80,11 +77,8 @@ class TestServiceImpl( for (params in request.responseParametersList) { delay(timeMillis = TimeUnit.MICROSECONDS.toMillis(params.intervalUs.toLong())) emit( - Messages.StreamingOutputCallResponse - .newBuilder() - .apply { - payload = generatePayload(compressableBuffer, offset, params.size) - } + Messages.StreamingOutputCallResponse.newBuilder() + .apply { payload = generatePayload(compressableBuffer, offset, params.size) } .build() ) offset += params.size @@ -96,11 +90,8 @@ class TestServiceImpl( override suspend fun streamingInputCall( requests: Flow ): Messages.StreamingInputCallResponse = - Messages.StreamingInputCallResponse - .newBuilder() - .apply { - aggregatedPayloadSize = requests.map { it.payload.body.size() }.sum() - } + Messages.StreamingInputCallResponse.newBuilder() + .apply { aggregatedPayloadSize = requests.map { it.payload.body.size() }.sum() } .build() override fun fullDuplexCall( @@ -108,8 +99,7 @@ class TestServiceImpl( ): Flow = requests.flatMapConcat { if (it.hasResponseStatus()) { - throw Status - .fromCodeValue(it.responseStatus.code) + throw Status.fromCodeValue(it.responseStatus.code) .withDescription(it.responseStatus.message) .asException() } @@ -118,21 +108,21 @@ class TestServiceImpl( override fun halfDuplexCall( requests: Flow - ): Flow = - flow { - val requestList = requests.toList() - emitAll(requestList.asFlow().flatMapConcat { streamingOutputCall(it) }) - } + ): Flow = flow { + val requestList = requests.toList() + emitAll(requestList.asFlow().flatMapConcat { streamingOutputCall(it) }) + } companion object { - /** Returns interceptors necessary for full service implementation. */ + /** Returns interceptors necessary for full service implementation. */ @get:JvmStatic @get:JvmName("interceptors") - val interceptors = listOf( - echoRequestHeadersInterceptor(Util.METADATA_KEY), - echoRequestMetadataInHeaders(Util.ECHO_INITIAL_METADATA_KEY), - echoRequestMetadataInTrailers(Util.ECHO_TRAILING_METADATA_KEY) - ) + val interceptors = + listOf( + echoRequestHeadersInterceptor(Util.METADATA_KEY), + echoRequestMetadataInHeaders(Util.ECHO_INITIAL_METADATA_KEY), + echoRequestMetadataInTrailers(Util.ECHO_TRAILING_METADATA_KEY) + ) suspend fun Flow.sum() = fold(0) { a, b -> a + b } @@ -157,8 +147,8 @@ class TestServiceImpl( } /** - * Echo the request headers from a client into response headers and trailers. Useful for - * testing end-to-end metadata propagation. + * Echo the request headers from a client into response headers and trailers. Useful for testing + * end-to-end metadata propagation. */ private fun echoRequestHeadersInterceptor(vararg keys: Metadata.Key<*>): ServerInterceptor { val keySet: Set> = keys.toSet() diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index f1d4d161..153d511f 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop_testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -117,11 +117,13 @@ void parseArgs(String[] args) { System.out.println( "Usage: [ARGS...]" + "\n" - + "\n --port=PORT Port to connect to. Default " + s.port - + "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls + + "\n --port=PORT Port to connect to. Default " + + s.port + + "\n --use_tls=true|false Whether to use TLS. Default " + + s.useTls + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." - + "\n Default " + s.useAlts - ); + + "\n Default " + + s.useAlts); System.exit(1); } } diff --git a/interop_testing/src/main/java/io/grpc/testing/integration/Util.java b/interop_testing/src/main/java/io/grpc/testing/integration/Util.java index 6ea6af19..866afa00 100644 --- a/interop_testing/src/main/java/io/grpc/testing/integration/Util.java +++ b/interop_testing/src/main/java/io/grpc/testing/integration/Util.java @@ -24,23 +24,19 @@ import java.util.List; import org.junit.Assert; -/** - * Utility methods to support integration testing. - */ +/** Utility methods to support integration testing. */ public class Util { public static final Metadata.Key METADATA_KEY = Metadata.Key.of( "grpc.testing.SimpleContext" + Metadata.BINARY_HEADER_SUFFIX, ProtoLiteUtils.metadataMarshaller(Messages.SimpleContext.getDefaultInstance())); - public static final Metadata.Key ECHO_INITIAL_METADATA_KEY - = Metadata.Key.of("x-grpc-test-echo-initial", Metadata.ASCII_STRING_MARSHALLER); - public static final Metadata.Key ECHO_TRAILING_METADATA_KEY - = Metadata.Key.of("x-grpc-test-echo-trailing-bin", Metadata.BINARY_BYTE_MARSHALLER); + public static final Metadata.Key ECHO_INITIAL_METADATA_KEY = + Metadata.Key.of("x-grpc-test-echo-initial", Metadata.ASCII_STRING_MARSHALLER); + public static final Metadata.Key ECHO_TRAILING_METADATA_KEY = + Metadata.Key.of("x-grpc-test-echo-trailing-bin", Metadata.BINARY_BYTE_MARSHALLER); - /** - * Combine a host and port into an authority string. - */ + /** Combine a host and port into an authority string. */ public static String authorityFromHostAndPort(String host, int port) { try { return new URI(null, null, host, port, null, null, null).getAuthority(); @@ -66,8 +62,8 @@ public static void assertEquals(MessageLite expected, MessageLite actual) { } /** Assert that two lists of messages are equal, producing a useful message if not. */ - public static void assertEquals(List expected, - List actual) { + public static void assertEquals( + List expected, List actual) { if (expected == null || actual == null) { Assert.assertEquals(expected, actual); } else if (expected.size() != actual.size()) { diff --git a/interop_testing/src/main/proto/grpc/testing/empty.proto b/interop_testing/src/main/proto/grpc/testing/empty.proto index bd626abe..32346b17 100644 --- a/interop_testing/src/main/proto/grpc/testing/empty.proto +++ b/interop_testing/src/main/proto/grpc/testing/empty.proto @@ -15,8 +15,8 @@ syntax = "proto2"; package grpc.testing; -option java_package = "io.grpc.testing.integration"; option java_outer_classname = "EmptyProtos"; +option java_package = "io.grpc.testing.integration"; // An empty message that you can re-use to avoid defining duplicated empty // messages in your project. A typical example is to use it as argument or the diff --git a/interop_testing/src/main/proto/grpc/testing/test.proto b/interop_testing/src/main/proto/grpc/testing/test.proto index 38845330..3388c8b7 100644 --- a/interop_testing/src/main/proto/grpc/testing/test.proto +++ b/interop_testing/src/main/proto/grpc/testing/test.proto @@ -38,26 +38,22 @@ service TestService { // One request followed by a sequence of responses (streamed download). // The server returns the payload with client desired type and sizes. - rpc StreamingOutputCall(StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); + rpc StreamingOutputCall(StreamingOutputCallRequest) returns (stream StreamingOutputCallResponse); // A sequence of requests followed by one response (streamed upload). // The server returns the aggregated size of client payload as the result. - rpc StreamingInputCall(stream StreamingInputCallRequest) - returns (StreamingInputCallResponse); + rpc StreamingInputCall(stream StreamingInputCallRequest) returns (StreamingInputCallResponse); // A sequence of requests with each request served by the server immediately. // As one request could lead to multiple responses, this interface // demonstrates the idea of full duplexing. - rpc FullDuplexCall(stream StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); + rpc FullDuplexCall(stream StreamingOutputCallRequest) returns (stream StreamingOutputCallResponse); // A sequence of requests followed by a sequence of responses. // The server buffers all the client requests and then serves them in order. A // stream of responses are returned to the client when the server starts with // first request. - rpc HalfDuplexCall(stream StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); + rpc HalfDuplexCall(stream StreamingOutputCallRequest) returns (stream StreamingOutputCallResponse); // The test server will not implement this method. It will be used // to test the behavior when clients call unimplemented methods. @@ -80,6 +76,5 @@ service ReconnectService { // A service used to obtain stats for verifying LB behavior. service LoadBalancerStatsService { // Gets the backend distribution for RPCs sent by a test client. - rpc GetClientStats(LoadBalancerStatsRequest) - returns (LoadBalancerStatsResponse) {} + rpc GetClientStats(LoadBalancerStatsRequest) returns (LoadBalancerStatsResponse) {} } diff --git a/interop_testing/src/test/java/io/grpc/stub/StubConfigTest.java b/interop_testing/src/test/java/io/grpc/stub/StubConfigTest.java index aa918b4f..5d4a6566 100644 --- a/interop_testing/src/test/java/io/grpc/stub/StubConfigTest.java +++ b/interop_testing/src/test/java/io/grpc/stub/StubConfigTest.java @@ -36,25 +36,19 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -/** - * Tests for stub reconfiguration. - */ +/** Tests for stub reconfiguration. */ @RunWith(JUnit4.class) public class StubConfigTest { - @Mock - private Channel channel; + @Mock private Channel channel; - @Mock - private StreamObserver responseObserver; + @Mock private StreamObserver responseObserver; - @Mock - private ClientCall call; + @Mock private ClientCall call; - /** - * Sets up mocks. - */ - @Before public void setUp() { + /** Sets up mocks. */ + @Before + public void setUp() { MockitoAnnotations.openMocks(this); when(channel.newCall( diff --git a/interop_testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java b/interop_testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java index f65d6efe..8fee86cc 100644 --- a/interop_testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java +++ b/interop_testing/src/test/java/io/grpc/testing/integration/Http2OkHttpTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertTrue; import com.google.common.base.Throwables; +import com.squareup.okhttp.ConnectionSpec; import io.grpc.ManagedChannel; import io.grpc.ServerBuilder; import io.grpc.internal.testing.StreamRecorder; @@ -41,17 +42,13 @@ import java.util.Arrays; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLPeerUnverifiedException; -import com.squareup.okhttp.ConnectionSpec; - import org.jetbrains.annotations.NotNull; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Integration tests for GRPC over Http2 using the OkHttp framework. - */ +/** Integration tests for GRPC over Http2 using the OkHttp framework. */ @RunWith(JUnit4.class) public class Http2OkHttpTest extends AbstractInteropTest { @@ -74,10 +71,12 @@ protected ServerBuilder getServerBuilder() { // are forced to use Jetty ALPN for Netty instead of OpenSSL. sslProvider = SslProvider.JDK; } - SslContextBuilder contextBuilder = SslContextBuilder - .forServer(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")); + SslContextBuilder contextBuilder = + SslContextBuilder.forServer( + TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")); GrpcSslContexts.configure(contextBuilder, sslProvider); - Iterable ciphers = Arrays.asList(SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites()); + Iterable ciphers = + Arrays.asList(SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites()); contextBuilder.ciphers(ciphers, SupportedCipherSuiteFilter.INSTANCE); return NettyServerBuilder.forPort(0) .flowControlWindow(65 * 1024) @@ -96,14 +95,15 @@ protected ManagedChannel createChannel() { private OkHttpChannelBuilder createChannelBuilder() { int port = ((InetSocketAddress) getListenAddress()).getPort(); - OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port) - .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) - .connectionSpec(ConnectionSpec.MODERN_TLS) - .overrideAuthority(Util.authorityFromHostAndPort( - TestUtils.TEST_SERVER_HOST, port)); + OkHttpChannelBuilder builder = + OkHttpChannelBuilder.forAddress("localhost", port) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .connectionSpec(ConnectionSpec.MODERN_TLS) + .overrideAuthority(Util.authorityFromHostAndPort(TestUtils.TEST_SERVER_HOST, port)); try { - builder.sslSocketFactory(TestUtils.newSslSocketFactoryForCa(Platform.get().getProvider(), - TestUtils.loadCert("ca.pem"))); + builder.sslSocketFactory( + TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem"))); } catch (Exception e) { throw new RuntimeException(e); } @@ -113,8 +113,7 @@ private OkHttpChannelBuilder createChannelBuilder() { @Test public void receivedDataForFinishedStream() throws Exception { Messages.ResponseParameters.Builder responseParameters = - Messages.ResponseParameters.newBuilder() - .setSize(1); + Messages.ResponseParameters.newBuilder().setSize(1); Messages.StreamingOutputCallRequest.Builder requestBuilder = Messages.StreamingOutputCallRequest.newBuilder(); for (int i = 0; i < 1000; i++) { @@ -137,12 +136,11 @@ public void receivedDataForFinishedStream() throws Exception { @Test public void wrongHostNameFailHostnameVerification() throws Exception { int port = ((InetSocketAddress) getListenAddress()).getPort(); - ManagedChannel channel = createChannelBuilder() - .overrideAuthority(Util.authorityFromHostAndPort( - BAD_HOSTNAME, port)) - .build(); - TestServiceGrpc.TestServiceBlockingStub blockingStub = - TestServiceGrpc.newBlockingStub(channel); + ManagedChannel channel = + createChannelBuilder() + .overrideAuthority(Util.authorityFromHostAndPort(BAD_HOSTNAME, port)) + .build(); + TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel); Throwable actualThrown = null; try { @@ -160,13 +158,12 @@ public void wrongHostNameFailHostnameVerification() throws Exception { @Test public void hostnameVerifierWithBadHostname() throws Exception { int port = ((InetSocketAddress) getListenAddress()).getPort(); - ManagedChannel channel = createChannelBuilder() - .overrideAuthority(Util.authorityFromHostAndPort( - BAD_HOSTNAME, port)) - .hostnameVerifier((hostname, session) -> true) - .build(); - TestServiceGrpc.TestServiceBlockingStub blockingStub = - TestServiceGrpc.newBlockingStub(channel); + ManagedChannel channel = + createChannelBuilder() + .overrideAuthority(Util.authorityFromHostAndPort(BAD_HOSTNAME, port)) + .hostnameVerifier((hostname, session) -> true) + .build(); + TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel); blockingStub.emptyCall(Empty.getDefaultInstance()); @@ -176,13 +173,12 @@ public void hostnameVerifierWithBadHostname() throws Exception { @Test public void hostnameVerifierWithCorrectHostname() throws Exception { int port = ((InetSocketAddress) getListenAddress()).getPort(); - ManagedChannel channel = createChannelBuilder() - .overrideAuthority(Util.authorityFromHostAndPort( - TestUtils.TEST_SERVER_HOST, port)) - .hostnameVerifier((hostname, session) -> false) - .build(); - TestServiceGrpc.TestServiceBlockingStub blockingStub = - TestServiceGrpc.newBlockingStub(channel); + ManagedChannel channel = + createChannelBuilder() + .overrideAuthority(Util.authorityFromHostAndPort(TestUtils.TEST_SERVER_HOST, port)) + .hostnameVerifier((hostname, session) -> false) + .build(); + TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel); Throwable actualThrown = null; try { diff --git a/interop_testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java b/interop_testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java index 4e511c1c..6531220d 100644 --- a/interop_testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java +++ b/interop_testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java @@ -25,9 +25,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Unit tests for {@link TestCases}. - */ +/** Unit tests for {@link TestCases}. */ @RunWith(JUnit4.class) public class TestCasesTest { diff --git a/kt_jvm_grpc.bzl b/kt_jvm_grpc.bzl index 35d1fc42..a811a24a 100644 --- a/kt_jvm_grpc.bzl +++ b/kt_jvm_grpc.bzl @@ -1,9 +1,9 @@ -load("@rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") load("@grpc-java//:java_grpc_library.bzl", "java_grpc_library") -load("@protobuf//bazel/common:proto_info.bzl", "ProtoInfo") load("@protobuf//bazel:java_lite_proto_library.bzl", "java_lite_proto_library") load("@protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@protobuf//bazel/common:proto_info.bzl", "ProtoInfo") load("@rules_java//java:defs.bzl", "JavaInfo") +load("@rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") def _invoke_generator(ctx, proto_dep, output_dir): direct_descriptor_set = depset([proto_dep[ProtoInfo].direct_descriptor_set]) diff --git a/settings.gradle.kts b/settings.gradle.kts index fb918628..547c19ae 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -3,9 +3,9 @@ rootProject.name = "grpc-kotlin" include("stub", "compiler", "interop_testing", "integration_testing") dependencyResolutionManagement { - @Suppress("UnstableApiUsage") - repositories { - mavenCentral() - maven("https://repo.gradle.org/gradle/libs-releases") - } + @Suppress("UnstableApiUsage") + repositories { + mavenCentral() + maven("https://repo.gradle.org/gradle/libs-releases") + } } diff --git a/stub/build.gradle.kts b/stub/build.gradle.kts index b422aeb4..6e00b59a 100644 --- a/stub/build.gradle.kts +++ b/stub/build.gradle.kts @@ -1,124 +1,115 @@ import com.google.protobuf.gradle.* -plugins { - alias(libs.plugins.dokka) -} +plugins { alias(libs.plugins.dokka) } repositories { - google() - mavenCentral() + google() + mavenCentral() - // for Dokka - maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven") + // for Dokka + maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven") } dependencies { - // Kotlin - implementation(kotlin("stdlib")) - implementation(libs.kotlinx.coroutines.core.jvm) - - // Grpc - api(libs.grpc.stub) - - // Java - api(libs.javax.annotation.api) - - // Testing - testImplementation(libs.junit) - testImplementation(libs.junit.jupiter.engine) - testImplementation(libs.truth.proto.extension) - testImplementation(libs.grpc.protobuf) - testImplementation(libs.grpc.testing) - testImplementation(libs.grpc.inprocess) + // Kotlin + implementation(kotlin("stdlib")) + implementation(libs.kotlinx.coroutines.core.jvm) + + // Grpc + api(libs.grpc.stub) + + // Java + api(libs.javax.annotation.api) + + // Testing + testImplementation(libs.junit) + testImplementation(libs.junit.jupiter.engine) + testImplementation(libs.truth.proto.extension) + testImplementation(libs.grpc.protobuf) + testImplementation(libs.grpc.testing) + testImplementation(libs.grpc.inprocess) } java { - withSourcesJar() - toolchain { - languageVersion = JavaLanguageVersion.of(17) - } + withSourcesJar() + toolchain { languageVersion = JavaLanguageVersion.of(17) } } protobuf { - protoc { - artifact = libs.protoc.asProvider().get().toString() + protoc { artifact = libs.protoc.asProvider().get().toString() } + plugins { + id("grpc") { artifact = libs.protoc.gen.grpc.java.get().toString() } + id("grpckt") { + path = project(":compiler").tasks.jar.get().archiveFile.get().asFile.absolutePath } - plugins { - id("grpc") { - artifact = libs.protoc.gen.grpc.java.get().toString() - } - id("grpckt") { - path = project(":compiler").tasks.jar.get().archiveFile.get().asFile.absolutePath - } - } - generateProtoTasks { - all().forEach { - if (it.name.startsWith("generateTestProto")) { - it.dependsOn(":compiler:jar") - } - - it.plugins { - id("grpc") - id("grpckt") - } - } + } + generateProtoTasks { + all().forEach { + if (it.name.startsWith("generateTestProto")) { + it.dependsOn(":compiler:jar") + } + + it.plugins { + id("grpc") + id("grpckt") + } } + } } dokka { - dokkaSourceSets.main { - reportUndocumented = true - - sourceLink { - localDirectory.set(file("src/main/java")) - remoteUrl("https://github.com/grpc/grpc-kotlin/blob/master/stub/src/main/java") - remoteLineSuffix.set("#L") - } - - externalDocumentationLinks.register("grpc-java-docs") { - url("https://grpc.github.io/grpc-java/javadoc/") - packageListUrl("https://grpc.github.io/grpc-java/javadoc/element-list") - } - externalDocumentationLinks.register("kotlinx.coroutines-docs") { - url("https://kotlinlang.org/api/kotlinx.coroutines/") - } - - perPackageOption { - matchingRegex.set("io.grpc.testing.*") - suppress.set(true) - } - - perPackageOption { - matchingRegex.set("io.grpc.kotlin.generator.*") - suppress.set(true) - } + dokkaSourceSets.main { + reportUndocumented = true + + sourceLink { + localDirectory.set(file("src/main/java")) + remoteUrl("https://github.com/grpc/grpc-kotlin/blob/master/stub/src/main/java") + remoteLineSuffix.set("#L") } + + externalDocumentationLinks.register("grpc-java-docs") { + url("https://grpc.github.io/grpc-java/javadoc/") + packageListUrl("https://grpc.github.io/grpc-java/javadoc/element-list") + } + externalDocumentationLinks.register("kotlinx.coroutines-docs") { + url("https://kotlinlang.org/api/kotlinx.coroutines/") + } + + perPackageOption { + matchingRegex.set("io.grpc.testing.*") + suppress.set(true) + } + + perPackageOption { + matchingRegex.set("io.grpc.kotlin.generator.*") + suppress.set(true) + } + } } -val javadocJar by tasks.registering(Jar::class) { +val javadocJar by + tasks.registering(Jar::class) { dependsOn(tasks.dokkaGenerate) archiveClassifier.set("javadoc") duplicatesStrategy = DuplicatesStrategy.EXCLUDE includeEmptyDirs = false from(layout.buildDirectory.dir("dokka/html")) -} + } -tasks.named("sourcesJar") { - exclude("**/*.bazel") -} +tasks.named("sourcesJar") { exclude("**/*.bazel") } publishing { - publications { - named("maven") { - from(components["java"]) - - artifact(javadocJar) - - pom { - name.set("gRPC Kotlin Stub") - artifactId = "grpc-kotlin-stub" - description.set("Kotlin-based stubs for gRPC services") - } - } + publications { + named("maven") { + from(components["java"]) + + artifact(javadocJar) + + pom { + name.set("gRPC Kotlin Stub") + artifactId = "grpc-kotlin-stub" + description.set("Kotlin-based stubs for gRPC services") + } } + } } diff --git a/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineServerImpl.kt b/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineServerImpl.kt index 30bfb478..28c91087 100644 --- a/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineServerImpl.kt +++ b/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineServerImpl.kt @@ -21,7 +21,7 @@ import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext /** - * Skeleton implementation of a coroutine-based gRPC server implementation. Intended to be + * Skeleton implementation of a coroutine-based gRPC server implementation. Intended to be * subclassed by generated code. */ abstract class AbstractCoroutineServerImpl( diff --git a/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineStub.kt b/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineStub.kt index b96c2760..9041b155 100644 --- a/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineStub.kt +++ b/stub/src/main/java/io/grpc/kotlin/AbstractCoroutineStub.kt @@ -19,16 +19,12 @@ package io.grpc.kotlin import io.grpc.CallOptions import io.grpc.Channel import io.grpc.stub.AbstractStub -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.withContext -import kotlin.coroutines.CoroutineContext -import kotlin.coroutines.EmptyCoroutineContext /** * A skeleton implementation of a coroutine-based client stub, suitable for extension by generated * client stubs. */ -abstract class AbstractCoroutineStub>( +abstract class AbstractCoroutineStub>( channel: Channel, callOptions: CallOptions = CallOptions.DEFAULT -): AbstractStub(channel, callOptions) +) : AbstractStub(channel, callOptions) diff --git a/stub/src/main/java/io/grpc/kotlin/ClientCalls.kt b/stub/src/main/java/io/grpc/kotlin/ClientCalls.kt index 3dce6967..5e3e02b5 100644 --- a/stub/src/main/java/io/grpc/kotlin/ClientCalls.kt +++ b/stub/src/main/java/io/grpc/kotlin/ClientCalls.kt @@ -17,7 +17,9 @@ package io.grpc.kotlin import io.grpc.CallOptions +import io.grpc.Channel as GrpcChannel import io.grpc.ClientCall +import io.grpc.Metadata as GrpcMetadata import io.grpc.MethodDescriptor import io.grpc.Status import kotlinx.coroutines.CancellationException @@ -32,17 +34,13 @@ import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import io.grpc.Channel as GrpcChannel -import io.grpc.Metadata as GrpcMetadata /** - * Helpers for gRPC clients implemented in Kotlin. Can be used directly, but intended to be used + * Helpers for gRPC clients implemented in Kotlin. Can be used directly, but intended to be used * from generated Kotlin APIs. */ object ClientCalls { - /** - * Launches a unary RPC on the specified channel, suspending until the result is received. - */ + /** Launches a unary RPC on the specified channel, suspending until the result is received. */ suspend fun unaryRpc( channel: GrpcChannel, method: MethodDescriptor, @@ -54,12 +52,13 @@ object ClientCalls { "Expected a unary RPC method, but got $method" } return rpcImpl( - channel = channel, - method = method, - callOptions = callOptions, - headers = headers, - request = Request.Unary(request) - ).singleOrStatus("request", method) + channel = channel, + method = method, + callOptions = callOptions, + headers = headers, + request = Request.Unary(request) + ) + .singleOrStatus("request", method) } /** @@ -73,12 +72,9 @@ object ClientCalls { method: MethodDescriptor, callOptions: CallOptions = CallOptions.DEFAULT, headers: suspend () -> GrpcMetadata = { GrpcMetadata() } - ): suspend (RequestT) -> ResponseT = - { unaryRpc(channel, method, it, callOptions, headers()) } + ): suspend (RequestT) -> ResponseT = { unaryRpc(channel, method, it, callOptions, headers()) } - /** - * Returns a [Flow] which launches the specified server-streaming RPC and emits the responses. - */ + /** Returns a [Flow] which launches the specified server-streaming RPC and emits the responses. */ fun serverStreamingRpc( channel: GrpcChannel, method: MethodDescriptor, @@ -110,20 +106,12 @@ object ClientCalls { callOptions: CallOptions = CallOptions.DEFAULT, headers: suspend () -> GrpcMetadata = { GrpcMetadata() } ): (RequestT) -> Flow = { - flow { - serverStreamingRpc( - channel, - method, - it, - callOptions, - headers() - ).collect { emit(it) } - } + flow { serverStreamingRpc(channel, method, it, callOptions, headers()).collect { emit(it) } } } /** * Launches a client-streaming RPC on the specified channel, suspending until the server returns - * the result. The caller is expected to provide a [Flow] of requests. + * the result. The caller is expected to provide a [Flow] of requests. */ suspend fun clientStreamingRpc( channel: GrpcChannel, @@ -136,12 +124,13 @@ object ClientCalls { "Expected a server streaming RPC method, but got $method" } return rpcImpl( - channel = channel, - method = method, - callOptions = callOptions, - headers = headers, - request = Request.Flowing(requests) - ).singleOrStatus("response", method) + channel = channel, + method = method, + callOptions = callOptions, + headers = headers, + request = Request.Flowing(requests) + ) + .singleOrStatus("response", method) } /** @@ -155,24 +144,17 @@ object ClientCalls { method: MethodDescriptor, callOptions: CallOptions = CallOptions.DEFAULT, headers: suspend () -> GrpcMetadata = { GrpcMetadata() } - ): suspend (Flow) -> ResponseT = - { - clientStreamingRpc( - channel, - method, - it, - callOptions, - headers() - ) - } + ): suspend (Flow) -> ResponseT = { + clientStreamingRpc(channel, method, it, callOptions, headers()) + } /** * Returns a [Flow] which launches the specified bidirectional-streaming RPC, collecting the * requests flow, sending them to the server, and emitting the responses. * - * Cancelling collection of the flow cancels the RPC upstream and collection of the requests. - * For example, if `responses.take(2).toList()` is executed, the RPC will be cancelled after - * the first two responses are returned. + * Cancelling collection of the flow cancels the RPC upstream and collection of the requests. For + * example, if `responses.take(2).toList()` is executed, the RPC will be cancelled after the first + * two responses are returned. */ fun bidiStreamingRpc( channel: GrpcChannel, @@ -204,46 +186,28 @@ object ClientCalls { method: MethodDescriptor, callOptions: CallOptions = CallOptions.DEFAULT, headers: suspend () -> GrpcMetadata = { GrpcMetadata() } - ): (Flow) -> Flow = - { - flow { - bidiStreamingRpc( - channel, - method, - it, - callOptions, - headers() - ).collect { emit(it) } - } - } + ): (Flow) -> Flow = { + flow { bidiStreamingRpc(channel, method, it, callOptions, headers()).collect { emit(it) } } + } /** The client's request(s). */ private sealed class Request { /** * Send the request(s) to the ClientCall, with `readiness` indicating calls to `onReady` from - * the listener. Returns when sending the requests is done, either because all the requests - * were sent (in which case `null` is returned) or because the requests channel was closed - * with an exception (in which case the exception is returned). + * the listener. Returns when sending the requests is done, either because all the requests were + * sent (in which case `null` is returned) or because the requests channel was closed with an + * exception (in which case the exception is returned). */ - abstract suspend fun sendTo( - clientCall: ClientCall, - readiness: Readiness - ) + abstract suspend fun sendTo(clientCall: ClientCall, readiness: Readiness) class Unary(private val request: RequestT) : Request() { - override suspend fun sendTo( - clientCall: ClientCall, - readiness: Readiness - ) { + override suspend fun sendTo(clientCall: ClientCall, readiness: Readiness) { clientCall.sendMessage(request) } } class Flowing(private val requestFlow: Flow) : Request() { - override suspend fun sendTo( - clientCall: ClientCall, - readiness: Readiness - ) { + override suspend fun sendTo(clientCall: ClientCall, readiness: Readiness) { readiness.suspendUntilReady() requestFlow.collect { request -> clientCall.sendMessage(request) @@ -260,10 +224,10 @@ object ClientCalls { } /** - * Returns a [Flow] that, when collected, issues the specified RPC with the specified request - * on the specified channel, and emits the responses. This is intended to be the root - * implementation of the client side of all Kotlin coroutine-based RPCs, with non-streaming - * implementations simply emitting or receiving a single message in the appropriate direction. + * Returns a [Flow] that, when collected, issues the specified RPC with the specified request on + * the specified channel, and emits the responses. This is intended to be the root implementation + * of the client side of all Kotlin coroutine-based RPCs, with non-streaming implementations + * simply emitting or receiving a single message in the appropriate direction. */ private fun rpcImpl( channel: GrpcChannel, @@ -309,15 +273,16 @@ object ClientCalls { headers.copy() ) - val sender = launch(CoroutineName("SendMessage worker for ${method.fullMethodName}")) { - try { - request.sendTo(clientCall, readiness) - clientCall.halfClose() - } catch (ex: Exception) { - clientCall.cancel("Collection of requests completed exceptionally", ex) - throw ex // propagate failure upward + val sender = + launch(CoroutineName("SendMessage worker for ${method.fullMethodName}")) { + try { + request.sendTo(clientCall, readiness) + clientCall.halfClose() + } catch (ex: Exception) { + clientCall.cancel("Collection of requests completed exceptionally", ex) + throw ex // propagate failure upward + } } - } try { clientCall.request(1) diff --git a/stub/src/main/java/io/grpc/kotlin/CoroutineContextServerInterceptor.kt b/stub/src/main/java/io/grpc/kotlin/CoroutineContextServerInterceptor.kt index f81063eb..e82b0a6d 100644 --- a/stub/src/main/java/io/grpc/kotlin/CoroutineContextServerInterceptor.kt +++ b/stub/src/main/java/io/grpc/kotlin/CoroutineContextServerInterceptor.kt @@ -1,5 +1,6 @@ package io.grpc.kotlin +import io.grpc.Context as GrpcContext import io.grpc.Metadata import io.grpc.ServerCall import io.grpc.ServerCallHandler @@ -7,21 +8,22 @@ import io.grpc.ServerInterceptor import io.grpc.StatusException import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext -import io.grpc.Context as GrpcContext /** * A [ServerInterceptor] subtype that can install elements in the [CoroutineContext] where server - * logic is executed. These elements are applied "after" the - * [AbstractCoroutineServerImpl.context]; that is, the interceptor overrides the server's context. + * logic is executed. These elements are applied "after" the [AbstractCoroutineServerImpl.context]; + * that is, the interceptor overrides the server's context. */ abstract class CoroutineContextServerInterceptor : ServerInterceptor { companion object { // This is deliberately kept visibility-restricted; it's intentional that the only way to affect // the CoroutineContext is to extend CoroutineContextServerInterceptor. - internal val COROUTINE_CONTEXT_KEY : GrpcContext.Key = + internal val COROUTINE_CONTEXT_KEY: GrpcContext.Key = GrpcContext.keyWithDefault("grpc-kotlin-coroutine-context", EmptyCoroutineContext) - private fun GrpcContext.extendCoroutineContext(coroutineContext: CoroutineContext): GrpcContext { + private fun GrpcContext.extendCoroutineContext( + coroutineContext: CoroutineContext + ): GrpcContext { val oldCoroutineContext: CoroutineContext = COROUTINE_CONTEXT_KEY[this] val newCoroutineContext = oldCoroutineContext + coroutineContext return withValue(COROUTINE_CONTEXT_KEY, newCoroutineContext) @@ -30,13 +32,13 @@ abstract class CoroutineContextServerInterceptor : ServerInterceptor { /** * Override this function to return a [CoroutineContext] in which to execute [call] and [headers]. - * The returned [CoroutineContext] will override any corresponding context elements in the - * server object. + * The returned [CoroutineContext] will override any corresponding context elements in the server + * object. * * This function will be called each time a [call] is executed. * * @throws StatusException if the call should be closed with the [Status][io.grpc.Status] in the - * exception and further processing suppressed + * exception and further processing suppressed */ abstract fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext @@ -54,14 +56,15 @@ abstract class CoroutineContextServerInterceptor : ServerInterceptor { headers: Metadata, next: ServerCallHandler ): ServerCall.Listener { - val coroutineContext = try { - coroutineContext(call, headers) - } catch (e: StatusException) { - call.close(e.status, e.trailers ?: Metadata()) - throw e - } + val coroutineContext = + try { + coroutineContext(call, headers) + } catch (e: StatusException) { + call.close(e.status, e.trailers ?: Metadata()) + throw e + } return withGrpcContext(GrpcContext.current().extendCoroutineContext(coroutineContext)) { next.startCall(call, headers) } } -} \ No newline at end of file +} diff --git a/stub/src/main/java/io/grpc/kotlin/GrpcContextElement.kt b/stub/src/main/java/io/grpc/kotlin/GrpcContextElement.kt index c47c28f2..0a66a58b 100644 --- a/stub/src/main/java/io/grpc/kotlin/GrpcContextElement.kt +++ b/stub/src/main/java/io/grpc/kotlin/GrpcContextElement.kt @@ -16,13 +16,13 @@ package io.grpc.kotlin -import kotlinx.coroutines.ThreadContextElement -import kotlin.coroutines.CoroutineContext import io.grpc.Context as GrpcContext +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.ThreadContextElement /** - * A [CoroutineContext] that propagates an associated [io.grpc.Context] to coroutines run using - * that context, regardless of thread. + * A [CoroutineContext] that propagates an associated [io.grpc.Context] to coroutines run using that + * context, regardless of thread. */ class GrpcContextElement(private val grpcContext: GrpcContext) : ThreadContextElement { companion object Key : CoroutineContext.Key { diff --git a/stub/src/main/java/io/grpc/kotlin/Helpers.kt b/stub/src/main/java/io/grpc/kotlin/Helpers.kt index 50f586d3..bc0e2728 100644 --- a/stub/src/main/java/io/grpc/kotlin/Helpers.kt +++ b/stub/src/main/java/io/grpc/kotlin/Helpers.kt @@ -29,20 +29,16 @@ import kotlinx.coroutines.flow.single import kotlinx.coroutines.runBlocking /** - * Extracts the value of a [Deferred] known to be completed, or throws its exception if it was - * not completed successfully. (Non-experimental variant of `getDone`.) + * Extracts the value of a [Deferred] known to be completed, or throws its exception if it was not + * completed successfully. (Non-experimental variant of `getDone`.) */ internal val Deferred.doneValue: T get() { check(isCompleted) { "doneValue should only be called on completed Deferred values" } - return runBlocking(Dispatchers.Unconfined) { - await() - } + return runBlocking(Dispatchers.Unconfined) { await() } } -/** - * Cancels a [Job] with a cause and suspends until the job completes/is finished cancelling. - */ +/** Cancels a [Job] with a cause and suspends until the job completes/is finished cancelling. */ internal suspend fun Job.cancelAndJoin(message: String, cause: Exception? = null) { cancel(message, cause) join() @@ -54,10 +50,7 @@ internal suspend fun Job.cancelAndJoin(message: String, cause: Exception? = null * The purpose of this function is to enable the one element to get processed before we have * confirmation that the input flow is done. */ -internal fun Flow.singleOrStatusFlow( - expected: String, - descriptor: Any -): Flow = flow { +internal fun Flow.singleOrStatusFlow(expected: String, descriptor: Any): Flow = flow { var found = false collect { if (!found) { @@ -80,7 +73,5 @@ internal fun Flow.singleOrStatusFlow( * Returns the one and only element of this flow, and throws a [StatusException] if there is not * exactly one element. */ -internal suspend fun Flow.singleOrStatus( - expected: String, - descriptor: Any -): T = singleOrStatusFlow(expected, descriptor).single() +internal suspend fun Flow.singleOrStatus(expected: String, descriptor: Any): T = + singleOrStatusFlow(expected, descriptor).single() diff --git a/stub/src/main/java/io/grpc/kotlin/Readiness.kt b/stub/src/main/java/io/grpc/kotlin/Readiness.kt index 90d619e7..117cdcb1 100644 --- a/stub/src/main/java/io/grpc/kotlin/Readiness.kt +++ b/stub/src/main/java/io/grpc/kotlin/Readiness.kt @@ -19,21 +19,18 @@ package io.grpc.kotlin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.onFailure -/** - * A simple helper allowing a notification of "ready" to be broadcast, and waited for. - */ -internal class Readiness( - private val isReallyReady: () -> Boolean -) { +/** A simple helper allowing a notification of "ready" to be broadcast, and waited for. */ +internal class Readiness(private val isReallyReady: () -> Boolean) { // A CONFLATED channel never suspends to send, and two notifications of readiness are equivalent // to one private val channel = Channel(Channel.CONFLATED) fun onReady() { channel.trySend(Unit).onFailure { e -> - throw e ?: AssertionError( - "Should be impossible; a CONFLATED channel should never return false on offer" - ) + throw e + ?: AssertionError( + "Should be impossible; a CONFLATED channel should never return false on offer" + ) } } diff --git a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt index 92f54007..097e4379 100644 --- a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt +++ b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt @@ -16,6 +16,7 @@ package io.grpc.kotlin +import io.grpc.Metadata as GrpcMetadata import io.grpc.MethodDescriptor import io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING import io.grpc.MethodDescriptor.MethodType.CLIENT_STREAMING @@ -27,6 +28,8 @@ import io.grpc.ServerMethodDefinition import io.grpc.Status import io.grpc.StatusException import io.grpc.StatusRuntimeException +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.cancel @@ -36,28 +39,23 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch -import java.util.concurrent.atomic.AtomicBoolean -import kotlin.coroutines.CoroutineContext -import io.grpc.Metadata as GrpcMetadata -/** - * Helpers for implementing a gRPC server based on a Kotlin coroutine implementation. - */ +/** Helpers for implementing a gRPC server based on a Kotlin coroutine implementation. */ object ServerCalls { /** * Creates a [ServerMethodDefinition] that implements the specified unary RPC method by running * the specified implementation and associated implementation details within a per-RPC * [CoroutineScope] generated with the specified [CoroutineContext]. * - * When the RPC is received, this method definition will pass the request from the client - * to [implementation], and send the response back to the client when it is returned. + * When the RPC is received, this method definition will pass the request from the client to + * [implementation], and send the response back to the client when it is returned. * * If [implementation] fails with a [StatusException], the RPC will fail with the corresponding - * [Status]. If [implementation] fails with a [CancellationException], the RPC will fail - * with [Status.CANCELLED]. If [implementation] fails for any other reason, the RPC will - * fail with [Status.UNKNOWN] with the exception as a cause. If a cancellation is received - * from the client before [implementation] is complete, the coroutine will be cancelled and the - * RPC will fail with [Status.CANCELLED]. + * [Status]. If [implementation] fails with a [CancellationException], the RPC will fail with + * [Status.CANCELLED]. If [implementation] fails for any other reason, the RPC will fail with + * [Status.UNKNOWN] with the exception as a cause. If a cancellation is received from the client + * before [implementation] is complete, the coroutine will be cancelled and the RPC will fail with + * [Status.CANCELLED]. * * @param context The context of the scopes the RPC implementation will run in * @param descriptor The descriptor of the method being implemented @@ -68,13 +66,9 @@ object ServerCalls { descriptor: MethodDescriptor, implementation: suspend (request: RequestT) -> ResponseT ): ServerMethodDefinition { - require(descriptor.type == UNARY) { - "Expected a unary method descriptor but got $descriptor" - } + require(descriptor.type == UNARY) { "Expected a unary method descriptor but got $descriptor" } return serverMethodDefinition(context, descriptor) { requests -> - requests - .singleOrStatusFlow("request", descriptor) - .map { implementation(it) } + requests.singleOrStatusFlow("request", descriptor).map { implementation(it) } } } @@ -84,11 +78,11 @@ object ServerCalls { * [CoroutineScope] generated with the specified [CoroutineContext]. * * When the RPC is received, this method definition will pass a [Flow] of requests from the client - * to [implementation], and send the response back to the client when it is returned. - * Exceptions are handled as in [unaryServerMethodDefinition]. Additionally, attempts to collect - * the requests flow more than once will throw an [IllegalStateException], and if [implementation] - * cancels collection of the requests flow, further requests from the client will be ignored - * (and no backpressure will be applied). + * to [implementation], and send the response back to the client when it is returned. Exceptions + * are handled as in [unaryServerMethodDefinition]. Additionally, attempts to collect the requests + * flow more than once will throw an [IllegalStateException], and if [implementation] cancels + * collection of the requests flow, further requests from the client will be ignored (and no + * backpressure will be applied). * * @param context The context of the scopes the RPC implementation will run in * @param descriptor The descriptor of the method being implemented @@ -113,13 +107,13 @@ object ServerCalls { /** * Creates a [ServerMethodDefinition] that implements the specified server-streaming RPC method by * running the specified implementation and associated implementation details within a per-RPC - * [CoroutineScope] generated with the specified [CoroutineContext]. When the RPC is received, + * [CoroutineScope] generated with the specified [CoroutineContext]. When the RPC is received, * this method definition will collect the flow returned by [implementation] and send the emitted * values back to the client. * - * When the RPC is received, this method definition will pass the request from the client - * to [implementation], and collect the returned [Flow], sending responses to the client as they - * are emitted. Exceptions and cancellation are handled as in [unaryServerMethodDefinition]. + * When the RPC is received, this method definition will pass the request from the client to + * [implementation], and collect the returned [Flow], sending responses to the client as they are + * emitted. Exceptions and cancellation are handled as in [unaryServerMethodDefinition]. * * @param context The context of the scopes the RPC implementation will run in * @param descriptor The descriptor of the method being implemented @@ -135,11 +129,9 @@ object ServerCalls { } return serverMethodDefinition(context, descriptor) { requests -> flow { - requests - .singleOrStatusFlow("request", descriptor) - .collect { req -> - implementation(req).collect { resp -> emit(resp) } - } + requests.singleOrStatusFlow("request", descriptor).collect { req -> + implementation(req).collect { resp -> emit(resp) } + } } } } @@ -153,8 +145,8 @@ object ServerCalls { * to [implementation], and collect the returned [Flow], sending responses to the client as they * are emitted. * - * Exceptions and cancellation are handled as in [clientStreamingServerMethodDefinition] and as - * in [serverStreamingServerMethodDefinition]. + * Exceptions and cancellation are handled as in [clientStreamingServerMethodDefinition] and as in + * [serverStreamingServerMethodDefinition]. * * @param context The context of the scopes the RPC implementation will run in * @param descriptor The descriptor of the method being implemented @@ -181,10 +173,7 @@ object ServerCalls { descriptor: MethodDescriptor, implementation: (Flow) -> Flow ): ServerMethodDefinition = - ServerMethodDefinition.create( - descriptor, - serverCallHandler(context, implementation) - ) + ServerMethodDefinition.create(descriptor, serverCallHandler(context, implementation)) /** * Returns a [ServerCallHandler] that implements an RPC method by running the specified @@ -193,16 +182,15 @@ object ServerCalls { private fun serverCallHandler( context: CoroutineContext, implementation: (Flow) -> Flow - ): ServerCallHandler = - ServerCallHandler { - call, _ -> serverCallListener( - context - + CoroutineContextServerInterceptor.COROUTINE_CONTEXT_KEY.get() - + GrpcContextElement.current(), - call, - implementation - ) - } + ): ServerCallHandler = ServerCallHandler { call, _ -> + serverCallListener( + context + + CoroutineContextServerInterceptor.COROUTINE_CONTEXT_KEY.get() + + GrpcContextElement.current(), + call, + implementation + ) + } private fun serverCallListener( context: CoroutineContext, @@ -214,55 +202,61 @@ object ServerCalls { val requestsStarted = AtomicBoolean(false) // enforces read-once - val requests = flow { - check(requestsStarted.compareAndSet(false, true)) { - "requests flow can only be collected once" - } - - call.request(1) - try { - for (request in requestsChannel) { - emit(request) - call.request(1) + val requests = + flow { + check(requestsStarted.compareAndSet(false, true)) { + "requests flow can only be collected once" } - } catch (e: Exception) { - requestsChannel.cancel( - CancellationException("Exception thrown while collecting requests", e) - ) - call.request(1) // make sure we don't cause backpressure - throw e - } - } - val rpcJob = CoroutineScope(context).launch { - var headersSent = false - val failure = runCatching { - implementation(requests).collect { - // once we have a response message, check if we've sent headers yet - if not, do so - if (!headersSent) { - call.sendHeaders(GrpcMetadata()) - headersSent = true + call.request(1) + try { + for (request in requestsChannel) { + emit(request) + call.request(1) } - readiness.suspendUntilReady() - call.sendMessage(it) + } catch (e: Exception) { + requestsChannel.cancel( + CancellationException("Exception thrown while collecting requests", e) + ) + call.request(1) // make sure we don't cause backpressure + throw e } - }.exceptionOrNull() - // check headers again once we're done collecting the response flow - if we received - // no elements or threw an exception, then we wouldn't have sent them - if (failure == null && !headersSent) { - call.sendHeaders(GrpcMetadata()) } - val closeStatus = when (failure) { - null -> Status.OK - is CancellationException -> Status.CANCELLED.withCause(failure) - is StatusException, is StatusRuntimeException -> Status.fromThrowable(failure) - else -> Status.fromThrowable(failure).withCause(failure) + + val rpcJob = + CoroutineScope(context).launch { + var headersSent = false + val failure = + runCatching { + implementation(requests).collect { + // once we have a response message, check if we've sent headers yet - if not, do so + if (!headersSent) { + call.sendHeaders(GrpcMetadata()) + headersSent = true + } + readiness.suspendUntilReady() + call.sendMessage(it) + } + } + .exceptionOrNull() + // check headers again once we're done collecting the response flow - if we received + // no elements or threw an exception, then we wouldn't have sent them + if (failure == null && !headersSent) { + call.sendHeaders(GrpcMetadata()) + } + val closeStatus = + when (failure) { + null -> Status.OK + is CancellationException -> Status.CANCELLED.withCause(failure) + is StatusException, + is StatusRuntimeException -> Status.fromThrowable(failure) + else -> Status.fromThrowable(failure).withCause(failure) + } + val trailers = failure?.let { Status.trailersFromThrowable(it) } ?: GrpcMetadata() + call.close(closeStatus, trailers) } - val trailers = failure?.let { Status.trailersFromThrowable(it) } ?: GrpcMetadata() - call.close(closeStatus, trailers) - } - return object: ServerCall.Listener() { + return object : ServerCall.Listener() { var isReceiving = true override fun onCancel() { @@ -275,8 +269,7 @@ object ServerCalls { isReceiving = result.isSuccess result.onFailure { ex -> if (ex !is CancellationException) { - throw Status.INTERNAL - .withDescription( + throw Status.INTERNAL.withDescription( "onMessage should never be called when requestsChannel is unready" ) .withCause(ex) diff --git a/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt index bb34474d..25a9e50f 100644 --- a/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt @@ -55,12 +55,14 @@ import org.junit.rules.Timeout abstract class AbstractCallsTest { companion object { fun helloRequest(name: String): HelloRequest = HelloRequest.newBuilder().setName(name).build() - fun helloReply(message: String): HelloReply = HelloReply.newBuilder().setMessage(message).build() + + fun helloReply(message: String): HelloReply = + HelloReply.newBuilder().setMessage(message).build() + fun multiHelloRequest(vararg name: String): MultiHelloRequest = MultiHelloRequest.newBuilder().addAllName(name.asList()).build() - val sayHelloMethod: MethodDescriptor = - GreeterGrpc.getSayHelloMethod() + val sayHelloMethod: MethodDescriptor = GreeterGrpc.getSayHelloMethod() val clientStreamingSayHelloMethod: MethodDescriptor = GreeterGrpc.getClientStreamSayHelloMethod() val serverStreamingSayHelloMethod: MethodDescriptor = @@ -69,9 +71,7 @@ abstract class AbstractCallsTest { GreeterGrpc.getBidiStreamSayHelloMethod() val greeterService: ServiceDescriptor = GreeterGrpc.getServiceDescriptor() - fun CoroutineScope.produce( - block: suspend SendChannel.() -> Unit - ): ReceiveChannel { + fun CoroutineScope.produce(block: suspend SendChannel.() -> Unit): ReceiveChannel { val channel = Channel() launch { channel.block() @@ -98,19 +98,16 @@ abstract class AbstractCallsTest { } fun whenContextIsCancelled(onCancelled: () -> Unit) { - Context.current().withCancellation().addListener( - Context.CancellationListener { onCancelled() }, - MoreExecutors.directExecutor() - ) + Context.current() + .withCancellation() + .addListener(Context.CancellationListener { onCancelled() }, MoreExecutors.directExecutor()) } } - @get:Rule - val timeout: Timeout = Timeout.seconds(10) + @get:Rule val timeout: Timeout = Timeout.seconds(10) // We want the coroutines timeout to come first, because it comes with useful debug logs. - @get:Rule - val grpcCleanup = GrpcCleanupRule().setTimeout(11, TimeUnit.SECONDS) + @get:Rule val grpcCleanup = GrpcCleanupRule().setTimeout(11, TimeUnit.SECONDS) lateinit var channel: ManagedChannel @@ -133,9 +130,7 @@ abstract class AbstractCallsTest { } } - inline fun assertThrows( - callback: () -> Unit - ): E { + inline fun assertThrows(callback: () -> Unit): E { var ex: Exception? = null try { callback() @@ -169,8 +164,7 @@ abstract class AbstractCallsTest { ) return grpcCleanup.register( - InProcessChannelBuilder - .forName(serverName) + InProcessChannelBuilder.forName(serverName) .enableRetry() .defaultServiceConfig(serviceConfig) .run { this as io.grpc.ManagedChannelBuilder<*> } // workaround b/123879662 @@ -199,10 +193,7 @@ abstract class AbstractCallsTest { config: Map = emptyMap(), vararg interceptors: ServerInterceptor ): ManagedChannel { - return makeChannel( - ServerInterceptors.intercept(serverServiceDefinition, *interceptors), - config - ) + return makeChannel(ServerInterceptors.intercept(serverServiceDefinition, *interceptors), config) } fun runBlocking(block: suspend CoroutineScope.() -> R): Unit = diff --git a/stub/src/test/java/io/grpc/kotlin/ClientCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/ClientCallsTest.kt index 4e9862bc..87c24275 100644 --- a/stub/src/test/java/io/grpc/kotlin/ClientCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/ClientCallsTest.kt @@ -18,12 +18,10 @@ package io.grpc.kotlin import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.common.util.concurrent.MoreExecutors.directExecutor import io.grpc.CallOptions import io.grpc.ClientCall import io.grpc.ClientInterceptor import io.grpc.ClientInterceptors -import io.grpc.Context import io.grpc.ForwardingClientCall import io.grpc.Metadata import io.grpc.MethodDescriptor @@ -35,8 +33,8 @@ import io.grpc.examples.helloworld.HelloRequest import io.grpc.examples.helloworld.MultiHelloRequest import io.grpc.stub.StreamObserver import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.Job import kotlinx.coroutines.async @@ -48,75 +46,75 @@ import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.produceIn import kotlinx.coroutines.flow.single -import kotlinx.coroutines.flow.take import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.yield import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -import java.util.concurrent.atomic.AtomicInteger /** Tests for [ClientCalls]. */ @RunWith(JUnit4::class) -class ClientCallsTest: AbstractCallsTest() { +class ClientCallsTest : AbstractCallsTest() { - /** - * Verifies that a simple unary RPC successfully returns results to a suspend function. - */ + /** Verifies that a simple unary RPC successfully returns results to a suspend function. */ @Test fun simpleUnary() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { - responseObserver.onNext(helloReply("Hello, ${request.name}")) - responseObserver.onCompleted() + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { + responseObserver.onNext(helloReply("Hello, ${request.name}")) + responseObserver.onCompleted() + } } - } channel = makeChannel(serverImpl) assertThat( - ClientCalls.unaryRpc( - channel = channel, - callOptions = CallOptions.DEFAULT, - method = sayHelloMethod, - request = helloRequest("Cindy") + ClientCalls.unaryRpc( + channel = channel, + callOptions = CallOptions.DEFAULT, + method = sayHelloMethod, + request = helloRequest("Cindy") + ) ) - ).isEqualTo(helloReply("Hello, Cindy")) + .isEqualTo(helloReply("Hello, Cindy")) assertThat( - ClientCalls.unaryRpc( - channel = channel, - callOptions = CallOptions.DEFAULT, - method = sayHelloMethod, - request = helloRequest("Jeff") + ClientCalls.unaryRpc( + channel = channel, + callOptions = CallOptions.DEFAULT, + method = sayHelloMethod, + request = helloRequest("Jeff") + ) ) - ).isEqualTo(helloReply("Hello, Jeff")) + .isEqualTo(helloReply("Hello, Jeff")) } /** - * Verify that a unary RPC that does not respond within a timeout specified by [CallOptions] - * fails on the client with a DEADLINE_EXCEEDED and is cancelled on the server. + * Verify that a unary RPC that does not respond within a timeout specified by [CallOptions] fails + * on the client with a DEADLINE_EXCEEDED and is cancelled on the server. */ @Test fun unaryServerDoesNotRespondGrpcTimeout() = runBlocking { val serverCancelled = Job() - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { - whenContextIsCancelled { serverCancelled.complete() } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { + whenContextIsCancelled { serverCancelled.complete() } + } } - } channel = makeChannel(serverImpl) - val ex = assertThrows { - ClientCalls.unaryRpc( - channel = channel, - callOptions = CallOptions.DEFAULT.withDeadlineAfter(200, TimeUnit.MILLISECONDS), - method = sayHelloMethod, - request = helloRequest("Jeff") - ) - } + val ex = + assertThrows { + ClientCalls.unaryRpc( + channel = channel, + callOptions = CallOptions.DEFAULT.withDeadlineAfter(200, TimeUnit.MILLISECONDS), + method = sayHelloMethod, + request = helloRequest("Jeff") + ) + } assertThat(ex.status.code).isEqualTo(Status.Code.DEADLINE_EXCEEDED) serverCancelled.join() } @@ -124,13 +122,14 @@ class ClientCallsTest: AbstractCallsTest() { /** Verify that a server that sends two responses to a unary RPC causes an exception. */ @Test fun unaryTooManyResponses() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { - responseObserver.onNext(helloReply("Hello, ${request.name}")) - responseObserver.onNext(helloReply("It's nice to meet you, ${request.name}")) - responseObserver.onCompleted() + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { + responseObserver.onNext(helloReply("Hello, ${request.name}")) + responseObserver.onNext(helloReply("It's nice to meet you, ${request.name}")) + responseObserver.onCompleted() + } } - } channel = makeChannel(serverImpl) @@ -149,11 +148,12 @@ class ClientCallsTest: AbstractCallsTest() { /** Verify that a server that sends zero responses to a unary RPC causes an exception. */ @Test fun unaryNoResponses() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { - responseObserver.onCompleted() + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { + responseObserver.onCompleted() + } } - } channel = makeChannel(serverImpl) @@ -179,12 +179,13 @@ class ClientCallsTest: AbstractCallsTest() { val serverReceived = Job() val serverCancelled = Job() - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { - serverReceived.complete() - whenContextIsCancelled { serverCancelled.complete() } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { + serverReceived.complete() + whenContextIsCancelled { serverCancelled.complete() } + } } - } channel = makeChannel(serverImpl) @@ -203,22 +204,24 @@ class ClientCallsTest: AbstractCallsTest() { @Test fun unaryServerExceptionPropagated() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { - throw IllegalArgumentException("No hello for you!") + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun sayHello(request: HelloRequest, responseObserver: StreamObserver) { + throw IllegalArgumentException("No hello for you!") + } } - } channel = makeChannel(serverImpl) - val ex = assertThrows { - ClientCalls.unaryRpc( - channel = channel, - callOptions = CallOptions.DEFAULT, - method = sayHelloMethod, - request = helloRequest("Cindy") - ) - } + val ex = + assertThrows { + ClientCalls.unaryRpc( + channel = channel, + callOptions = CallOptions.DEFAULT, + method = sayHelloMethod, + request = helloRequest("Cindy") + ) + } assertThat(ex.status.code).isEqualTo(Status.Code.UNKNOWN) } @@ -258,56 +261,64 @@ class ClientCallsTest: AbstractCallsTest() { @Test fun simpleServerStreamingRpc() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun serverStreamSayHello( - request: MultiHelloRequest, - responseObserver: StreamObserver - ) { - for (name in request.nameList) { - responseObserver.onNext(helloReply("Hello, $name")) + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun serverStreamSayHello( + request: MultiHelloRequest, + responseObserver: StreamObserver + ) { + for (name in request.nameList) { + responseObserver.onNext(helloReply("Hello, $name")) + } + responseObserver.onCompleted() } - responseObserver.onCompleted() } - } channel = makeChannel(serverImpl) - val rpc = ClientCalls.serverStreamingRpc( - channel = channel, - method = serverStreamingSayHelloMethod, - request = multiHelloRequest("Cindy", "Jeff", "Aki") - ) + val rpc = + ClientCalls.serverStreamingRpc( + channel = channel, + method = serverStreamingSayHelloMethod, + request = multiHelloRequest("Cindy", "Jeff", "Aki") + ) - assertThat(rpc.toList()).containsExactly( - helloReply("Hello, Cindy"), helloReply("Hello, Jeff"), helloReply("Hello, Aki") - ).inOrder() + assertThat(rpc.toList()) + .containsExactly( + helloReply("Hello, Cindy"), + helloReply("Hello, Jeff"), + helloReply("Hello, Aki") + ) + .inOrder() } @Test fun serverStreamingRpcCancellation() = runBlocking { val serverCancelled = Job() val serverReceived = Job() - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun serverStreamSayHello( - request: MultiHelloRequest, - responseObserver: StreamObserver - ) { - whenContextIsCancelled { serverCancelled.complete() } - serverReceived.complete() - for (name in request.nameList) { - responseObserver.onNext(helloReply("Hello, $name")) + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun serverStreamSayHello( + request: MultiHelloRequest, + responseObserver: StreamObserver + ) { + whenContextIsCancelled { serverCancelled.complete() } + serverReceived.complete() + for (name in request.nameList) { + responseObserver.onNext(helloReply("Hello, $name")) + } + responseObserver.onCompleted() } - responseObserver.onCompleted() } - } channel = makeChannel(serverImpl) - val rpc = ClientCalls.serverStreamingRpc( - channel = channel, - method = serverStreamingSayHelloMethod, - request = multiHelloRequest("Tim", "Jim", "Pym") - ) + val rpc = + ClientCalls.serverStreamingRpc( + channel = channel, + method = serverStreamingSayHelloMethod, + request = multiHelloRequest("Tim", "Jim", "Pym") + ) assertThrows { rpc.collect { serverReceived.join() @@ -319,76 +330,76 @@ class ClientCallsTest: AbstractCallsTest() { @Test fun simpleClientStreamingRpc() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun clientStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - private val names = mutableListOf() - - override fun onNext(value: HelloRequest) { - names += value.name - } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun clientStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + private val names = mutableListOf() + + override fun onNext(value: HelloRequest) { + names += value.name + } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - responseObserver.onNext( - helloReply(names.joinToString(prefix = "Hello, ", separator = ", ")) - ) - responseObserver.onCompleted() + override fun onCompleted() { + responseObserver.onNext( + helloReply(names.joinToString(prefix = "Hello, ", separator = ", ")) + ) + responseObserver.onCompleted() + } } } } - } channel = makeChannel(serverImpl) - val requests = flowOf( - helloRequest("Tim"), - helloRequest("Jim") - ) + val requests = flowOf(helloRequest("Tim"), helloRequest("Jim")) assertThat( - ClientCalls.clientStreamingRpc( - channel = channel, - method = clientStreamingSayHelloMethod, - requests = requests + ClientCalls.clientStreamingRpc( + channel = channel, + method = clientStreamingSayHelloMethod, + requests = requests + ) ) - ).isEqualTo(helloReply("Hello, Tim, Jim")) + .isEqualTo(helloReply("Hello, Tim, Jim")) } @FlowPreview @Test fun clientStreamingRpcReturnsEarly() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun clientStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - private val names = mutableListOf() - private var isComplete = false - - override fun onNext(value: HelloRequest) { - names += value.name - if (names.size >= 2 && !isComplete) { - onCompleted() + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun clientStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + private val names = mutableListOf() + private var isComplete = false + + override fun onNext(value: HelloRequest) { + names += value.name + if (names.size >= 2 && !isComplete) { + onCompleted() + } } - } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - if (!isComplete) { - responseObserver.onNext( - helloReply(names.joinToString(prefix = "Hello, ", separator = ", ")) - ) - responseObserver.onCompleted() - isComplete = true + override fun onCompleted() { + if (!isComplete) { + responseObserver.onNext( + helloReply(names.joinToString(prefix = "Hello, ", separator = ", ")) + ) + responseObserver.onCompleted() + isComplete = true + } } } } } - } channel = makeChannel(serverImpl) @@ -415,28 +426,29 @@ class ClientCallsTest: AbstractCallsTest() { @FlowPreview @Test fun clientStreamingRpcCancelled() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun clientStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - private val names = mutableListOf() - - override fun onNext(value: HelloRequest) { - names += value.name - } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun clientStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + private val names = mutableListOf() + + override fun onNext(value: HelloRequest) { + names += value.name + } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - responseObserver.onNext( - helloReply(names.joinToString(prefix = "Hello, ", separator = ", ")) - ) - responseObserver.onCompleted() + override fun onCompleted() { + responseObserver.onNext( + helloReply(names.joinToString(prefix = "Hello, ", separator = ", ")) + ) + responseObserver.onCompleted() + } } } } - } channel = makeChannel(serverImpl) @@ -451,40 +463,41 @@ class ClientCallsTest: AbstractCallsTest() { requests.send(helloRequest("Tim")) response.cancel() response.join() - assertThrows { - requests.send(helloRequest("John")) - } + assertThrows { requests.send(helloRequest("John")) } } @FlowPreview @Test fun simpleBidiStreamingRpc() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun bidiStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - override fun onNext(value: HelloRequest) { - responseObserver.onNext(helloReply("Hello, ${value.name}")) - } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun bidiStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + override fun onNext(value: HelloRequest) { + responseObserver.onNext(helloReply("Hello, ${value.name}")) + } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - responseObserver.onCompleted() + override fun onCompleted() { + responseObserver.onCompleted() + } } } } - } channel = makeChannel(serverImpl) val requests = Channel() - val rpc = ClientCalls.bidiStreamingRpc( - channel = channel, - method = bidiStreamingSayHelloMethod, - requests = requests.consumeAsFlow() - ).produceIn(this) + val rpc = + ClientCalls.bidiStreamingRpc( + channel = channel, + method = bidiStreamingSayHelloMethod, + requests = requests.consumeAsFlow() + ) + .produceIn(this) requests.send(helloRequest("Tim")) assertThat(rpc.receive()).isEqualTo(helloReply("Hello, Tim")) requests.send(helloRequest("Jim")) @@ -496,38 +509,41 @@ class ClientCallsTest: AbstractCallsTest() { @FlowPreview @Test fun bidiStreamingRpcReturnsEarly() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun bidiStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - private var responseCount = 0 - - override fun onNext(value: HelloRequest) { - responseCount++ - responseObserver.onNext(helloReply("Hello, ${value.name}")) - if (responseCount >= 2) { - onCompleted() + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun bidiStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + private var responseCount = 0 + + override fun onNext(value: HelloRequest) { + responseCount++ + responseObserver.onNext(helloReply("Hello, ${value.name}")) + if (responseCount >= 2) { + onCompleted() + } } - } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - responseObserver.onCompleted() + override fun onCompleted() { + responseObserver.onCompleted() + } } } } - } channel = makeChannel(serverImpl) val requests = Channel() - val rpc = ClientCalls.bidiStreamingRpc( - channel = channel, - method = bidiStreamingSayHelloMethod, - requests = requests.consumeAsFlow() - ).produceIn(this) + val rpc = + ClientCalls.bidiStreamingRpc( + channel = channel, + method = bidiStreamingSayHelloMethod, + requests = requests.consumeAsFlow() + ) + .produceIn(this) requests.send(helloRequest("Tim")) assertThat(rpc.receive()).isEqualTo(helloReply("Hello, Tim")) requests.send(helloRequest("Jim")) @@ -544,72 +560,73 @@ class ClientCallsTest: AbstractCallsTest() { @Test fun bidiStreamingRpcRequestsFail() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun bidiStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - override fun onNext(value: HelloRequest) { - responseObserver.onNext(helloReply("Hello, ${value.name}")) - } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun bidiStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + override fun onNext(value: HelloRequest) { + responseObserver.onNext(helloReply("Hello, ${value.name}")) + } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - responseObserver.onCompleted() + override fun onCompleted() { + responseObserver.onCompleted() + } } } } - } channel = makeChannel(serverImpl) - val responses = ClientCalls.bidiStreamingRpc( - channel = channel, - method = bidiStreamingSayHelloMethod, - requests = flow { - throw MyException() - } - ) + val responses = + ClientCalls.bidiStreamingRpc( + channel = channel, + method = bidiStreamingSayHelloMethod, + requests = flow { throw MyException() } + ) - assertThrows { - responses.collect() - } + assertThrows { responses.collect() } } - private class MyException: Exception() + private class MyException : Exception() @Test fun bidiStreamingRpcCollectsRequestsEachTime() = runBlocking { - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun bidiStreamSayHello( - responseObserver: StreamObserver - ): StreamObserver { - return object : StreamObserver { - override fun onNext(value: HelloRequest) { - responseObserver.onNext(helloReply("Hello, ${value.name}")) - } + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun bidiStreamSayHello( + responseObserver: StreamObserver + ): StreamObserver { + return object : StreamObserver { + override fun onNext(value: HelloRequest) { + responseObserver.onNext(helloReply("Hello, ${value.name}")) + } - override fun onError(t: Throwable) = throw t + override fun onError(t: Throwable) = throw t - override fun onCompleted() { - responseObserver.onCompleted() + override fun onCompleted() { + responseObserver.onCompleted() + } } } } - } channel = makeChannel(serverImpl) val requestsEvaluations = AtomicInteger() - val requests = flow { - requestsEvaluations.incrementAndGet() - emit(helloRequest("Sunstone")) - } + val requests = + flow { + requestsEvaluations.incrementAndGet() + emit(helloRequest("Sunstone")) + } - val responses = ClientCalls.bidiStreamingRpc( - channel = channel, - method = bidiStreamingSayHelloMethod, - requests = requests - ) + val responses = + ClientCalls.bidiStreamingRpc( + channel = channel, + method = bidiStreamingSayHelloMethod, + requests = requests + ) assertThat(responses.single()).isEqualTo(helloReply("Hello, Sunstone")) assertThat(responses.single()).isEqualTo(helloReply("Hello, Sunstone")) @@ -618,37 +635,40 @@ class ClientCallsTest: AbstractCallsTest() { @Test fun metadataCopied() = runBlocking { - val metadataKey: Metadata.Key = Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER) - val serverImpl = object : GreeterGrpc.GreeterImplBase() { - override fun serverStreamSayHello( - request: MultiHelloRequest, - responseObserver: StreamObserver - ) { - responseObserver.onNext(helloReply("hello!")) - responseObserver.onCompleted() + val metadataKey: Metadata.Key = + Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER) + val serverImpl = + object : GreeterGrpc.GreeterImplBase() { + override fun serverStreamSayHello( + request: MultiHelloRequest, + responseObserver: StreamObserver + ) { + responseObserver.onNext(helloReply("hello!")) + responseObserver.onCompleted() + } } - } // Verify that the metadata is copied anew for each collection of the flow, with an interceptor // that checks that it hasn't run before. - val interceptor = object : ClientInterceptor { - override fun interceptCall( - method: MethodDescriptor?, - callOptions: CallOptions, - next: io.grpc.Channel - ): ClientCall { - val call: ClientCall = next.newCall(method, callOptions) - return object : ForwardingClientCall() { - override fun start(responseListener: Listener, headers: Metadata) { - check(!headers.containsKey(metadataKey)) - headers.put(metadataKey, "value") - super.start(responseListener, headers) - } + val interceptor = + object : ClientInterceptor { + override fun interceptCall( + method: MethodDescriptor?, + callOptions: CallOptions, + next: io.grpc.Channel + ): ClientCall { + val call: ClientCall = next.newCall(method, callOptions) + return object : ForwardingClientCall() { + override fun start(responseListener: Listener, headers: Metadata) { + check(!headers.containsKey(metadataKey)) + headers.put(metadataKey, "value") + super.start(responseListener, headers) + } - override fun delegate(): ClientCall = call + override fun delegate(): ClientCall = call + } } } - } val channel = ClientInterceptors.intercept(makeChannel(serverImpl), interceptor) val flow = ClientCalls.serverStreamingRpc(channel, serverStreamingSayHelloMethod, multiHelloRequest()) diff --git a/stub/src/test/java/io/grpc/kotlin/CoroutineContextServerInterceptorTest.kt b/stub/src/test/java/io/grpc/kotlin/CoroutineContextServerInterceptorTest.kt index dc90e8ed..1bc94837 100644 --- a/stub/src/test/java/io/grpc/kotlin/CoroutineContextServerInterceptorTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/CoroutineContextServerInterceptorTest.kt @@ -1,137 +1,132 @@ package io.grpc.kotlin import com.google.common.truth.Truth.assertThat +import io.grpc.Metadata as GrpcMetadata import io.grpc.ServerCall import io.grpc.ServerInterceptors import io.grpc.Status import io.grpc.StatusException -import io.grpc.StatusRuntimeException import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.coroutineContext -import io.grpc.Metadata as GrpcMetadata +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 /** Tests for [CoroutineContextServerInterceptor]. */ @RunWith(JUnit4::class) class CoroutineContextServerInterceptorTest : AbstractCallsTest() { class ArbitraryContextElement(val message: String = "") : CoroutineContext.Element { companion object Key : CoroutineContext.Key + override val key: CoroutineContext.Key<*> get() = Key } - class HelloReplyWithContextMessage( - message: String? = null - ) : GreeterCoroutineImplBase( - message?.let { ArbitraryContextElement(it) } ?: EmptyCoroutineContext - ) { + class HelloReplyWithContextMessage(message: String? = null) : + GreeterCoroutineImplBase( + message?.let { ArbitraryContextElement(it) } ?: EmptyCoroutineContext + ) { override suspend fun sayHello(request: HelloRequest): HelloReply = helloReply(coroutineContext[ArbitraryContextElement]!!.message) } @Test fun injectContext() { - val interceptor = object : CoroutineContextServerInterceptor() { - override fun coroutineContext( - call: ServerCall<*, *>, - headers: GrpcMetadata - ): CoroutineContext = ArbitraryContextElement("success") - } + val interceptor = + object : CoroutineContextServerInterceptor() { + override fun coroutineContext( + call: ServerCall<*, *>, + headers: GrpcMetadata + ): CoroutineContext = ArbitraryContextElement("success") + } val channel = makeChannel(HelloReplyWithContextMessage(), interceptor) val client = GreeterCoroutineStub(channel) - runBlocking { - assertThat(client.sayHello(helloRequest("")).message).isEqualTo("success") - } + runBlocking { assertThat(client.sayHello(helloRequest("")).message).isEqualTo("success") } } @Test fun conflictingInterceptorsInnermostWins() { - val interceptor1 = object : CoroutineContextServerInterceptor() { - override fun coroutineContext( - call: ServerCall<*, *>, - headers: GrpcMetadata - ): CoroutineContext = ArbitraryContextElement("first") - } - val interceptor2 = object : CoroutineContextServerInterceptor() { - override fun coroutineContext( - call: ServerCall<*, *>, - headers: GrpcMetadata - ): CoroutineContext = ArbitraryContextElement("second") - } + val interceptor1 = + object : CoroutineContextServerInterceptor() { + override fun coroutineContext( + call: ServerCall<*, *>, + headers: GrpcMetadata + ): CoroutineContext = ArbitraryContextElement("first") + } + val interceptor2 = + object : CoroutineContextServerInterceptor() { + override fun coroutineContext( + call: ServerCall<*, *>, + headers: GrpcMetadata + ): CoroutineContext = ArbitraryContextElement("second") + } - val channel = makeChannel( - ServerInterceptors.intercept( + val channel = + makeChannel( ServerInterceptors.intercept( - HelloReplyWithContextMessage(), - interceptor2 - ), - interceptor1 + ServerInterceptors.intercept(HelloReplyWithContextMessage(), interceptor2), + interceptor1 + ) ) - ) val client = GreeterCoroutineStub(channel) - runBlocking { - assertThat(client.sayHello(helloRequest("")).message).isEqualTo("second") - } + runBlocking { assertThat(client.sayHello(helloRequest("")).message).isEqualTo("second") } } @Test fun interceptorContextTakesPriority() { - val interceptor = object : CoroutineContextServerInterceptor() { - override fun coroutineContext( - call: ServerCall<*, *>, - headers: GrpcMetadata - ): CoroutineContext = ArbitraryContextElement("interceptor") - } + val interceptor = + object : CoroutineContextServerInterceptor() { + override fun coroutineContext( + call: ServerCall<*, *>, + headers: GrpcMetadata + ): CoroutineContext = ArbitraryContextElement("interceptor") + } val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor) val client = GreeterCoroutineStub(channel) - runBlocking { - assertThat(client.sayHello(helloRequest("")).message).isEqualTo("interceptor") - } + runBlocking { assertThat(client.sayHello(helloRequest("")).message).isEqualTo("interceptor") } } @Test fun statusExceptionThrownFromCoroutineContextClosesCall() { - val interceptor = object : CoroutineContextServerInterceptor() { - override fun coroutineContext( - call: ServerCall<*, *>, - headers: GrpcMetadata - ): CoroutineContext { - throw StatusException(Status.INTERNAL.withDescription("An error")) + val interceptor = + object : CoroutineContextServerInterceptor() { + override fun coroutineContext( + call: ServerCall<*, *>, + headers: GrpcMetadata + ): CoroutineContext { + throw StatusException(Status.INTERNAL.withDescription("An error")) + } } - } val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor) val client = GreeterCoroutineStub(channel) - runBlocking { - assertThrows { client.sayHello(helloRequest("")) } - } + runBlocking { assertThrows { client.sayHello(helloRequest("")) } } } @Test fun retainsTrailersFromStatusExceptionThrownFromCoroutineContext() { val aMetadataKey = GrpcMetadata.Key.of("a-metadata-key", GrpcMetadata.ASCII_STRING_MARSHALLER) - val interceptor = object : CoroutineContextServerInterceptor() { - override fun coroutineContext( - call: ServerCall<*, *>, - headers: GrpcMetadata - ): CoroutineContext { - val trailers = GrpcMetadata().apply { put(aMetadataKey, "A value") } - throw StatusException(Status.INTERNAL, trailers) + val interceptor = + object : CoroutineContextServerInterceptor() { + override fun coroutineContext( + call: ServerCall<*, *>, + headers: GrpcMetadata + ): CoroutineContext { + val trailers = GrpcMetadata().apply { put(aMetadataKey, "A value") } + throw StatusException(Status.INTERNAL, trailers) + } } - } val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor) val client = GreeterCoroutineStub(channel) diff --git a/stub/src/test/java/io/grpc/kotlin/FlowControlTest.kt b/stub/src/test/java/io/grpc/kotlin/FlowControlTest.kt index e0f2ea7b..0f3346a6 100644 --- a/stub/src/test/java/io/grpc/kotlin/FlowControlTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/FlowControlTest.kt @@ -44,28 +44,28 @@ class FlowControlTest : AbstractCallsTest() { val context = CoroutineName("server context") private fun Flow.produceUnbuffered(scope: CoroutineScope): ReceiveChannel { - return scope.produce { - collect { send(it) } - } + return scope.produce { collect { send(it) } } } @FlowPreview @Test fun bidiPingPongFlowControl() = runBlocking { - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context = context, - descriptor = bidiStreamingSayHelloMethod, - implementation = { requests -> requests.map { helloReply("Hello, ${it.name}") } } + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition( + context = context, + descriptor = bidiStreamingSayHelloMethod, + implementation = { requests -> requests.map { helloReply("Hello, ${it.name}") } } + ) ) - ) val requests = Channel() val responses = ClientCalls.bidiStreamingRpc( - channel = channel, - requests = requests.consumeAsFlow(), - method = bidiStreamingSayHelloMethod - ).produceUnbuffered(this) + channel = channel, + requests = requests.consumeAsFlow(), + method = bidiStreamingSayHelloMethod + ) + .produceUnbuffered(this) requests.send(helloRequest("Garnet")) requests.send(helloRequest("Amethyst")) val third = launch { requests.send(helloRequest("Steven")) } @@ -80,21 +80,24 @@ class FlowControlTest : AbstractCallsTest() { @FlowPreview @Test fun bidiPingPongFlowControlExpandedServerBuffer() = runBlocking { - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context = context, - descriptor = bidiStreamingSayHelloMethod, - implementation = { - requests -> requests.buffer(Channel.RENDEZVOUS).map { helloReply("Hello, ${it.name}") } - } + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition( + context = context, + descriptor = bidiStreamingSayHelloMethod, + implementation = { requests -> + requests.buffer(Channel.RENDEZVOUS).map { helloReply("Hello, ${it.name}") } + } + ) ) - ) val requests = Channel() - val responses = ClientCalls.bidiStreamingRpc( - channel = channel, - requests = requests.consumeAsFlow(), - method = bidiStreamingSayHelloMethod - ).produceUnbuffered(this) + val responses = + ClientCalls.bidiStreamingRpc( + channel = channel, + requests = requests.consumeAsFlow(), + method = bidiStreamingSayHelloMethod + ) + .produceUnbuffered(this) requests.send(helloRequest("Garnet")) requests.send(helloRequest("Amethyst")) requests.send(helloRequest("Pearl")) @@ -109,7 +112,7 @@ class FlowControlTest : AbstractCallsTest() { @FlowPreview @Test fun bidiPingPongFlowControlServerDrawsMultipleRequests() = runBlocking { - fun Flow.pairOff(): Flow> = flow { + fun Flow.pairOff(): Flow> = flow { var odd: T? = null collect { val o = odd @@ -122,21 +125,24 @@ class FlowControlTest : AbstractCallsTest() { } } - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context = context, - descriptor = bidiStreamingSayHelloMethod, - implementation = { requests -> - requests.pairOff().map { (a, b) -> helloReply("Hello, ${a.name} and ${b.name}") } - } + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition( + context = context, + descriptor = bidiStreamingSayHelloMethod, + implementation = { requests -> + requests.pairOff().map { (a, b) -> helloReply("Hello, ${a.name} and ${b.name}") } + } + ) ) - ) val requests = Channel() - val responses = ClientCalls.bidiStreamingRpc( - channel = channel, - requests = requests.consumeAsFlow(), - method = bidiStreamingSayHelloMethod - ).produceUnbuffered(this) + val responses = + ClientCalls.bidiStreamingRpc( + channel = channel, + requests = requests.consumeAsFlow(), + method = bidiStreamingSayHelloMethod + ) + .produceUnbuffered(this) requests.send(helloRequest("Garnet")) requests.send(helloRequest("Amethyst")) requests.send(helloRequest("Pearl")) @@ -148,33 +154,38 @@ class FlowControlTest : AbstractCallsTest() { fourth.join() // pulling one element allows the cycle to advance requests.send(helloRequest("Rainbow 2.0")) requests.close() - assertThat(responses.toList()).containsExactly( - helloReply("Hello, Pearl and Steven"), helloReply("Hello, Onion and Rainbow 2.0") - ) + assertThat(responses.toList()) + .containsExactly( + helloReply("Hello, Pearl and Steven"), + helloReply("Hello, Onion and Rainbow 2.0") + ) } @ExperimentalCoroutinesApi // transform @FlowPreview @Test fun bidiPingPongFlowControlServerSendsMultipleResponses() = runBlocking { - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context = context, - descriptor = bidiStreamingSayHelloMethod, - implementation = { requests -> - requests.transform { - emit(helloReply("Hello, ${it.name}")) - emit(helloReply("Goodbye, ${it.name}")) + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition( + context = context, + descriptor = bidiStreamingSayHelloMethod, + implementation = { requests -> + requests.transform { + emit(helloReply("Hello, ${it.name}")) + emit(helloReply("Goodbye, ${it.name}")) + } } - } + ) ) - ) val requests = Channel() - val responses = ClientCalls.bidiStreamingRpc( - channel = channel, - requests = requests.consumeAsFlow(), - method = bidiStreamingSayHelloMethod - ).produceUnbuffered(this) + val responses = + ClientCalls.bidiStreamingRpc( + channel = channel, + requests = requests.consumeAsFlow(), + method = bidiStreamingSayHelloMethod + ) + .produceUnbuffered(this) requests.send(helloRequest("Garnet")) val second = launch { requests.send(helloRequest("Pearl")) } delay(200) // wait for everything to work its way through the system diff --git a/stub/src/test/java/io/grpc/kotlin/GeneratedCodeTest.kt b/stub/src/test/java/io/grpc/kotlin/GeneratedCodeTest.kt index 9db08ccd..a84e656f 100644 --- a/stub/src/test/java/io/grpc/kotlin/GeneratedCodeTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/GeneratedCodeTest.kt @@ -18,6 +18,7 @@ package io.grpc.kotlin import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import io.grpc.Metadata as GrpcMetadata import io.grpc.ServerCall import io.grpc.ServerCallHandler import io.grpc.ServerInterceptor @@ -30,7 +31,9 @@ import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.examples.helloworld.MultiHelloRequest +import java.util.concurrent.TimeUnit import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.Job @@ -51,9 +54,6 @@ import kotlinx.coroutines.launch import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -import java.util.concurrent.TimeUnit -import kotlinx.coroutines.CompletableDeferred -import io.grpc.Metadata as GrpcMetadata @RunWith(JUnit4::class) class GeneratedCodeTest : AbstractCallsTest() { @@ -64,19 +64,17 @@ class GeneratedCodeTest : AbstractCallsTest() { @Test fun simpleUnary() { - val server = object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - return HelloReply.newBuilder() - .setMessage("Hello, ${request.name}!" ) - .build() + val server = + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + return HelloReply.newBuilder().setMessage("Hello, ${request.name}!").build() + } } - } val channel = makeChannel(server) val stub = GreeterCoroutineStub(channel) runBlocking { - assertThat(stub.sayHello(helloRequest("Steven"))) - .isEqualTo(helloReply("Hello, Steven!")) + assertThat(stub.sayHello(helloRequest("Steven"))).isEqualTo(helloReply("Hello, Steven!")) } } @@ -84,19 +82,18 @@ class GeneratedCodeTest : AbstractCallsTest() { fun unaryServerDoesNotRespondGrpcTimeout() = runBlocking { val serverCancelled = Job() - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - suspendUntilCancelled { - serverCancelled.complete() + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + suspendUntilCancelled { serverCancelled.complete() } + } } - } - }) + ) val stub = GreeterCoroutineStub(channel).withDeadlineAfter(100, TimeUnit.MILLISECONDS) - val ex = assertThrows { - stub.sayHello(helloRequest("Topaz")) - } + val ex = assertThrows { stub.sayHello(helloRequest("Topaz")) } assertThat(ex.status.code).isEqualTo(Status.Code.DEADLINE_EXCEEDED) serverCancelled.join() } @@ -105,14 +102,15 @@ class GeneratedCodeTest : AbstractCallsTest() { fun unaryClientCancellation() { val helloReceived = Job() val helloCancelled = Job() - val helloChannel = makeChannel(object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - helloReceived.complete() - suspendUntilCancelled { - helloCancelled.complete() + val helloChannel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + helloReceived.complete() + suspendUntilCancelled { helloCancelled.complete() } + } } - } - }) + ) val helloStub = GreeterCoroutineStub(helloChannel) runBlocking { @@ -128,18 +126,17 @@ class GeneratedCodeTest : AbstractCallsTest() { @Test fun unaryMethodThrowsStatusException() = runBlocking { - val channel = makeChannel( - object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - throw StatusException(Status.PERMISSION_DENIED) + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + throw StatusException(Status.PERMISSION_DENIED) + } } - } - ) + ) val stub = GreeterCoroutineStub(channel) - val ex = assertThrows { - stub.sayHello(helloRequest("Peridot")) - } + val ex = assertThrows { stub.sayHello(helloRequest("Peridot")) } assertThat(ex.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) } @@ -147,27 +144,27 @@ class GeneratedCodeTest : AbstractCallsTest() { fun metadataPassedThrough() = runBlocking { val key = GrpcMetadata.Key.of("key", GrpcMetadata.ASCII_STRING_MARSHALLER) - val server = object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - return HelloReply.newBuilder() - .setMessage("Hello, ${request.name}!" ) - .build() + val server = + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + return HelloReply.newBuilder().setMessage("Hello, ${request.name}!").build() + } } - } val receivedMetadata = CompletableDeferred() - val channel = makeChannel( - server, - object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - headers: GrpcMetadata, - next: ServerCallHandler - ): ServerCall.Listener { - receivedMetadata.complete(headers[key]!!) - return next.startCall(call, headers) + val channel = + makeChannel( + server, + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + headers: GrpcMetadata, + next: ServerCallHandler + ): ServerCall.Listener { + receivedMetadata.complete(headers[key]!!) + return next.startCall(call, headers) + } } - } - ) + ) val stub = GreeterCoroutineStub(channel) val meta = GrpcMetadata() meta.put(key, "Pink Diamond") @@ -177,38 +174,37 @@ class GeneratedCodeTest : AbstractCallsTest() { @Test fun unaryMethodThrowsException() = runBlocking { - val channel = makeChannel( - object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - throw IllegalArgumentException() + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + throw IllegalArgumentException() + } } - } - ) + ) val stub = GreeterCoroutineStub(channel) - val ex = assertThrows { - stub.sayHello(helloRequest("Peridot")) - } + val ex = assertThrows { stub.sayHello(helloRequest("Peridot")) } assertThat(ex.status.code).isEqualTo(Status.Code.UNKNOWN) } @Test fun simpleClientStreamingRpc() = runBlocking { - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override suspend fun clientStreamSayHello(requests: Flow): HelloReply { - return HelloReply.newBuilder() - .setMessage( - requests.toList() - .joinToString(prefix = "Hello, ", separator = ", ") { it.name } - ).build() - } - }) + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun clientStreamSayHello(requests: Flow): HelloReply { + return HelloReply.newBuilder() + .setMessage( + requests.toList().joinToString(prefix = "Hello, ", separator = ", ") { it.name } + ) + .build() + } + } + ) val stub = GreeterCoroutineStub(channel) - val requests = flowOf( - helloRequest("Peridot"), - helloRequest("Lapis") - ) + val requests = flowOf(helloRequest("Peridot"), helloRequest("Lapis")) val response = async { stub.clientStreamSayHello(requests) } assertThat(response.await()).isEqualTo(helloReply("Hello, Peridot, Lapis")) } @@ -218,56 +214,59 @@ class GeneratedCodeTest : AbstractCallsTest() { fun clientStreamingRpcCancellation() = runBlocking { val serverReceived = Job() val serverCancelled = Job() - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override suspend fun clientStreamSayHello(requests: Flow): HelloReply { - requests.collect { - serverReceived.complete() - suspendUntilCancelled { serverCancelled.complete() } + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun clientStreamSayHello(requests: Flow): HelloReply { + requests.collect { + serverReceived.complete() + suspendUntilCancelled { serverCancelled.complete() } + } + throw AssertionError("unreachable") + } } - throw AssertionError("unreachable") - } - }) + ) val stub = GreeterCoroutineStub(channel) val requests = Channel() - val response = async { - stub.clientStreamSayHello(requests.consumeAsFlow()) - } + val response = async { stub.clientStreamSayHello(requests.consumeAsFlow()) } requests.send(helloRequest("Aquamarine")) serverReceived.join() response.cancel() serverCancelled.join() - assertThrows { - requests.send(helloRequest("John")) - } + assertThrows { requests.send(helloRequest("John")) } } @Test fun clientStreamingRpcThrowsStatusException() = runBlocking { - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override suspend fun clientStreamSayHello(requests: Flow): HelloReply { - throw StatusException(Status.PERMISSION_DENIED) - } - }) + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override suspend fun clientStreamSayHello(requests: Flow): HelloReply { + throw StatusException(Status.PERMISSION_DENIED) + } + } + ) val stub = GreeterCoroutineStub(channel) - val ex = assertThrows { - stub.clientStreamSayHello(flowOf()) - } + val ex = assertThrows { stub.clientStreamSayHello(flowOf()) } assertThat(ex.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) } @Test fun simpleServerStreamingRpc() = runBlocking { - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override fun serverStreamSayHello(request: MultiHelloRequest): Flow { - return request.nameList.asFlow().map { helloReply("Hello, $it") } - } - }) + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override fun serverStreamSayHello(request: MultiHelloRequest): Flow { + return request.nameList.asFlow().map { helloReply("Hello, $it") } + } + } + ) - val responses = GreeterCoroutineStub(channel).serverStreamSayHello( - multiHelloRequest("Garnet", "Amethyst", "Pearl") - ) + val responses = + GreeterCoroutineStub(channel) + .serverStreamSayHello(multiHelloRequest("Garnet", "Amethyst", "Pearl")) assertThat(responses.toList()) .containsExactly( @@ -284,20 +283,22 @@ class GeneratedCodeTest : AbstractCallsTest() { val serverCancelled = Job() val serverReceived = Job() - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override fun serverStreamSayHello(request: MultiHelloRequest): Flow { - return flow { - serverReceived.complete() - suspendUntilCancelled { - serverCancelled.complete() + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override fun serverStreamSayHello(request: MultiHelloRequest): Flow { + return flow { + serverReceived.complete() + suspendUntilCancelled { serverCancelled.complete() } + } } } - } - }) + ) - val response = GreeterCoroutineStub(channel).serverStreamSayHello( - multiHelloRequest("Topaz", "Aquamarine") - ).produceIn(this) + val response = + GreeterCoroutineStub(channel) + .serverStreamSayHello(multiHelloRequest("Topaz", "Aquamarine")) + .produceIn(this) serverReceived.join() response.cancel() serverCancelled.join() @@ -306,14 +307,18 @@ class GeneratedCodeTest : AbstractCallsTest() { @FlowPreview @Test fun bidiPingPong() = runBlocking { - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override fun bidiStreamSayHello(requests: Flow): Flow { - return requests.map { helloReply("Hello, ${it.name}") } - } - }) + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override fun bidiStreamSayHello(requests: Flow): Flow { + return requests.map { helloReply("Hello, ${it.name}") } + } + } + ) val requests = Channel() - val responses = GreeterCoroutineStub(channel).bidiStreamSayHello(requests.consumeAsFlow()).produceIn(this) + val responses = + GreeterCoroutineStub(channel).bidiStreamSayHello(requests.consumeAsFlow()).produceIn(this) requests.send(helloRequest("Steven")) assertThat(responses.receive()).isEqualTo(helloReply("Hello, Steven")) @@ -327,11 +332,14 @@ class GeneratedCodeTest : AbstractCallsTest() { @FlowPreview @Test fun bidiStreamingRpcReturnsEarly() = runBlocking { - val channel = makeChannel(object : GreeterCoroutineImplBase() { - override fun bidiStreamSayHello(requests: Flow): Flow { - return requests.take(2).map { helloReply("Hello, ${it.name}") } - } - }) + val channel = + makeChannel( + object : GreeterCoroutineImplBase() { + override fun bidiStreamSayHello(requests: Flow): Flow { + return requests.take(2).map { helloReply("Hello, ${it.name}") } + } + } + ) val stub = GreeterCoroutineStub(channel) val requests = Channel() @@ -350,20 +358,19 @@ class GeneratedCodeTest : AbstractCallsTest() { fun serverScopeCancelledDuringRpc() = runBlocking { val serverJob = Job() val serverReceived = Job() - val channel = makeChannel( - object : GreeterCoroutineImplBase(serverJob) { - override suspend fun sayHello(request: HelloRequest): HelloReply { - serverReceived.complete() - suspendUntilCancelled { /* do nothing */ } + val channel = + makeChannel( + object : GreeterCoroutineImplBase(serverJob) { + override suspend fun sayHello(request: HelloRequest): HelloReply { + serverReceived.complete() + suspendUntilCancelled { /* do nothing */} + } } - } - ) + ) val stub = GreeterCoroutineStub(channel) val test = launch { - val ex = assertThrows { - stub.sayHello(helloRequest("Greg")) - } + val ex = assertThrows { stub.sayHello(helloRequest("Greg")) } assertThat(ex.status.code).isEqualTo(Status.Code.CANCELLED) } serverReceived.join() diff --git a/stub/src/test/java/io/grpc/kotlin/GrpcContextElementTest.kt b/stub/src/test/java/io/grpc/kotlin/GrpcContextElementTest.kt index 29156641..1e498174 100644 --- a/stub/src/test/java/io/grpc/kotlin/GrpcContextElementTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/GrpcContextElementTest.kt @@ -18,12 +18,12 @@ package io.grpc.kotlin import com.google.common.truth.Truth.assertThat import io.grpc.Context +import java.util.concurrent.Executors import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.runBlocking import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -import java.util.concurrent.Executors @RunWith(JUnit4::class) class GrpcContextElementTest { @@ -33,7 +33,8 @@ class GrpcContextElementTest { fun testContextPropagation() { val testGrpcContext = Context.ROOT.withValue(testKey, "testValue") val coroutineContext = - Executors.newSingleThreadExecutor().asCoroutineDispatcher() + GrpcContextElement(testGrpcContext) + Executors.newSingleThreadExecutor().asCoroutineDispatcher() + + GrpcContextElement(testGrpcContext) runBlocking(coroutineContext) { val currentTestKey = testKey.get() // gets from the implicit current gRPC context diff --git a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt index 9ed61a66..15f643e9 100644 --- a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt @@ -20,10 +20,13 @@ import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import io.grpc.* import io.grpc.examples.helloworld.GreeterGrpc +import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase +import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest -import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub -import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase +import java.util.concurrent.Executors +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -32,8 +35,8 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.async import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.channels.toList +import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.delay import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.buffer @@ -51,9 +54,6 @@ import kotlinx.coroutines.withContext import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -import java.util.concurrent.Executors -import kotlin.coroutines.CoroutineContext -import kotlin.coroutines.EmptyCoroutineContext @ExperimentalCoroutinesApi @RunWith(JUnit4::class) @@ -62,11 +62,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun simpleUnaryMethod() = runBlocking { - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { request -> - helloReply("Hello, ${request.name}") - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { request -> + helloReply("Hello, ${request.name}") + } + ) val stub = GreeterGrpc.newBlockingStub(channel) assertThat(stub.sayHello(helloRequest("Steven"))).isEqualTo(helloReply("Hello, Steven")) @@ -77,12 +78,13 @@ class ServerCallsTest : AbstractCallsTest() { fun unaryMethodCancellationPropagatedToServer() = runBlocking { val requestReceived = Job() val cancelled = Job() - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - requestReceived.complete() - suspendUntilCancelled { cancelled.complete() } - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + requestReceived.complete() + suspendUntilCancelled { cancelled.complete() } + } + ) val stub = GreeterGrpc.newFutureStub(channel) val future = stub.sayHello(helloRequest("Garnet")) @@ -95,15 +97,14 @@ class ServerCallsTest : AbstractCallsTest() { fun unaryMethodCancellationContextWithJobPropagatedToServer() = runBlocking { val completable = CompletableDeferred() val requestReceived = Job() - val channel = makeChannel( - // Note that we use runBlocking's context here - ServerCalls.unaryServerMethodDefinition(coroutineContext, sayHelloMethod) { - requestReceived.complete() - suspendUntilCancelled { - completable.complete(42) + val channel = + makeChannel( + // Note that we use runBlocking's context here + ServerCalls.unaryServerMethodDefinition(coroutineContext, sayHelloMethod) { + requestReceived.complete() + suspendUntilCancelled { completable.complete(42) } } - } - ) + ) val stub = GreeterGrpc.newFutureStub(channel) val future = stub.sayHello(helloRequest("Garnet")) @@ -115,25 +116,29 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun unaryRequestHandledWithoutWaitingForHalfClose() = runBlocking { val processingStarted = Job() - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - processingStarted.complete() - helloReply("Hello!") - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + processingStarted.complete() + helloReply("Hello!") + } + ) val clientCall = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) val response = CompletableDeferred() val closeStatus = CompletableDeferred() - clientCall.start(object: ClientCall.Listener() { - override fun onMessage(message: HelloReply) { - response.complete(message) - } + clientCall.start( + object : ClientCall.Listener() { + override fun onMessage(message: HelloReply) { + response.complete(message) + } - override fun onClose(status: Status, trailers: Metadata) { - closeStatus.complete(status) - } - }, Metadata()) + override fun onClose(status: Status, trailers: Metadata) { + closeStatus.complete(status) + } + }, + Metadata() + ) clientCall.sendMessage(helloRequest("")) clientCall.request(1) processingStarted.join() @@ -146,11 +151,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun unaryMethodReceivedTooManyRequests() = runBlocking { - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - helloReply("Hello, ${it.name}") - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + helloReply("Hello, ${it.name}") + } + ) val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -174,13 +180,14 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun unaryMethodFailedWithStatusWithTrailers() = runBlocking { val key: Metadata.Key = Metadata.Key.of("key", Metadata.ASCII_STRING_MARSHALLER) - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - val trailers = Metadata() - trailers.put(key, "value") - throw StatusException(Status.DATA_LOSS, trailers) - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + val trailers = Metadata() + trailers.put(key, "value") + throw StatusException(Status.DATA_LOSS, trailers) + } + ) val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) val closeTrailers = CompletableDeferred() @@ -203,11 +210,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun unaryMethodReceivedNoRequests() = runBlocking { - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - helloReply("Hello, ${it.name}") - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + helloReply("Hello, ${it.name}") + } + ) val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -228,16 +236,15 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun unaryMethodThrowsStatusException() = runBlocking { - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - throw StatusException(Status.OUT_OF_RANGE) - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + throw StatusException(Status.OUT_OF_RANGE) + } + ) val stub = GreeterGrpc.newBlockingStub(channel) - val ex = assertThrows { - stub.sayHello(helloRequest("Peridot")) - } + val ex = assertThrows { stub.sayHello(helloRequest("Peridot")) } assertThat(ex.status.code).isEqualTo(Status.Code.OUT_OF_RANGE) } @@ -245,55 +252,53 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun unaryMethodThrowsException() = runBlocking { - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - throw MyException() - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { throw MyException() } + ) val stub = GreeterGrpc.newBlockingStub(channel) - val ex = assertThrows { - stub.sayHello(helloRequest("Lapis Lazuli")) - } + val ex = assertThrows { stub.sayHello(helloRequest("Lapis Lazuli")) } assertThat(ex.status.code).isEqualTo(Status.Code.UNKNOWN) } @Test fun simpleServerStreaming() = runBlocking { - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { - it.nameList.asFlow().map { helloReply("Hello, $it") } - } - ) + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + it.nameList.asFlow().map { helloReply("Hello, $it") } + } + ) - val responses = ClientCalls.serverStreamingRpc( - channel, - serverStreamingSayHelloMethod, - multiHelloRequest("Garnet", "Amethyst", "Pearl") - ) + val responses = + ClientCalls.serverStreamingRpc( + channel, + serverStreamingSayHelloMethod, + multiHelloRequest("Garnet", "Amethyst", "Pearl") + ) assertThat(responses.toList()) .containsExactly( helloReply("Hello, Garnet"), helloReply("Hello, Amethyst"), helloReply("Hello, Pearl") - ).inOrder() + ) + .inOrder() } @Test fun serverStreamingCancellationPropagatedToServer() = runBlocking { val requestReceived = Job() val cancelled = Job() - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition( - context, - serverStreamingSayHelloMethod - ) { - flow { - requestReceived.complete() - suspendUntilCancelled { cancelled.complete() } + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + flow { + requestReceived.complete() + suspendUntilCancelled { cancelled.complete() } + } } - } - ) + ) val call = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -315,12 +320,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun serverStreamingThrowsStatusException() = runBlocking { - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition( - context, - serverStreamingSayHelloMethod - ) { flow { throw StatusException(Status.OUT_OF_RANGE) } } - ) + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + flow { throw StatusException(Status.OUT_OF_RANGE) } + } + ) val call = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -342,28 +347,33 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun serverStreamingHandledWithoutWaitingForHalfClose() = runBlocking { val processingStarted = Job() - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { - request -> flow { - processingStarted.complete() - for (name in request.nameList) { - emit(helloReply("Hello, $name")) + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + request -> + flow { + processingStarted.complete() + for (name in request.nameList) { + emit(helloReply("Hello, $name")) + } } } - } - ) + ) val clientCall = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT) val responseChannel = Channel() - clientCall.start(object: ClientCall.Listener() { - override fun onMessage(message: HelloReply) { - responseChannel.trySendBlocking(message) - } + clientCall.start( + object : ClientCall.Listener() { + override fun onMessage(message: HelloReply) { + responseChannel.trySendBlocking(message) + } - override fun onClose(status: Status, trailers: Metadata) { - responseChannel.close() - } - }, Metadata()) + override fun onClose(status: Status, trailers: Metadata) { + responseChannel.close() + } + }, + Metadata() + ) clientCall.sendMessage(multiHelloRequest("Ruby", "Sapphire")) clientCall.request(2) processingStarted.join() @@ -377,12 +387,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun serverStreamingThrowsException() = runBlocking { - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition( - context, - serverStreamingSayHelloMethod - ) { throw MyException() } - ) + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + throw MyException() + } + ) val call = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -404,54 +414,44 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun simpleClientStreaming() = runBlocking { - val channel = makeChannel( - ServerCalls.clientStreamingServerMethodDefinition( - context, - clientStreamingSayHelloMethod - ) { requests -> - helloReply(requests.toList().joinToString(separator = ", ", prefix = "Hello, ") { it.name }) - } - ) + val channel = + makeChannel( + ServerCalls.clientStreamingServerMethodDefinition(context, clientStreamingSayHelloMethod) { + requests -> + helloReply( + requests.toList().joinToString(separator = ", ", prefix = "Hello, ") { it.name } + ) + } + ) - val requestChannel = flowOf( - helloRequest("Ruby"), - helloRequest("Sapphire") - ) + val requestChannel = flowOf(helloRequest("Ruby"), helloRequest("Sapphire")) assertThat( - ClientCalls.clientStreamingRpc( - channel, - clientStreamingSayHelloMethod, - requestChannel + ClientCalls.clientStreamingRpc(channel, clientStreamingSayHelloMethod, requestChannel) ) - ).isEqualTo(helloReply("Hello, Ruby, Sapphire")) + .isEqualTo(helloReply("Hello, Ruby, Sapphire")) } @ExperimentalCoroutinesApi // take @Test fun clientStreamingDoesntWaitForAllRequests() = runBlocking { - val channel = makeChannel( - ServerCalls.clientStreamingServerMethodDefinition( - context, - clientStreamingSayHelloMethod - ) { requests -> - val (req1, req2) = requests.take(2).toList() - helloReply("Hello, ${req1.name} and ${req2.name}") - } - ) + val channel = + makeChannel( + ServerCalls.clientStreamingServerMethodDefinition(context, clientStreamingSayHelloMethod) { + requests -> + val (req1, req2) = requests.take(2).toList() + helloReply("Hello, ${req1.name} and ${req2.name}") + } + ) - val requests = flowOf( - helloRequest("Peridot"), - helloRequest("Lapis"), - helloRequest("Jasper"), - helloRequest("Aquamarine") - ) - assertThat( - ClientCalls.clientStreamingRpc( - channel, - clientStreamingSayHelloMethod, - requests + val requests = + flowOf( + helloRequest("Peridot"), + helloRequest("Lapis"), + helloRequest("Jasper"), + helloRequest("Aquamarine") ) - ).isEqualTo(helloReply("Hello, Peridot and Lapis")) + assertThat(ClientCalls.clientStreamingRpc(channel, clientStreamingSayHelloMethod, requests)) + .isEqualTo(helloReply("Hello, Peridot and Lapis")) } @ExperimentalCoroutinesApi // take @@ -459,16 +459,15 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun clientStreamingWhenRequestsCancelledNoBackpressure() = runBlocking { val barrier = Job() - val channel = makeChannel( - ServerCalls.clientStreamingServerMethodDefinition( - context, - clientStreamingSayHelloMethod - ) { requests -> - val (req1, req2) = requests.take(2).toList() - barrier.join() - helloReply("Hello, ${req1.name} and ${req2.name}") - } - ) + val channel = + makeChannel( + ServerCalls.clientStreamingServerMethodDefinition(context, clientStreamingSayHelloMethod) { + requests -> + val (req1, req2) = requests.take(2).toList() + barrier.join() + helloReply("Hello, ${req1.name} and ${req2.name}") + } + ) val requestChannel = Channel() val response = async { @@ -491,18 +490,16 @@ class ServerCallsTest : AbstractCallsTest() { fun clientStreamingCancellationPropagatedToServer() = runBlocking { val requestReceived = Job() val cancelled = Job() - val channel = makeChannel( - ServerCalls.clientStreamingServerMethodDefinition( - context, - clientStreamingSayHelloMethod - ) { - it.collect { - requestReceived.complete() - suspendUntilCancelled { cancelled.complete() } + val channel = + makeChannel( + ServerCalls.clientStreamingServerMethodDefinition(context, clientStreamingSayHelloMethod) { + it.collect { + requestReceived.complete() + suspendUntilCancelled { cancelled.complete() } + } + helloReply("Impossible?") } - helloReply("Impossible?") - } - ) + ) val call = channel.newCall(clientStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -523,12 +520,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun clientStreamingThrowsStatusException() = runBlocking { - val channel = makeChannel( - ServerCalls.clientStreamingServerMethodDefinition( - context, - clientStreamingSayHelloMethod - ) { throw StatusException(Status.INVALID_ARGUMENT) } - ) + val channel = + makeChannel( + ServerCalls.clientStreamingServerMethodDefinition(context, clientStreamingSayHelloMethod) { + throw StatusException(Status.INVALID_ARGUMENT) + } + ) val call = channel.newCall(clientStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -546,14 +543,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun clientStreamingThrowsException() = runBlocking { - val channel = makeChannel( - ServerCalls.clientStreamingServerMethodDefinition( - context, - clientStreamingSayHelloMethod - ) { - throw MyException() - } - ) + val channel = + makeChannel( + ServerCalls.clientStreamingServerMethodDefinition(context, clientStreamingSayHelloMethod) { + throw MyException() + } + ) val call = channel.newCall(clientStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -573,11 +568,15 @@ class ServerCallsTest : AbstractCallsTest() { @FlowPreview @Test fun simpleBidiStreamingPingPong() = runBlocking { - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition(context, bidiStreamingSayHelloMethod) { - requests -> requests.map { helloReply("Hello, ${it.name}") }.onCompletion { emit(helloReply("Goodbye")) } - } - ) + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition(context, bidiStreamingSayHelloMethod) { + requests -> + requests + .map { helloReply("Hello, ${it.name}") } + .onCompletion { emit(helloReply("Goodbye")) } + } + ) val requests = Channel() val responses = @@ -597,19 +596,18 @@ class ServerCallsTest : AbstractCallsTest() { fun bidiStreamingCancellationPropagatedToServer() = runBlocking { val requestReceived = Job() val cancelled = Job() - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context, - bidiStreamingSayHelloMethod - ) { requests -> - flow { - requests.collect { - requestReceived.complete() - suspendUntilCancelled { cancelled.complete() } + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition(context, bidiStreamingSayHelloMethod) { + requests -> + flow { + requests.collect { + requestReceived.complete() + suspendUntilCancelled { cancelled.complete() } + } } } - } - ) + ) val call = channel.newCall(bidiStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -630,12 +628,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun bidiStreamingThrowsStatusException() = runBlocking { - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context, - bidiStreamingSayHelloMethod - ) { flow { throw StatusException(Status.INVALID_ARGUMENT) } } - ) + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition(context, bidiStreamingSayHelloMethod) { + flow { throw StatusException(Status.INVALID_ARGUMENT) } + } + ) val call = channel.newCall(bidiStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -653,12 +651,12 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun bidiStreamingThrowsException() = runBlocking { - val channel = makeChannel( - ServerCalls.bidiStreamingServerMethodDefinition( - context, - bidiStreamingSayHelloMethod - ) { throw MyException() } - ) + val channel = + makeChannel( + ServerCalls.bidiStreamingServerMethodDefinition(context, bidiStreamingSayHelloMethod) { + throw MyException() + } + ) val call = channel.newCall(bidiStreamingSayHelloMethod, CallOptions.DEFAULT) val closeStatus = CompletableDeferred() @@ -684,24 +682,21 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun rejectNonClientStreamingMethod() = runBlocking { assertThrows { - ServerCalls - .clientStreamingServerMethodDefinition(context, sayHelloMethod) { TODO() } + ServerCalls.clientStreamingServerMethodDefinition(context, sayHelloMethod) { TODO() } } } @Test fun rejectNonServerStreamingMethod() = runBlocking { assertThrows { - ServerCalls - .serverStreamingServerMethodDefinition(context, sayHelloMethod) { TODO() } + ServerCalls.serverStreamingServerMethodDefinition(context, sayHelloMethod) { TODO() } } } @Test fun rejectNonBidiStreamingMethod() = runBlocking { assertThrows { - ServerCalls - .bidiStreamingServerMethodDefinition(context, sayHelloMethod) { TODO() } + ServerCalls.bidiStreamingServerMethodDefinition(context, sayHelloMethod) { TODO() } } } @@ -712,32 +707,29 @@ class ServerCallsTest : AbstractCallsTest() { val contextKey = Context.key("testKey") val contextToInject = Context.ROOT.withValue(contextKey, "testValue") - val interceptor = object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - headers: Metadata, - next: ServerCallHandler - ): ServerCall.Listener { - return Contexts.interceptCall( - contextToInject, - call, - headers, - next - ) + val interceptor = + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + headers: Metadata, + next: ServerCallHandler + ): ServerCall.Listener { + return Contexts.interceptCall(contextToInject, call, headers, next) + } } - } - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - withContext(differentThreadContext) { - // Run this in a definitely different thread, just to verify context propagation - // is WAI. - assertThat(contextKey.get(Context.current())).isEqualTo("testValue") - helloReply("Hello, ${it.name}") - } - }, - interceptor - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + withContext(differentThreadContext) { + // Run this in a definitely different thread, just to verify context propagation + // is WAI. + assertThat(contextKey.get(Context.current())).isEqualTo("testValue") + helloReply("Hello, ${it.name}") + } + }, + interceptor + ) val stub = GreeterGrpc.newBlockingStub(channel) assertThat(stub.sayHello(helloRequest("Peridot"))).isEqualTo(helloReply("Hello, Peridot")) @@ -749,33 +741,31 @@ class ServerCallsTest : AbstractCallsTest() { fun serverStreamingFlowControl() = runBlocking { val receiveFirstMessage = Job() val receivedFirstMessage = Job() - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition( - EmptyCoroutineContext, - serverStreamingSayHelloMethod - ) { - channelFlow { - send(helloReply("1st")) - send(helloReply("2nd")) - val thirdSend = launch { - send(helloReply("3rd")) - } - delay(200) - assertThat(thirdSend.isCompleted).isFalse() - receiveFirstMessage.complete() - receivedFirstMessage.join() - thirdSend.join() - }.buffer(Channel.RENDEZVOUS) - } - ) + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition( + EmptyCoroutineContext, + serverStreamingSayHelloMethod + ) { + channelFlow { + send(helloReply("1st")) + send(helloReply("2nd")) + val thirdSend = launch { send(helloReply("3rd")) } + delay(200) + assertThat(thirdSend.isCompleted).isFalse() + receiveFirstMessage.complete() + receivedFirstMessage.join() + thirdSend.join() + } + .buffer(Channel.RENDEZVOUS) + } + ) - val responses = produce { - ClientCalls.serverStreamingRpc( - channel, - serverStreamingSayHelloMethod, - multiHelloRequest() - ).collect { send(it) } - } + val responses = + produce { + ClientCalls.serverStreamingRpc(channel, serverStreamingSayHelloMethod, multiHelloRequest()) + .collect { send(it) } + } receiveFirstMessage.join() assertThat(responses.receive()).isEqualTo(helloReply("1st")) receivedFirstMessage.complete() @@ -785,44 +775,42 @@ class ServerCallsTest : AbstractCallsTest() { @Test fun contextPreservation() = runBlocking { val contextKey = Context.key("foo") - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition( - context, - sayHelloMethod - ) { - assertThat(contextKey.get()).isEqualTo("bar") - helloReply("Hello!") - }, - object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - headers: Metadata, - next: ServerCallHandler - ): ServerCall.Listener = - Contexts.interceptCall( - Context.current().withValue(contextKey, "bar"), - call, - headers, - next - ) - } - ) - assertThat( - ClientCalls.unaryRpc(channel, sayHelloMethod, helloRequest("")) - ).isEqualTo(helloReply("Hello!")) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + assertThat(contextKey.get()).isEqualTo("bar") + helloReply("Hello!") + }, + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + headers: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + Contexts.interceptCall( + Context.current().withValue(contextKey, "bar"), + call, + headers, + next + ) + } + ) + assertThat(ClientCalls.unaryRpc(channel, sayHelloMethod, helloRequest(""))) + .isEqualTo(helloReply("Hello!")) } @Test fun serverCallListenerDefersHeaders() = runBlocking { val requestReceived = Job() val responseReleased = Job() - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - requestReceived.complete() - responseReleased.join() - helloReply("Hello, ${it.name}") - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + requestReceived.complete() + responseReleased.join() + helloReply("Hello, ${it.name}") + } + ) val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) @@ -865,13 +853,14 @@ class ServerCallsTest : AbstractCallsTest() { fun serverCallListenerDefersHeadersOnException() = runBlocking { val requestReceived = Job() val responseReleased = Job() - val channel = makeChannel( - ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { - requestReceived.complete() - responseReleased.join() - throw StatusException(Status.INTERNAL.withDescription("no response frames")) - } - ) + val channel = + makeChannel( + ServerCalls.unaryServerMethodDefinition(context, sayHelloMethod) { + requestReceived.complete() + responseReleased.join() + throw StatusException(Status.INTERNAL.withDescription("no response frames")) + } + ) val call = channel.newCall(sayHelloMethod, CallOptions.DEFAULT) @@ -911,14 +900,15 @@ class ServerCallsTest : AbstractCallsTest() { fun serverCallListenerDefersHeadersOnEmptyStream() = runBlocking { val requestReceived = Job() val responseReleased = Job() - val channel = makeChannel( - ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { - flow { - requestReceived.complete() - responseReleased.join() + val channel = + makeChannel( + ServerCalls.serverStreamingServerMethodDefinition(context, serverStreamingSayHelloMethod) { + flow { + requestReceived.complete() + responseReleased.join() + } } - } - ) + ) val call = channel.newCall(serverStreamingSayHelloMethod, CallOptions.DEFAULT) @@ -957,15 +947,16 @@ class ServerCallsTest : AbstractCallsTest() { runBlocking { val retryCount = 5 val config = getRetryingServiceConfig(retryCount.toDouble()) - val coroutinesServer = object : GreeterCoroutineImplBase() { - var count = 0 - private set - - override suspend fun sayHello(request: HelloRequest): HelloReply { - count++ - throw StatusRuntimeException(Status.UNKNOWN) + val coroutinesServer = + object : GreeterCoroutineImplBase() { + var count = 0 + private set + + override suspend fun sayHello(request: HelloRequest): HelloReply { + count++ + throw StatusRuntimeException(Status.UNKNOWN) + } } - } val channel = makeChannel(coroutinesServer.bindService(), config) @@ -979,18 +970,11 @@ class ServerCallsTest : AbstractCallsTest() { } } - private fun getRetryingServiceConfig( - retryCount: Double - ): Map { + private fun getRetryingServiceConfig(retryCount: Double): Map { val config = hashMapOf() val name = mutableListOf>() - name.add( - mapOf( - "service" to "helloworld.Greeter", - "method" to "SayHello" - ) - ) + name.add(mapOf("service" to "helloworld.Greeter", "method" to "SayHello")) val retryPolicy = hashMapOf() retryPolicy["maxAttempts"] = retryCount @@ -1016,43 +1000,43 @@ class ServerCallsTest : AbstractCallsTest() { fun testPropagateStackTraceForStatusException() = runBlocking { val thrownStatusCause = CompletableDeferred() - val serverImpl = object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - internalServerCall() - } + val serverImpl = + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + internalServerCall() + } - private fun internalServerCall(): Nothing { - val exception = Exception("causal exception") - thrownStatusCause.complete(exception) - throw Status.INTERNAL.withCause(exception).asException() + private fun internalServerCall(): Nothing { + val exception = Exception("causal exception") + thrownStatusCause.complete(exception) + throw Status.INTERNAL.withCause(exception).asException() + } } - } val receivedStatusCause = CompletableDeferred() - val interceptor = object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - requestHeaders: Metadata, - next: ServerCallHandler - ): ServerCall.Listener = - next.startCall( - object : ForwardingServerCall.SimpleForwardingServerCall(call) { - override fun close(status: Status, trailers: Metadata) { - receivedStatusCause.complete(status.cause) - super.close(status, trailers) - } - }, - requestHeaders - ) - } + val interceptor = + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + requestHeaders: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + next.startCall( + object : ForwardingServerCall.SimpleForwardingServerCall(call) { + override fun close(status: Status, trailers: Metadata) { + receivedStatusCause.complete(status.cause) + super.close(status, trailers) + } + }, + requestHeaders + ) + } val channel = makeChannel(serverImpl, interceptor) val stub = GreeterGrpc.newBlockingStub(channel) - val clientException = assertThrows { - stub.sayHello(helloRequest("")) - } + val clientException = assertThrows { stub.sayHello(helloRequest("")) } // the exception should not propagate to the client assertThat(clientException.cause).isNull() @@ -1068,43 +1052,43 @@ class ServerCallsTest : AbstractCallsTest() { fun testPropagateStackTraceForStatusRuntimeException() = runBlocking { val thrownStatusCause = CompletableDeferred() - val serverImpl = object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - internalServerCall() - } + val serverImpl = + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + internalServerCall() + } - private fun internalServerCall(): Nothing { - val exception = Exception("causal exception") - thrownStatusCause.complete(exception) - throw Status.INTERNAL.withCause(exception).asRuntimeException() + private fun internalServerCall(): Nothing { + val exception = Exception("causal exception") + thrownStatusCause.complete(exception) + throw Status.INTERNAL.withCause(exception).asRuntimeException() + } } - } val receivedStatusCause = CompletableDeferred() - val interceptor = object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - requestHeaders: Metadata, - next: ServerCallHandler - ): ServerCall.Listener = - next.startCall( - object : ForwardingServerCall.SimpleForwardingServerCall(call) { - override fun close(status: Status, trailers: Metadata) { - receivedStatusCause.complete(status.cause) - super.close(status, trailers) - } - }, - requestHeaders - ) - } + val interceptor = + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + requestHeaders: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + next.startCall( + object : ForwardingServerCall.SimpleForwardingServerCall(call) { + override fun close(status: Status, trailers: Metadata) { + receivedStatusCause.complete(status.cause) + super.close(status, trailers) + } + }, + requestHeaders + ) + } val channel = makeChannel(serverImpl, interceptor) val stub = GreeterGrpc.newBlockingStub(channel) - val clientException = assertThrows { - stub.sayHello(helloRequest("")) - } + val clientException = assertThrows { stub.sayHello(helloRequest("")) } // the exception should not propagate to the client assertThat(clientException.cause).isNull() @@ -1120,43 +1104,43 @@ class ServerCallsTest : AbstractCallsTest() { fun testPropagateStackTraceForNonStatusException() = runBlocking { val thrownStatusCause = CompletableDeferred() - val serverImpl = object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - internalServerCall() - } + val serverImpl = + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + internalServerCall() + } - private fun internalServerCall(): Nothing { - val exception = Exception("causal exception") - thrownStatusCause.complete(exception) - throw exception + private fun internalServerCall(): Nothing { + val exception = Exception("causal exception") + thrownStatusCause.complete(exception) + throw exception + } } - } val receivedStatusCause = CompletableDeferred() - val interceptor = object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - requestHeaders: Metadata, - next: ServerCallHandler - ): ServerCall.Listener = - next.startCall( - object : ForwardingServerCall.SimpleForwardingServerCall(call) { - override fun close(status: Status, trailers: Metadata) { - receivedStatusCause.complete(status.cause) - super.close(status, trailers) - } - }, - requestHeaders - ) - } + val interceptor = + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + requestHeaders: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + next.startCall( + object : ForwardingServerCall.SimpleForwardingServerCall(call) { + override fun close(status: Status, trailers: Metadata) { + receivedStatusCause.complete(status.cause) + super.close(status, trailers) + } + }, + requestHeaders + ) + } val channel = makeChannel(serverImpl, interceptor) val stub = GreeterGrpc.newBlockingStub(channel) - val clientException = assertThrows { - stub.sayHello(helloRequest("")) - } + val clientException = assertThrows { stub.sayHello(helloRequest("")) } // the exception should not propagate to the client assertThat(clientException.cause).isNull() @@ -1172,43 +1156,43 @@ class ServerCallsTest : AbstractCallsTest() { fun testPropagateStackTraceForNonStatusExceptionWithStatusExceptionCause() = runBlocking { val thrownStatusCause = CompletableDeferred() - val serverImpl = object : GreeterCoroutineImplBase() { - override suspend fun sayHello(request: HelloRequest): HelloReply { - internalServerCall() - } + val serverImpl = + object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + internalServerCall() + } - private fun internalServerCall(): Nothing { - val exception = Exception("causal exception", Status.INTERNAL.asException()) - thrownStatusCause.complete(exception) - throw exception + private fun internalServerCall(): Nothing { + val exception = Exception("causal exception", Status.INTERNAL.asException()) + thrownStatusCause.complete(exception) + throw exception + } } - } val receivedStatusCause = CompletableDeferred() - val interceptor = object : ServerInterceptor { - override fun interceptCall( - call: ServerCall, - requestHeaders: Metadata, - next: ServerCallHandler - ): ServerCall.Listener = - next.startCall( - object : ForwardingServerCall.SimpleForwardingServerCall(call) { - override fun close(status: Status, trailers: Metadata) { - receivedStatusCause.complete(status.cause) - super.close(status, trailers) - } - }, - requestHeaders - ) - } + val interceptor = + object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + requestHeaders: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + next.startCall( + object : ForwardingServerCall.SimpleForwardingServerCall(call) { + override fun close(status: Status, trailers: Metadata) { + receivedStatusCause.complete(status.cause) + super.close(status, trailers) + } + }, + requestHeaders + ) + } val channel = makeChannel(serverImpl, interceptor) val stub = GreeterGrpc.newBlockingStub(channel) - val clientException = assertThrows { - stub.sayHello(helloRequest("")) - } + val clientException = assertThrows { stub.sayHello(helloRequest("")) } // the exception should not propagate to the client assertThat(clientException.cause).isNull() @@ -1219,5 +1203,4 @@ class ServerCallsTest : AbstractCallsTest() { assertThat(statusCause).isEqualTo(thrownStatusCause.await()) assertThat(statusCause!!.stackTraceToString()).contains("internalServerCall") } - } diff --git a/stub/src/test/proto/helloworld/helloworld.proto b/stub/src/test/proto/helloworld/helloworld.proto index 97c19f83..821abeeb 100644 --- a/stub/src/test/proto/helloworld/helloworld.proto +++ b/stub/src/test/proto/helloworld/helloworld.proto @@ -16,8 +16,8 @@ syntax = "proto3"; package helloworld; option java_multiple_files = true; -option java_package = "io.grpc.examples.helloworld"; option java_outer_classname = "HelloWorldProto"; +option java_package = "io.grpc.examples.helloworld"; // The greeting service definition. service Greeter {