Skip to content
This repository was archived by the owner on Jul 6, 2023. It is now read-only.

Commit 5abf2bc

Browse files
fix: enable self signed jwt for grpc (#81)
PiperOrigin-RevId: 386504689 Source-Link: googleapis/googleapis@762094a Source-Link: https://github.com/googleapis/googleapis-gen/commit/6bfc480e1a161d5de121c2bcc3745885d33b265a
1 parent 71c7a5c commit 5abf2bc

File tree

8 files changed

+88
-44
lines changed

8 files changed

+88
-44
lines changed

google/cloud/workflows/executions_v1/services/executions/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,10 @@ def __init__(
363363
client_cert_source_for_mtls=client_cert_source_func,
364364
quota_project_id=client_options.quota_project_id,
365365
client_info=client_info,
366+
always_use_jwt_access=(
367+
Transport == type(self).get_transport_class("grpc")
368+
or Transport == type(self).get_transport_class("grpc_asyncio")
369+
),
366370
)
367371

368372
def list_executions(

google/cloud/workflows/executions_v1beta/services/executions/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ def __init__(
364364
client_cert_source_for_mtls=client_cert_source_func,
365365
quota_project_id=client_options.quota_project_id,
366366
client_info=client_info,
367+
always_use_jwt_access=(
368+
Transport == type(self).get_transport_class("grpc")
369+
or Transport == type(self).get_transport_class("grpc_asyncio")
370+
),
367371
)
368372

369373
def list_executions(

google/cloud/workflows_v1/services/workflows/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ def __init__(
350350
client_cert_source_for_mtls=client_cert_source_func,
351351
quota_project_id=client_options.quota_project_id,
352352
client_info=client_info,
353+
always_use_jwt_access=(
354+
Transport == type(self).get_transport_class("grpc")
355+
or Transport == type(self).get_transport_class("grpc_asyncio")
356+
),
353357
)
354358

355359
def list_workflows(

google/cloud/workflows_v1beta/services/workflows/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ def __init__(
350350
client_cert_source_for_mtls=client_cert_source_func,
351351
quota_project_id=client_options.quota_project_id,
352352
client_info=client_info,
353+
always_use_jwt_access=(
354+
Transport == type(self).get_transport_class("grpc")
355+
or Transport == type(self).get_transport_class("grpc_asyncio")
356+
),
353357
)
354358

355359
def list_workflows(

tests/unit/gapic/executions_v1/test_executions.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -116,24 +116,14 @@ def test_executions_client_from_service_account_info(client_class):
116116
assert client.transport._host == "workflowexecutions.googleapis.com:443"
117117

118118

119-
@pytest.mark.parametrize("client_class", [ExecutionsClient, ExecutionsAsyncClient,])
120-
def test_executions_client_service_account_always_use_jwt(client_class):
121-
with mock.patch.object(
122-
service_account.Credentials, "with_always_use_jwt_access", create=True
123-
) as use_jwt:
124-
creds = service_account.Credentials(None, None, None)
125-
client = client_class(credentials=creds)
126-
use_jwt.assert_not_called()
127-
128-
129119
@pytest.mark.parametrize(
130120
"transport_class,transport_name",
131121
[
132122
(transports.ExecutionsGrpcTransport, "grpc"),
133123
(transports.ExecutionsGrpcAsyncIOTransport, "grpc_asyncio"),
134124
],
135125
)
136-
def test_executions_client_service_account_always_use_jwt_true(
126+
def test_executions_client_service_account_always_use_jwt(
137127
transport_class, transport_name
138128
):
139129
with mock.patch.object(
@@ -143,6 +133,13 @@ def test_executions_client_service_account_always_use_jwt_true(
143133
transport = transport_class(credentials=creds, always_use_jwt_access=True)
144134
use_jwt.assert_called_once_with(True)
145135

136+
with mock.patch.object(
137+
service_account.Credentials, "with_always_use_jwt_access", create=True
138+
) as use_jwt:
139+
creds = service_account.Credentials(None, None, None)
140+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
141+
use_jwt.assert_not_called()
142+
146143

147144
@pytest.mark.parametrize("client_class", [ExecutionsClient, ExecutionsAsyncClient,])
148145
def test_executions_client_from_service_account_file(client_class):
@@ -219,6 +216,7 @@ def test_executions_client_client_options(
219216
client_cert_source_for_mtls=None,
220217
quota_project_id=None,
221218
client_info=transports.base.DEFAULT_CLIENT_INFO,
219+
always_use_jwt_access=True,
222220
)
223221

224222
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -235,6 +233,7 @@ def test_executions_client_client_options(
235233
client_cert_source_for_mtls=None,
236234
quota_project_id=None,
237235
client_info=transports.base.DEFAULT_CLIENT_INFO,
236+
always_use_jwt_access=True,
238237
)
239238

240239
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -251,6 +250,7 @@ def test_executions_client_client_options(
251250
client_cert_source_for_mtls=None,
252251
quota_project_id=None,
253252
client_info=transports.base.DEFAULT_CLIENT_INFO,
253+
always_use_jwt_access=True,
254254
)
255255

256256
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -279,6 +279,7 @@ def test_executions_client_client_options(
279279
client_cert_source_for_mtls=None,
280280
quota_project_id="octopus",
281281
client_info=transports.base.DEFAULT_CLIENT_INFO,
282+
always_use_jwt_access=True,
282283
)
283284

284285

@@ -343,6 +344,7 @@ def test_executions_client_mtls_env_auto(
343344
client_cert_source_for_mtls=expected_client_cert_source,
344345
quota_project_id=None,
345346
client_info=transports.base.DEFAULT_CLIENT_INFO,
347+
always_use_jwt_access=True,
346348
)
347349

348350
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -376,6 +378,7 @@ def test_executions_client_mtls_env_auto(
376378
client_cert_source_for_mtls=expected_client_cert_source,
377379
quota_project_id=None,
378380
client_info=transports.base.DEFAULT_CLIENT_INFO,
381+
always_use_jwt_access=True,
379382
)
380383

381384
# Check the case client_cert_source and ADC client cert are not provided.
@@ -397,6 +400,7 @@ def test_executions_client_mtls_env_auto(
397400
client_cert_source_for_mtls=None,
398401
quota_project_id=None,
399402
client_info=transports.base.DEFAULT_CLIENT_INFO,
403+
always_use_jwt_access=True,
400404
)
401405

402406

@@ -427,6 +431,7 @@ def test_executions_client_client_options_scopes(
427431
client_cert_source_for_mtls=None,
428432
quota_project_id=None,
429433
client_info=transports.base.DEFAULT_CLIENT_INFO,
434+
always_use_jwt_access=True,
430435
)
431436

432437

@@ -457,6 +462,7 @@ def test_executions_client_client_options_credentials_file(
457462
client_cert_source_for_mtls=None,
458463
quota_project_id=None,
459464
client_info=transports.base.DEFAULT_CLIENT_INFO,
465+
always_use_jwt_access=True,
460466
)
461467

462468

@@ -474,6 +480,7 @@ def test_executions_client_client_options_from_dict():
474480
client_cert_source_for_mtls=None,
475481
quota_project_id=None,
476482
client_info=transports.base.DEFAULT_CLIENT_INFO,
483+
always_use_jwt_access=True,
477484
)
478485

479486

tests/unit/gapic/executions_v1beta/test_executions.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -118,24 +118,14 @@ def test_executions_client_from_service_account_info(client_class):
118118
assert client.transport._host == "workflowexecutions.googleapis.com:443"
119119

120120

121-
@pytest.mark.parametrize("client_class", [ExecutionsClient, ExecutionsAsyncClient,])
122-
def test_executions_client_service_account_always_use_jwt(client_class):
123-
with mock.patch.object(
124-
service_account.Credentials, "with_always_use_jwt_access", create=True
125-
) as use_jwt:
126-
creds = service_account.Credentials(None, None, None)
127-
client = client_class(credentials=creds)
128-
use_jwt.assert_not_called()
129-
130-
131121
@pytest.mark.parametrize(
132122
"transport_class,transport_name",
133123
[
134124
(transports.ExecutionsGrpcTransport, "grpc"),
135125
(transports.ExecutionsGrpcAsyncIOTransport, "grpc_asyncio"),
136126
],
137127
)
138-
def test_executions_client_service_account_always_use_jwt_true(
128+
def test_executions_client_service_account_always_use_jwt(
139129
transport_class, transport_name
140130
):
141131
with mock.patch.object(
@@ -145,6 +135,13 @@ def test_executions_client_service_account_always_use_jwt_true(
145135
transport = transport_class(credentials=creds, always_use_jwt_access=True)
146136
use_jwt.assert_called_once_with(True)
147137

138+
with mock.patch.object(
139+
service_account.Credentials, "with_always_use_jwt_access", create=True
140+
) as use_jwt:
141+
creds = service_account.Credentials(None, None, None)
142+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
143+
use_jwt.assert_not_called()
144+
148145

149146
@pytest.mark.parametrize("client_class", [ExecutionsClient, ExecutionsAsyncClient,])
150147
def test_executions_client_from_service_account_file(client_class):
@@ -221,6 +218,7 @@ def test_executions_client_client_options(
221218
client_cert_source_for_mtls=None,
222219
quota_project_id=None,
223220
client_info=transports.base.DEFAULT_CLIENT_INFO,
221+
always_use_jwt_access=True,
224222
)
225223

226224
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -237,6 +235,7 @@ def test_executions_client_client_options(
237235
client_cert_source_for_mtls=None,
238236
quota_project_id=None,
239237
client_info=transports.base.DEFAULT_CLIENT_INFO,
238+
always_use_jwt_access=True,
240239
)
241240

242241
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -253,6 +252,7 @@ def test_executions_client_client_options(
253252
client_cert_source_for_mtls=None,
254253
quota_project_id=None,
255254
client_info=transports.base.DEFAULT_CLIENT_INFO,
255+
always_use_jwt_access=True,
256256
)
257257

258258
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -281,6 +281,7 @@ def test_executions_client_client_options(
281281
client_cert_source_for_mtls=None,
282282
quota_project_id="octopus",
283283
client_info=transports.base.DEFAULT_CLIENT_INFO,
284+
always_use_jwt_access=True,
284285
)
285286

286287

@@ -345,6 +346,7 @@ def test_executions_client_mtls_env_auto(
345346
client_cert_source_for_mtls=expected_client_cert_source,
346347
quota_project_id=None,
347348
client_info=transports.base.DEFAULT_CLIENT_INFO,
349+
always_use_jwt_access=True,
348350
)
349351

350352
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -378,6 +380,7 @@ def test_executions_client_mtls_env_auto(
378380
client_cert_source_for_mtls=expected_client_cert_source,
379381
quota_project_id=None,
380382
client_info=transports.base.DEFAULT_CLIENT_INFO,
383+
always_use_jwt_access=True,
381384
)
382385

383386
# Check the case client_cert_source and ADC client cert are not provided.
@@ -399,6 +402,7 @@ def test_executions_client_mtls_env_auto(
399402
client_cert_source_for_mtls=None,
400403
quota_project_id=None,
401404
client_info=transports.base.DEFAULT_CLIENT_INFO,
405+
always_use_jwt_access=True,
402406
)
403407

404408

@@ -429,6 +433,7 @@ def test_executions_client_client_options_scopes(
429433
client_cert_source_for_mtls=None,
430434
quota_project_id=None,
431435
client_info=transports.base.DEFAULT_CLIENT_INFO,
436+
always_use_jwt_access=True,
432437
)
433438

434439

@@ -459,6 +464,7 @@ def test_executions_client_client_options_credentials_file(
459464
client_cert_source_for_mtls=None,
460465
quota_project_id=None,
461466
client_info=transports.base.DEFAULT_CLIENT_INFO,
467+
always_use_jwt_access=True,
462468
)
463469

464470

@@ -476,6 +482,7 @@ def test_executions_client_client_options_from_dict():
476482
client_cert_source_for_mtls=None,
477483
quota_project_id=None,
478484
client_info=transports.base.DEFAULT_CLIENT_INFO,
485+
always_use_jwt_access=True,
479486
)
480487

481488

0 commit comments

Comments
 (0)