Skip to content

Commit

Permalink
[raz] Raz Client for ADLS to submit proper requests to getback SAS to…
Browse files Browse the repository at this point in the history
…ken (#2362)

- Update existing S3 client to support ADLS
- Updated UTs

- Currently, it sends the read request to get ADLS SAS Token, we need more info for mapping with request methods to read/write/delete ops for the token
  • Loading branch information
Harshg999 committed Jul 26, 2021
1 parent 048b8b8 commit 1da17c4
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 52 deletions.
93 changes: 62 additions & 31 deletions desktop/core/src/desktop/lib/raz/raz_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, raz_url, raz_token, username, service='s3', service_name='cm_
self.raz_token = raz_token
self.username = username
self.service = service

if self.service == 'adls':
self.service_params = {
'endpoint_prefix': 'adls',
Expand All @@ -88,6 +89,7 @@ def __init__(self, raz_url, raz_token, username, service='s3', service_name='cm_
'service_name': 's3',
'serviceType': 's3'
}

self.service_name = service_name
self.cluster_name = cluster_name
self.requestid = str(uuid.uuid4())
Expand All @@ -100,51 +102,34 @@ def check_access(self, method, url, params=None, headers=None):
params = params if params is not None else {}
headers = headers if headers is not None else {}

allparams = [raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in url_params.items()]
allparams.extend([raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in params.items()])
headers = [raz_signer.StringStringMapProto(key=key, value=val) for key, val in headers.items()]
endpoint = "%s://%s" % (path.scheme, path.netloc)
resource_path = path.path.lstrip("/")

LOG.debug(
"Preparing sign request with http_method: {%s}, headers: {%s}, parameters: {%s}, endpoint: {%s}, resource_path: {%s}" %
(method, headers, allparams, endpoint, resource_path)
)
raz_req = raz_signer.SignRequestProto(
endpoint_prefix=self.service_params['endpoint_prefix'],
service_name=self.service_params['service_name'],
endpoint=endpoint,
http_method=method,
headers=headers,
parameters=allparams,
resource_path=resource_path,
time_offset=0
)
raz_req_serialized = raz_req.SerializeToString()
signed_request = base64.b64encode(raz_req_serialized)

request_data = {
"requestId": self.requestid,
"serviceType": self.service_params['serviceType'],
"serviceName": self.service_name,
"user": self.username,
"userGroups": [],
"accessTime": "",
"clientIpAddress": "",
"clientType": "",
"clusterName": self.cluster_name,
"clusterType": "",
"sessionId": "",
"context": {
"S3_SIGN_REQUEST": signed_request
}
"accessTime": "",
"context": {}
}
headers = {"Content-Type":"application/json", "Accept-Encoding":"gzip,deflate"}
raz_url = "%s/api/authz/s3/access?delegation=%s" % (self.raz_url, self.raz_token)
LOG.debug('Raz url: %s' % raz_url)
request_headers = {"Content-Type": "application/json"}
raz_url = "%s/api/authz/%s/access?delegation=%s" % (self.raz_url, self.service, self.raz_token)

LOG.debug("Sending access check headers: {%s} request_data: {%s}" % (headers, request_data))
raz_req = requests.post(raz_url, headers=headers, json=request_data, verify=False)
if self.service == 'adls':
self._make_adls_request(request_data, path, resource_path)
elif self.service == 's3':
self._make_s3_request(request_data, request_headers, method, params, headers, url_params, endpoint, resource_path)

LOG.debug('Raz url: %s' % raz_url)
LOG.debug("Sending access check headers: {%s} request_data: {%s}" % (request_headers, request_data))
raz_req = requests.post(raz_url, headers=request_headers, json=request_data, verify=False)

signed_response_result = None
signed_response = None
Expand All @@ -164,21 +149,67 @@ def check_access(self, method, url, params=None, headers=None):
if result == "ALLOWED":
LOG.debug('Received allowed response %s' % raz_req.json())
signed_response_data = raz_req.json()["operResult"]["additionalInfo"]

if self.service == 'adls':
LOG.debug("Received SAS %s" % signed_response_data["ADLS_DSAS"])
return {'token': signed_response_data["ADLS_DSAS"]}
else:
signed_response_result = signed_response_data["S3_SIGN_RESPONSE"]

if signed_response_result:
if signed_response_result is not None:
raz_response_proto = raz_signer.SignResponseProto()
signed_response = raz_response_proto.FromString(base64.b64decode(signed_response_result))
LOG.debug("Received signed Response %s" % signed_response)

# Signed headers "only"
if signed_response:
if signed_response is not None:
return dict([(i.key, i.value) for i in signed_response.signer_generated_headers])

def _make_adls_request(self, request_data, path, resource_path):
storage_account = path.netloc.split('.')[0]
container, relative_path = resource_path.split('/', 1)

request_data.update({
"clientType": "adls",
"operation": {
"resource": {
"storageaccount": storage_account,
"container": container,
"relativepath": relative_path,
},
"resourceOwner": "",
"action": "read",
"accessTypes":["read"]
}
})

def _make_s3_request(self, request_data, request_headers, method, params, headers, url_params, endpoint, resource_path):

allparams = [raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in url_params.items()]
allparams.extend([raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in params.items()])
headers = [raz_signer.StringStringMapProto(key=key, value=val) for key, val in headers.items()]

LOG.debug(
"Preparing sign request with http_method: {%s}, headers: {%s}, parameters: {%s}, endpoint: {%s}, resource_path: {%s}" %
(method, headers, allparams, endpoint, resource_path)
)
raz_req = raz_signer.SignRequestProto(
endpoint_prefix=self.service_params['endpoint_prefix'],
service_name=self.service_params['service_name'],
endpoint=endpoint,
http_method=method,
headers=headers,
parameters=allparams,
resource_path=resource_path,
time_offset=0
)
raz_req_serialized = raz_req.SerializeToString()
signed_request = base64.b64encode(raz_req_serialized)

request_headers["Accept-Encoding"] = {"gzip,deflate"}
request_data["context"] = {
"S3_SIGN_REQUEST": signed_request
}

def get_raz_client(raz_url, username, auth='kerberos', service='s3', service_name='cm_s3', cluster_name='myCluster'):
if not username:
Expand Down
104 changes: 83 additions & 21 deletions desktop/core/src/desktop/lib/raz/raz_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ def test_renew_delegation_token(self):

t = token.renew_delegation_token(user=self.username)

assert_equal(
assert_equal(t,
'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv?AWSAccessKeyId=AKIA23E77ZX2HVY76YGL&'
'Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304',
t
'Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304'
)

with patch('desktop.lib.raz.raz_client.requests.put') as requests_put:
Expand All @@ -100,36 +99,103 @@ class RazClientTest(unittest.TestCase):
def setUp(self):
self.username = 'gethue'
self.raz_url = 'https://raz.gethue.com:8080'
self.resource_url = 'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv'

def test_get_raz_client(self):
self.s3_path = 'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv'
self.adls_path = 'https://gethuestorageaccount.blob.core.windows.net/demo-gethue-container/demo-dir1/customer.csv'

def test_get_raz_client_adls(self):
with patch('desktop.lib.raz.raz_client.RazToken') as RazToken:
with patch('desktop.lib.raz.raz_client.requests_kerberos.HTTPKerberosAuth') as HTTPKerberosAuth:
client = get_raz_client(
raz_url=self.raz_url,
username=self.username,
auth='kerberos',
service='s3',
service_name='gethue_s3',
service='adls',
service_name='gethue_adls',
cluster_name='gethueCluster'
)

assert_true(isinstance(client, RazClient))

HTTPKerberosAuth.assert_called()
assert_equal(
client.raz_url, self.raz_url
assert_equal(client.raz_url, self.raz_url)
assert_equal(client.service_name, 'gethue_adls')
assert_equal(client.cluster_name, 'gethueCluster')

def test_check_access_adls(self):
with patch('desktop.lib.raz.raz_client.requests.post') as requests_post:
with patch('desktop.lib.raz.raz_client.uuid.uuid4') as uuid:
raz_token = "mock_RAZ_token"

requests_post.return_value = Mock(
json=Mock(return_value=
{
'operResult': {
'result': 'ALLOWED',
'additionalInfo': {
"ADLS_DSAS": "nulltenantIdnullnullbnullALLOWEDnullnull1.05nSlN7t/QiPJ1OFlCruTEPLibFbAhEYYj5wbJuaeQqs="
}
}
}
)
)
assert_equal(
client.service_name, 'gethue_s3'
uuid.return_value = 'mock_request_id'

client = RazClient(self.raz_url, raz_token, username=self.username, service="adls", service_name="adls", cluster_name="cl1")

resp = client.check_access(method='GET', url=self.adls_path)

requests_post.assert_called_with(
"https://raz.gethue.com:8080/api/authz/adls/access?delegation=" + raz_token,
headers={"Content-Type": "application/json"},
json={
'requestId': 'mock_request_id',
'serviceType': 'adls',
'serviceName': 'adls',
'user': 'gethue',
'userGroups': [],
'clientIpAddress': '',
'clientType': 'adls',
'clusterName': 'cl1',
'clusterType': '',
'sessionId': '',
'accessTime': '',
'context': {},
'operation': {
'resource': {
'storageaccount': 'gethuestorageaccount',
'container': 'demo-gethue-container',
'relativepath': 'demo-dir1/customer.csv'
},
'resourceOwner': '',
'action': 'read',
'accessTypes': ['read']
}
},
verify=False
)
assert_equal(
client.cluster_name, 'gethueCluster'
assert_equal(resp['token'], "nulltenantIdnullnullbnullALLOWEDnullnull1.05nSlN7t/QiPJ1OFlCruTEPLibFbAhEYYj5wbJuaeQqs=")

def test_get_raz_client_s3(self):
with patch('desktop.lib.raz.raz_client.RazToken') as RazToken:
with patch('desktop.lib.raz.raz_client.requests_kerberos.HTTPKerberosAuth') as HTTPKerberosAuth:
client = get_raz_client(
raz_url=self.raz_url,
username=self.username,
auth='kerberos',
service='s3',
service_name='gethue_s3',
cluster_name='gethueCluster'
)

assert_true(isinstance(client, RazClient))

HTTPKerberosAuth.assert_called()
assert_equal(client.raz_url, self.raz_url)
assert_equal(client.service_name, 'gethue_s3')
assert_equal(client.cluster_name, 'gethueCluster')

def test_check_access(self):
def test_check_access_s3(self):
raz_token = Mock()

client = RazClient(self.raz_url, raz_token, username=self.username)
Expand Down Expand Up @@ -162,11 +228,7 @@ def test_check_access(self):
)
)

resp = client.check_access(method='GET', url=self.resource_url)
resp = client.check_access(method='GET', url=self.s3_path)

assert_true(
resp
)
assert_equal(
resp['AWSAccessKeyId'], 'AKIA23E77ZX2HVY76YGL'
)
assert_true(resp)
assert_equal(resp['AWSAccessKeyId'], 'AKIA23E77ZX2HVY76YGL')

0 comments on commit 1da17c4

Please sign in to comment.