## SageMaker에서 Redshift 연결하고 데이터 가져오기 (with Python)

In [1]:
from __future__ import print_function
import sys
import base64
import json
from random import shuffle
import random
import datetime
import os

import boto3
from botocore.exceptions import ClientError
import sagemaker
import s3fs
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role() 
region = sagemaker_session.boto_region_name

s3_bucket = sagemaker.Session().default_bucket()  # replace with an existing bucket if needed
s3_prefix = 'redshift-access-test'    # prefix used for all data stored within the bucket


### 1. Postgres Adaptor(psycopg2) 사용하기
- ref : https://www.psycopg.org/docs/

In [3]:
# pip install psycopg2-binary
import psycopg2

- Common utility functions

In [4]:
#uses session manager name to return connection and credential information
def connection_info(db_creds):
    session = boto3.session.Session()
    client = session.client(service_name='secretsmanager')
    get_secret_value_response = client.get_secret_value(SecretId=db_creds)
    if 'SecretString' in get_secret_value_response:
        secret = json.loads(get_secret_value_response['SecretString'])
    else:
        secret = json.loads(base64.b64decode(get_secret_value_response['SecretBinary']))
    return secret

#creates a connection to the cluster
def get_connection(db,db_creds):
    con_params = connection_info(db_creds)
    print("Connection info retrieved from Secrets manager")
    rs_conn=psycopg2.connect(dbname=db, host=con_params['host'], port=con_params['port'], user=con_params['username'], password=con_params['password'])
    cur = rs_conn.cursor()
    cur.execute("set statement_timeout = 1200000;")
    print("Connected to {}".format(db))

    return cur

#Close the connection to the cluster
def close_cursor(cursor):
    print("Connection closed")
    cursor.close()

#submits a query to the cluster
def run_command(cursor, statement):
    cursor.execute(statement)
    res = cursor.fetchall()
    print("Query execution complete")

    return res


In [5]:
db='nyctaxi'
db_creds = 'nyctaxisecret'

#get database connection
cursor = get_connection(db,db_creds)

#run each sql statement
query_str = 'select sysdate;'
query_str = 'select top 3 * from taxischema.nyc_greentaxi;'

result = run_command(cursor, query_str)

close_cursor(cursor) 

Connection info retrieved from Secrets manager
Connected to nyctaxi
Query execution complete
Connection closed


In [6]:
pd.DataFrame(result)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,1,2019-04-01 00:22:52,2019-04-01 00:30:52,N,1,255,112,4,1.3,7.5,0.5,0.5,1.0,0.0,,0.3,9.8,1,1,0.0
1,1,2019-04-01 00:09:21,2019-04-01 00:18:24,N,1,74,147,1,2.7,10.5,0.5,0.5,0.0,0.0,,0.3,11.8,2,1,0.0
2,1,2019-04-01 01:17:39,2019-04-01 01:24:41,N,1,129,129,1,0.7,6.0,0.5,0.5,0.0,0.0,,0.3,7.3,2,1,0.0


### 2. boto3 의 redshift-data API 사용하기

- ref : https://aws.amazon.com/ko/blogs/big-data/using-the-amazon-redshift-data-api-to-interact-from-an-amazon-sagemaker-jupyter-notebook/

In [7]:
import botocore.session as s

In [8]:
bc_session = s.get_session()

session = boto3.Session(
        botocore_session=bc_session,
        region_name=region,
    )

# Setup the client
client_redshift = session.client("redshift-data")
print("Data API client successfully loaded")

Data API client successfully loaded


In [9]:
db='nyctaxi'
db_creds = 'nyctaxisecret'

def connection_info(db_creds):
    session = boto3.session.Session()
    client = session.client(service_name='secretsmanager')
    get_secret_value_response = client.get_secret_value(SecretId=db_creds)
    if 'SecretString' in get_secret_value_response:
        secret = json.loads(get_secret_value_response['SecretString'])
    else:
        secret = json.loads(base64.b64decode(get_secret_value_response['SecretBinary']))

    return secret, get_secret_value_response['ARN']

secrets, secret_arn = connection_info(db_creds)

In [10]:
cluster_id=secrets['dbClusterIdentifier']
print("Cluster_id: " + cluster_id + "\nDB: " + db + "\nSecret ARN: " + secret_arn)

Cluster_id: redshiftcluster-f0csvykgqnfu
DB: nyctaxi
Secret ARN: arn:aws:secretsmanager:us-east-1:308961792850:secret:nyctaxisecret-QHzY9o


In [11]:
client_redshift.list_schemas(
    Database= db, 
    SecretArn= secret_arn, 
    ClusterIdentifier= cluster_id)["Schemas"]

['catalog_history',
 'information_schema',
 'pg_automv',
 'pg_catalog',
 'pg_internal',
 'public',
 'spectrum_schema',
 'taxischema']

In [12]:
client_redshift.list_tables(
    Database= db, 
    SecretArn= secret_arn, 
    SchemaPattern='taxischema',
    ClusterIdentifier= cluster_id)["Tables"]

[{'name': 'nyc_greentaxi', 'schema': 'taxischema', 'type': 'TABLE'}]

- Custom waiter

In [13]:
from botocore.exceptions import WaiterError
from botocore.waiter import WaiterModel
from botocore.waiter import create_waiter_with_client

In [14]:
# Create custom waiter for the Redshift Data API to wait for finish execution of current SQL statement
waiter_name = 'DataAPIExecution'

delay=2
max_attempts=3

#Configure the waiter settings
waiter_config = {
  'version': 2,
  'waiters': {
    'DataAPIExecution': {
      'operation': 'DescribeStatement',
      'delay': delay,
      'maxAttempts': max_attempts,
      'acceptors': [
        {
          "matcher": "path",
          "expected": "FINISHED",
          "argument": "Status",
          "state": "success"
        },
        {
          "matcher": "pathAny",
          "expected": ["PICKED","STARTED","SUBMITTED"],
          "argument": "Status",
          "state": "retry"
        },
        {
          "matcher": "pathAny",
          "expected": ["FAILED","ABORTED"],
          "argument": "Status",
          "state": "failure"
        }
      ],
    },
  },
}

In [15]:
waiter_model = WaiterModel(waiter_config)
custom_waiter = create_waiter_with_client(waiter_name, waiter_model, client_redshift)

- Run query

In [16]:
query_str = 'select sysdate;'
query_str = 'select top 3 * from taxischema.nyc_greentaxi;'

res = client_redshift.execute_statement(Database= db, SecretArn= secret_arn, Sql= query_str, ClusterIdentifier= cluster_id)
print("Redshift Data API execution  started ...")
id=res["Id"]

# Waiter in try block and wait for DATA API to return
try:
    custom_waiter.wait(Id=id)
    print("Done waiting to finish Data API.")
except WaiterError as e:
    print (e)
    
desc=client_redshift.describe_statement(Id=id)
print("Status: " + desc["Status"] + ". Excution time: %d miliseconds" %float(desc["Duration"]/pow(10,6)))

Redshift Data API execution  started ...
Done waiting to finish Data API.
Status: FINISHED. Excution time: 52 miliseconds


In [17]:
res

{'ClusterIdentifier': 'redshiftcluster-f0csvykgqnfu',
 'CreatedAt': datetime.datetime(2022, 7, 16, 1, 35, 8, 185000, tzinfo=tzlocal()),
 'Database': 'nyctaxi',
 'Id': '16e12ee0-15a0-44e2-bfc2-f922401cd5c7',
 'SecretArn': 'arn:aws:secretsmanager:us-east-1:308961792850:secret:nyctaxisecret-QHzY9o',
 'ResponseMetadata': {'RequestId': '290b7079-f67e-4874-957e-72a57f1ec4a6',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '290b7079-f67e-4874-957e-72a57f1ec4a6',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '234',
   'date': 'Sat, 16 Jul 2022 01:35:08 GMT'},
  'RetryAttempts': 0}}

In [18]:
output = client_redshift.get_statement_result(Id=id)

nrows=output["TotalNumRows"]
ncols=len(output["ColumnMetadata"])

print(nrows, ncols)

3 20


In [19]:
col_labels=[]
for i in range(ncols): col_labels.append(output["ColumnMetadata"][i]['label'])
col_labels

['vendorid',
 'lpep_pickup_datetime',
 'lpep_dropoff_datetime',
 'store_and_fwd_flag',
 'ratecodeid',
 'pulocationid',
 'dolocationid',
 'passenger_count',
 'trip_distance',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'ehail_fee',
 'improvement_surcharge',
 'total_amount',
 'payment_type',
 'trip_type',
 'congestion_surcharge']

In [20]:
records = output['Records']
records

[[{'stringValue': '1'},
  {'stringValue': '2019-04-01 00:04:42'},
  {'stringValue': '2019-04-01 00:16:50'},
  {'stringValue': 'N'},
  {'longValue': 1},
  {'longValue': 7},
  {'longValue': 129},
  {'longValue': 1},
  {'stringValue': '3.10'},
  {'stringValue': '12.00'},
  {'stringValue': '0.50'},
  {'stringValue': '0.50'},
  {'stringValue': '0.00'},
  {'stringValue': '0.00'},
  {'stringValue': ''},
  {'stringValue': '0.30'},
  {'stringValue': '13.30'},
  {'stringValue': '2'},
  {'stringValue': '1'},
  {'stringValue': '0.00'}],
 [{'stringValue': '1'},
  {'stringValue': '2019-04-01 00:40:52'},
  {'stringValue': '2019-04-01 01:09:59'},
  {'stringValue': 'N'},
  {'longValue': 1},
  {'longValue': 65},
  {'longValue': 239},
  {'longValue': 1},
  {'stringValue': '8.70'},
  {'stringValue': '28.00'},
  {'stringValue': '3.25'},
  {'stringValue': '0.50'},
  {'stringValue': '6.40'},
  {'stringValue': '0.00'},
  {'stringValue': ''},
  {'stringValue': '0.30'},
  {'stringValue': '38.45'},
  {'stringVal

In [21]:
contents=[]
for i in range(nrows): 
    content=[]
    for j in range(ncols):
        content.append(*records[i][j].values())
    contents.append(content)
contents

[['1',
  '2019-04-01 00:04:42',
  '2019-04-01 00:16:50',
  'N',
  1,
  7,
  129,
  1,
  '3.10',
  '12.00',
  '0.50',
  '0.50',
  '0.00',
  '0.00',
  '',
  '0.30',
  '13.30',
  '2',
  '1',
  '0.00'],
 ['1',
  '2019-04-01 00:40:52',
  '2019-04-01 01:09:59',
  'N',
  1,
  65,
  239,
  1,
  '8.70',
  '28.00',
  '3.25',
  '0.50',
  '6.40',
  '0.00',
  '',
  '0.30',
  '38.45',
  '1',
  '1',
  '2.75'],
 ['1',
  '2019-04-01 00:04:30',
  '2019-04-01 00:07:46',
  'N',
  1,
  223,
  223,
  1,
  '0.90',
  '5.00',
  '0.50',
  '0.50',
  '0.00',
  '0.00',
  '',
  '0.30',
  '6.30',
  '2',
  '1',
  '0.00']]

In [22]:
pd.DataFrame(contents, columns=col_labels)

Unnamed: 0,vendorid,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,ratecodeid,pulocationid,dolocationid,passenger_count,trip_distance,fare_amount,extra,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge
0,1,2019-04-01 00:04:42,2019-04-01 00:16:50,N,1,7,129,1,3.1,12.0,0.5,0.5,0.0,0.0,,0.3,13.3,2,1,0.0
1,1,2019-04-01 00:40:52,2019-04-01 01:09:59,N,1,65,239,1,8.7,28.0,3.25,0.5,6.4,0.0,,0.3,38.45,1,1,2.75
2,1,2019-04-01 00:04:30,2019-04-01 00:07:46,N,1,223,223,1,0.9,5.0,0.5,0.5,0.0,0.0,,0.3,6.3,2,1,0.0
