Skip to content

Commit

Permalink
Merge pull request #182 from jacobtomlinson/gcp-auth-refresh
Browse files Browse the repository at this point in the history
Wrap googleapiclient compute in class to refresh expired tokens
  • Loading branch information
quasiben committed Nov 20, 2020
2 parents a190601 + bb6d4c9 commit fef4279
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
52 changes: 33 additions & 19 deletions dask_cloudprovider/gcp/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def __init__(
**kwargs,
):

self.compute = authenticate()
self.compute = GCPCompute()

self.config = dask.config.get("cloudprovider.gcp", {})
self.auto_shutdown = (
Expand Down Expand Up @@ -546,24 +546,38 @@ def __init__(
super().__init__(**kwargs)


def authenticate():
if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", False):
compute = googleapiclient.discovery.build("compute", "v1")
else:
import google.auth.credentials # google-auth

path = os.path.join(os.path.expanduser("~"), ".config/gcloud/credentials.db")
if not os.path.exists(path):
raise GCPCredentialsError()
conn = sqlite3.connect(path)
creds_rows = conn.execute("select * from credentials").fetchall()
with tmpfile() as f:
with open(f, "w") as f_:
# take first row
f_.write(creds_rows[0][1])
creds, _ = google.auth.load_credentials_from_file(filename=f)
compute = googleapiclient.discovery.build("compute", "v1", credentials=creds)
return compute
class GCPCompute:
"""Wrapper for the ``googleapiclient`` compute object."""

def __init__(self):
self._compute = self.refresh_client()

def refresh_client(self):
if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", False):
return googleapiclient.discovery.build("compute", "v1")
else:
import google.auth.credentials # google-auth

path = os.path.join(
os.path.expanduser("~"), ".config/gcloud/credentials.db"
)
if not os.path.exists(path):
raise GCPCredentialsError()
conn = sqlite3.connect(path)
creds_rows = conn.execute("select * from credentials").fetchall()
with tmpfile() as f:
with open(f, "w") as f_:
# take first row
f_.write(creds_rows[0][1])
creds, _ = google.auth.load_credentials_from_file(filename=f)
return googleapiclient.discovery.build("compute", "v1", credentials=creds)

def instances(self):
try:
return self._compute.instances()
except Exception: # noqa
self._compute = self.refresh_client()
return self._compute.instances()


# Note: if you have trouble connecting make sure firewall rules in GCP are stetup for 8787,8786,22
4 changes: 2 additions & 2 deletions dask_cloudprovider/gcp/tests/test_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dask
from dask_cloudprovider.gcp.instances import (
GCPCluster,
authenticate,
GCPCompute,
GCPCredentialsError,
)
from dask.distributed import Client
Expand All @@ -12,7 +12,7 @@

def skip_without_credentials():
try:
authenticate()
_ = GCPCompute()
except GCPCredentialsError:
pytest.skip(
"""
Expand Down

0 comments on commit fef4279

Please sign in to comment.