In [68]:
import sqlalchemy as sa
from sqlalchemy.engine.url import URL
import pandas as pd
import os
from enum import Enum
from typing import Optional

from config import dwh_connection
from data.load import DataLoader

In [8]:
AWS_REDSHIFT_USERNAME = os.getenv("AWS_REDSHIFT_USERNAME")
AWS_REDSHIFT_PASSWORD = os.getenv("AWS_REDSHIFT_PASSWORD")

In [10]:
url = URL.create(
        drivername="redshift+redshift_connector",
        host="ml-cluster.cao3kphpeedo.us-east-1.redshift.amazonaws.com",
        port=5439,
        database="dev",
        username=AWS_REDSHIFT_USERNAME,
        password=AWS_REDSHIFT_PASSWORD,
    )
url

redshift+redshift_connector://admin:***@ml-cluster.cao3kphpeedo.us-east-1.redshift.amazonaws.com:5439/dev

In [11]:
engine = sa.create_engine(url)

In [21]:
engine.__class__

sqlalchemy.engine.base.Engine

In [17]:
with engine.connect() as conn:
    df = pd.read_sql("SELECT * FROM public.applications", con=conn)

In [18]:
df

Unnamed: 0,sk_id_curr,target,name_contract_type,code_gender,flag_own_car,flag_own_realty,cnt_children,amt_income_total,amt_credit,amt_annuity,...,flag_document_18,flag_document_19,flag_document_20,flag_document_21,amt_req_credit_bureau_hour,amt_req_credit_bureau_day,amt_req_credit_bureau_week,amt_req_credit_bureau_mon,amt_req_credit_bureau_qrt,amt_req_credit_bureau_year
0,100003.0,0.0,Cash loans,F,N,N,0.0,270000.0,1293502.0,35698.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,100008.0,0.0,Cash loans,M,N,Y,0.0,99000.0,490495.0,27517.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0
2,100010.0,0.0,Cash loans,M,Y,Y,0.0,360000.0,1530000.0,42075.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,100011.0,0.0,Cash loans,F,N,Y,0.0,112500.0,1019610.0,33826.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
4,100014.0,0.0,Cash loans,F,N,Y,1.0,112500.0,652500.0,21177.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
307506,456242.0,0.0,Cash loans,M,Y,Y,0.0,198000.0,1312110.0,52168.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
307507,456245.0,0.0,Cash loans,F,N,Y,3.0,81000.0,269550.0,11871.0,...,0.0,0.0,0.0,0.0,,,,,,
307508,456246.0,0.0,Cash loans,F,N,Y,1.0,94500.0,225000.0,10620.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
307509,456251.0,0.0,Cash loans,M,N,N,0.0,157500.0,254700.0,27558.0,...,0.0,0.0,0.0,0.0,,,,,,


In [69]:
class DatasetTablename(Enum):
    """Map dataset names to table names in database."""

    APPLICATIONS = "application_train"
    BUREAU_BALANCE = "bureau_balance"
    BUREAU = "bureau"
    CREDIT_CARD_BALANCE = "credit_card_balance"
    INSTALLMENTS_PAYMENTS = "installments_payments"
    PREVIOUS_APPLICATIONS = "previous_application"
    CASH_BALANCE = "pos_cash_balance"

    @classmethod
    def from_name(cls, name: str) -> str:
        if hasattr(DatasetTablename, name.upper()):
            return getattr(DatasetTablename, name.upper()).value
        else:
            raise ValueError(f"No such dataset: {name}")

In [70]:
class SQLDataLoader(DataLoader):
    engine = dwh_connection()

    def load_dataset(self, dataset_name: str, limit:Optional[int]=None, reload=False) -> pd.DataFrame:
        if (dataset_name not in self.datasets_) or reload:
            with self.engine.connect() as conn:
                table_name = DatasetTablename.from_name(dataset_name)
                limit_str = f" limit {limit}" if limit else 'limit'
                df = pd.read_sql(f"SELECT * FROM public.{table_name}{limit_str}", con=conn)
                self.datasets_[dataset_name] = df
        else:
            df = self.datasets_[dataset_name]
        return df

In [71]:
data_io = SQLDataLoader()
data_io

<__main__.SQLDataLoader at 0x7f094047bc40>

In [72]:
data_io.datasets_

{}

In [73]:
data_io.engine

Engine(redshift+redshift_connector://admin:***@ml-cluster.cao3kphpeedo.us-east-1.redshift.amazonaws.com:5439/dev)

In [74]:
data_io.load_dataset('cash_balance', limit=3)

Unnamed: 0,sk_id_prev,sk_id_curr,months_balance,cnt_instalment,cnt_instalment_future,name_contract_status,sk_dpd,sk_dpd_def
0,1000001.0,158271.0,-8.0,2.0,0.0,Completed,0.0,0.0
1,1000001.0,158271.0,-9.0,12.0,11.0,Active,0.0,0.0
2,1000001.0,158271.0,-10.0,12.0,12.0,Active,0.0,0.0


In [58]:
data_io.list_loaded()

['bureau']

In [59]:
data_io.load_dataset('bureau', limit=7)
data_io['bureau'].shape

(3, 17)

In [62]:
data_io.load_dataset('bureau', limit=7, reload=True)
data_io['bureau'].shape

(7, 17)