Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit 0a69423

Browse files
fix: enable self signed jwt for grpc (#18)
PiperOrigin-RevId: 386504689 Source-Link: googleapis/googleapis@762094a Source-Link: https://github.com/googleapis/googleapis-gen/commit/6bfc480e1a161d5de121c2bcc3745885d33b265a
1 parent 70386c8 commit 0a69423

File tree

12 files changed

+132
-76
lines changed

12 files changed

+132
-76
lines changed

google/cloud/dataflow_v1beta3/services/flex_templates_service/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def __init__(
332332
client_cert_source_for_mtls=client_cert_source_func,
333333
quota_project_id=client_options.quota_project_id,
334334
client_info=client_info,
335+
always_use_jwt_access=(
336+
Transport == type(self).get_transport_class("grpc")
337+
or Transport == type(self).get_transport_class("grpc_asyncio")
338+
),
335339
)
336340

337341
def launch_flex_template(

google/cloud/dataflow_v1beta3/services/jobs_v1_beta3/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@ def __init__(
333333
client_cert_source_for_mtls=client_cert_source_func,
334334
quota_project_id=client_options.quota_project_id,
335335
client_info=client_info,
336+
always_use_jwt_access=(
337+
Transport == type(self).get_transport_class("grpc")
338+
or Transport == type(self).get_transport_class("grpc_asyncio")
339+
),
336340
)
337341

338342
def create_job(

google/cloud/dataflow_v1beta3/services/messages_v1_beta3/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ def __init__(
330330
client_cert_source_for_mtls=client_cert_source_func,
331331
quota_project_id=client_options.quota_project_id,
332332
client_info=client_info,
333+
always_use_jwt_access=(
334+
Transport == type(self).get_transport_class("grpc")
335+
or Transport == type(self).get_transport_class("grpc_asyncio")
336+
),
333337
)
334338

335339
def list_job_messages(

google/cloud/dataflow_v1beta3/services/metrics_v1_beta3/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ def __init__(
331331
client_cert_source_for_mtls=client_cert_source_func,
332332
quota_project_id=client_options.quota_project_id,
333333
client_info=client_info,
334+
always_use_jwt_access=(
335+
Transport == type(self).get_transport_class("grpc")
336+
or Transport == type(self).get_transport_class("grpc_asyncio")
337+
),
334338
)
335339

336340
def get_job_metrics(

google/cloud/dataflow_v1beta3/services/snapshots_v1_beta3/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ def __init__(
331331
client_cert_source_for_mtls=client_cert_source_func,
332332
quota_project_id=client_options.quota_project_id,
333333
client_info=client_info,
334+
always_use_jwt_access=(
335+
Transport == type(self).get_transport_class("grpc")
336+
or Transport == type(self).get_transport_class("grpc_asyncio")
337+
),
334338
)
335339

336340
def get_snapshot(

google/cloud/dataflow_v1beta3/services/templates_service/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@ def __init__(
333333
client_cert_source_for_mtls=client_cert_source_func,
334334
quota_project_id=client_options.quota_project_id,
335335
client_info=client_info,
336+
always_use_jwt_access=(
337+
Transport == type(self).get_transport_class("grpc")
338+
or Transport == type(self).get_transport_class("grpc_asyncio")
339+
),
336340
)
337341

338342
def create_job_from_template(

tests/unit/gapic/dataflow_v1beta3/test_flex_templates_service.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -124,26 +124,14 @@ def test_flex_templates_service_client_from_service_account_info(client_class):
124124
assert client.transport._host == "dataflow.googleapis.com:443"
125125

126126

127-
@pytest.mark.parametrize(
128-
"client_class", [FlexTemplatesServiceClient, FlexTemplatesServiceAsyncClient,]
129-
)
130-
def test_flex_templates_service_client_service_account_always_use_jwt(client_class):
131-
with mock.patch.object(
132-
service_account.Credentials, "with_always_use_jwt_access", create=True
133-
) as use_jwt:
134-
creds = service_account.Credentials(None, None, None)
135-
client = client_class(credentials=creds)
136-
use_jwt.assert_not_called()
137-
138-
139127
@pytest.mark.parametrize(
140128
"transport_class,transport_name",
141129
[
142130
(transports.FlexTemplatesServiceGrpcTransport, "grpc"),
143131
(transports.FlexTemplatesServiceGrpcAsyncIOTransport, "grpc_asyncio"),
144132
],
145133
)
146-
def test_flex_templates_service_client_service_account_always_use_jwt_true(
134+
def test_flex_templates_service_client_service_account_always_use_jwt(
147135
transport_class, transport_name
148136
):
149137
with mock.patch.object(
@@ -153,6 +141,13 @@ def test_flex_templates_service_client_service_account_always_use_jwt_true(
153141
transport = transport_class(credentials=creds, always_use_jwt_access=True)
154142
use_jwt.assert_called_once_with(True)
155143

144+
with mock.patch.object(
145+
service_account.Credentials, "with_always_use_jwt_access", create=True
146+
) as use_jwt:
147+
creds = service_account.Credentials(None, None, None)
148+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
149+
use_jwt.assert_not_called()
150+
156151

157152
@pytest.mark.parametrize(
158153
"client_class", [FlexTemplatesServiceClient, FlexTemplatesServiceAsyncClient,]
@@ -237,6 +232,7 @@ def test_flex_templates_service_client_client_options(
237232
client_cert_source_for_mtls=None,
238233
quota_project_id=None,
239234
client_info=transports.base.DEFAULT_CLIENT_INFO,
235+
always_use_jwt_access=True,
240236
)
241237

242238
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -253,6 +249,7 @@ def test_flex_templates_service_client_client_options(
253249
client_cert_source_for_mtls=None,
254250
quota_project_id=None,
255251
client_info=transports.base.DEFAULT_CLIENT_INFO,
252+
always_use_jwt_access=True,
256253
)
257254

258255
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -269,6 +266,7 @@ def test_flex_templates_service_client_client_options(
269266
client_cert_source_for_mtls=None,
270267
quota_project_id=None,
271268
client_info=transports.base.DEFAULT_CLIENT_INFO,
269+
always_use_jwt_access=True,
272270
)
273271

274272
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -297,6 +295,7 @@ def test_flex_templates_service_client_client_options(
297295
client_cert_source_for_mtls=None,
298296
quota_project_id="octopus",
299297
client_info=transports.base.DEFAULT_CLIENT_INFO,
298+
always_use_jwt_access=True,
300299
)
301300

302301

@@ -373,6 +372,7 @@ def test_flex_templates_service_client_mtls_env_auto(
373372
client_cert_source_for_mtls=expected_client_cert_source,
374373
quota_project_id=None,
375374
client_info=transports.base.DEFAULT_CLIENT_INFO,
375+
always_use_jwt_access=True,
376376
)
377377

378378
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -406,6 +406,7 @@ def test_flex_templates_service_client_mtls_env_auto(
406406
client_cert_source_for_mtls=expected_client_cert_source,
407407
quota_project_id=None,
408408
client_info=transports.base.DEFAULT_CLIENT_INFO,
409+
always_use_jwt_access=True,
409410
)
410411

411412
# Check the case client_cert_source and ADC client cert are not provided.
@@ -427,6 +428,7 @@ def test_flex_templates_service_client_mtls_env_auto(
427428
client_cert_source_for_mtls=None,
428429
quota_project_id=None,
429430
client_info=transports.base.DEFAULT_CLIENT_INFO,
431+
always_use_jwt_access=True,
430432
)
431433

432434

@@ -461,6 +463,7 @@ def test_flex_templates_service_client_client_options_scopes(
461463
client_cert_source_for_mtls=None,
462464
quota_project_id=None,
463465
client_info=transports.base.DEFAULT_CLIENT_INFO,
466+
always_use_jwt_access=True,
464467
)
465468

466469

@@ -495,6 +498,7 @@ def test_flex_templates_service_client_client_options_credentials_file(
495498
client_cert_source_for_mtls=None,
496499
quota_project_id=None,
497500
client_info=transports.base.DEFAULT_CLIENT_INFO,
501+
always_use_jwt_access=True,
498502
)
499503

500504

@@ -514,6 +518,7 @@ def test_flex_templates_service_client_client_options_from_dict():
514518
client_cert_source_for_mtls=None,
515519
quota_project_id=None,
516520
client_info=transports.base.DEFAULT_CLIENT_INFO,
521+
always_use_jwt_access=True,
517522
)
518523

519524

tests/unit/gapic/dataflow_v1beta3/test_jobs_v1_beta3.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -119,24 +119,14 @@ def test_jobs_v1_beta3_client_from_service_account_info(client_class):
119119
assert client.transport._host == "dataflow.googleapis.com:443"
120120

121121

122-
@pytest.mark.parametrize("client_class", [JobsV1Beta3Client, JobsV1Beta3AsyncClient,])
123-
def test_jobs_v1_beta3_client_service_account_always_use_jwt(client_class):
124-
with mock.patch.object(
125-
service_account.Credentials, "with_always_use_jwt_access", create=True
126-
) as use_jwt:
127-
creds = service_account.Credentials(None, None, None)
128-
client = client_class(credentials=creds)
129-
use_jwt.assert_not_called()
130-
131-
132122
@pytest.mark.parametrize(
133123
"transport_class,transport_name",
134124
[
135125
(transports.JobsV1Beta3GrpcTransport, "grpc"),
136126
(transports.JobsV1Beta3GrpcAsyncIOTransport, "grpc_asyncio"),
137127
],
138128
)
139-
def test_jobs_v1_beta3_client_service_account_always_use_jwt_true(
129+
def test_jobs_v1_beta3_client_service_account_always_use_jwt(
140130
transport_class, transport_name
141131
):
142132
with mock.patch.object(
@@ -146,6 +136,13 @@ def test_jobs_v1_beta3_client_service_account_always_use_jwt_true(
146136
transport = transport_class(credentials=creds, always_use_jwt_access=True)
147137
use_jwt.assert_called_once_with(True)
148138

139+
with mock.patch.object(
140+
service_account.Credentials, "with_always_use_jwt_access", create=True
141+
) as use_jwt:
142+
creds = service_account.Credentials(None, None, None)
143+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
144+
use_jwt.assert_not_called()
145+
149146

150147
@pytest.mark.parametrize("client_class", [JobsV1Beta3Client, JobsV1Beta3AsyncClient,])
151148
def test_jobs_v1_beta3_client_from_service_account_file(client_class):
@@ -222,6 +219,7 @@ def test_jobs_v1_beta3_client_client_options(
222219
client_cert_source_for_mtls=None,
223220
quota_project_id=None,
224221
client_info=transports.base.DEFAULT_CLIENT_INFO,
222+
always_use_jwt_access=True,
225223
)
226224

227225
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -238,6 +236,7 @@ def test_jobs_v1_beta3_client_client_options(
238236
client_cert_source_for_mtls=None,
239237
quota_project_id=None,
240238
client_info=transports.base.DEFAULT_CLIENT_INFO,
239+
always_use_jwt_access=True,
241240
)
242241

243242
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -254,6 +253,7 @@ def test_jobs_v1_beta3_client_client_options(
254253
client_cert_source_for_mtls=None,
255254
quota_project_id=None,
256255
client_info=transports.base.DEFAULT_CLIENT_INFO,
256+
always_use_jwt_access=True,
257257
)
258258

259259
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -282,6 +282,7 @@ def test_jobs_v1_beta3_client_client_options(
282282
client_cert_source_for_mtls=None,
283283
quota_project_id="octopus",
284284
client_info=transports.base.DEFAULT_CLIENT_INFO,
285+
always_use_jwt_access=True,
285286
)
286287

287288

@@ -346,6 +347,7 @@ def test_jobs_v1_beta3_client_mtls_env_auto(
346347
client_cert_source_for_mtls=expected_client_cert_source,
347348
quota_project_id=None,
348349
client_info=transports.base.DEFAULT_CLIENT_INFO,
350+
always_use_jwt_access=True,
349351
)
350352

351353
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -379,6 +381,7 @@ def test_jobs_v1_beta3_client_mtls_env_auto(
379381
client_cert_source_for_mtls=expected_client_cert_source,
380382
quota_project_id=None,
381383
client_info=transports.base.DEFAULT_CLIENT_INFO,
384+
always_use_jwt_access=True,
382385
)
383386

384387
# Check the case client_cert_source and ADC client cert are not provided.
@@ -400,6 +403,7 @@ def test_jobs_v1_beta3_client_mtls_env_auto(
400403
client_cert_source_for_mtls=None,
401404
quota_project_id=None,
402405
client_info=transports.base.DEFAULT_CLIENT_INFO,
406+
always_use_jwt_access=True,
403407
)
404408

405409

@@ -430,6 +434,7 @@ def test_jobs_v1_beta3_client_client_options_scopes(
430434
client_cert_source_for_mtls=None,
431435
quota_project_id=None,
432436
client_info=transports.base.DEFAULT_CLIENT_INFO,
437+
always_use_jwt_access=True,
433438
)
434439

435440

@@ -460,6 +465,7 @@ def test_jobs_v1_beta3_client_client_options_credentials_file(
460465
client_cert_source_for_mtls=None,
461466
quota_project_id=None,
462467
client_info=transports.base.DEFAULT_CLIENT_INFO,
468+
always_use_jwt_access=True,
463469
)
464470

465471

@@ -477,6 +483,7 @@ def test_jobs_v1_beta3_client_client_options_from_dict():
477483
client_cert_source_for_mtls=None,
478484
quota_project_id=None,
479485
client_info=transports.base.DEFAULT_CLIENT_INFO,
486+
always_use_jwt_access=True,
480487
)
481488

482489

0 commit comments

Comments
 (0)