---
# LPL Code Library
---
This workbook contains all the Python code written in the course of the _2018 LPL Mantas Rationalization_ engagement. The purpose of the engagement was to make improvements in LPL's Mantas AML Transaction monitoring system. In the engagement, Python was heavily utilized in two key workstreams:

1. **Data Profiling:** The engagement began with an exploratory analysis of the Mantas database. On face value, the purpose of this exploration was to identify potential data points to be used in _segmentation_, the process of splitting the population into groups based on AML risk and transactional behavior. However, this exercise also proves useful in familiarizing the team with a large, idiosyncratic new database, and providing readily available summaries of the hundreds of fields in the database. In previous engagements, this process had involved the manual writing and execution of many SQL queries, but in this engagement that process was automated using Python, greatly reducing the team's workload and enhancing the scope of the exercise.
2. **"Below the Line" Testing:** A later portion of the engagement involved the prescription of _threshold sets_ for LPL's current AML scenarios. A "threshold set" describes the point beyond which a certain type of transactional behavior appears suspicious and should be investigated for potential financial crime. EY recommended separate thresholds for each of the aforementioned segments, so that the transaction-monitoring system would be dynamic and take into account a customer's underlying characteristics in order to determine whether activity was or was not suspicious. (For example, transactions that would look suspicious for a regular person might not seem so suspicious for an ultra high net worth client.) In order to affirm the validity of these thresholds, "below the line" (BTL) testing was conducted to assess whether there was statistically significant suspicious acitivitiy going on *beneath* the recommended thresholds. If there had been, then the thresholds would have been lowered further. Python was used in this workstream to automate the collection and aggregation of random samples that would be large enough to determine, with stastical confidence, that there was no such activitiy, thereby validating the recommended thresholds.

This workbook contains the code, scrubbed of client data, used to accomplish these two tasks, as well as detailed commentary and walkthroughts on the mechanisms of the code.

## Packages

In [1]:
import cx_Oracle #For connecting to the database
import pandas #For working with DataFrames and exporting to .csvs
import numpy as np #For numerical calculations
import time #For testing the speed/performance of our code
import re #For using regular expression to text match
import functools #For memoization
import random #For setting the seed, extracting random samples
from datetime import timedelta #For dynamic lookback periods

from cx_Oracle import DatabaseError as DatabaseError
from cx_Oracle import InterfaceError as InterfaceError
from cx_Oracle import OperationalError as OperationalError
from pandas.tslib import OutOfBoundsDatetime as OutOfBoundsDateTime

  


---
## Part 1: Database Interfacing
---

### Connecting to Oracle Database

All code in this engagement required querying an Oracle 11g Database. To communicate with this type of database in Python, we utilize the [cx_Oracle library](https://oracle.github.io/python-cx_Oracle/), which is specific to Oracle databases but works much like any other Python database package (such as MySQL, PostgreSQL, etc.). 

These libraries work by creating a **cursor object** through which we can pass queries and other statements. Creating a cursor object requires connecting with proper credentials to the database. Sometimes the syntax for these connections can be confusing and not-straightforward, so we created a function that would make connection simple for the user and handle basic errors.

In [None]:
def Connect(user = 'Insert your username here', 
            pw = 'Insert your password here', 
            host = 'Insert your hostname here', 
            port = 'Insert your port here', 
            db='Insert the name of the database  here'):
    
    address = '%s/%s@%s:%s/%s' % (user, pw, host, port, db)
    
    #The code will attempt to connect as many as five times before failing.
    #Failures are routine and can happen randomly, but five in a row means something is wrong.
    tries = 0
    while tries < 5:
        tries += 1
        try:
            con = cx_Oracle.connect(address)
            print('Connection Succesful!')
            break
        except DatabaseError:
            print('Connection Failed - Trying Again...')

    cur = con.cursor()   
    return cur

The function takes four basic arguments - `user`, `pw`, `host`, `port`, and `db`, all of which are self-explanatory and will usually be provided by the client. Furthermore, we found that creating a connection often fails for no particular reason - especially when you're working through a VPN - so we introduce error handling that tells the program to attempt to connect five times before giving up. 

We would save this connection to a **cursor object** (usually  named `cur`) which becomes the backbone of the following SQL query function (which is, in turn, the back bone of all other code in this library). 

### Sending Queries

Once we have a cursor object, we can send regular SQL queries to the database, just like we would in a regular SQL workbench ([such as Oracle SQL Developer, which we used in this engagement](https://www.oracle.com/database/technologies/appdev/sql-developer.html)) and then get their results in Python, in whichever form is most convenient.

Again, the syntax for doing this in `cx_Oracle` is a little verbose and cumbersome, so we wrote a function a more user-friendly function that simply takes in the following arguments:

- **`query`**: A simple string of the SQL query that we want to pass to the database. (If it works in SQLDeveloper, it will work here too.)
- **`cur`**: A cursor object, which is produced with the `Connect()` function above
- **`limit`**: A limit to the number of records the query return. The default is 20, so as to avoid writing millions of rows into memory. (Note that many SQL workbenches basically have a built-in limit, which is why they may seem like they're able to handle large queries very quickly. If you run `SELECT *` from a big table, it won't return all 10M+ rows at once - it will fetch them as needed. Python has no such feature, so we need to create one.
- **`output`**: A variable that tells the function what we want to output. There are three possible values: 'standard', 'df', and 'metadata'. Respectively, they return a simple list of lists of the results, a Pandas dataframe of the results, and the _metadata_ of the results (i.e. the column names, datatypes, and other information). 

**Note:** Depending on what it's being used for, you may want to [**memoize**](https://en.wikipedia.org/wiki/Memoization) this function. Memoizing the function makes it remember the results of previous calls. So if you run the function using arguments it's already seen, it will simply return the value it returned last time, rather than than doing all the computational work it had to do to get the original result.

On one hand, this can save a lot of time while you're testing or completing repetitive tasks. During the engagement, SQL queries proved time-consuming to run, and even more so for the results to flow across the VPN. So, in cases when we could be certain that the underlying data would not change, we would often memoize the function to promote efficiency.

On the other hand, if you're working with data that you know will change, it doesn't make sense to memoize the function.

In [None]:
#Un-comment this line if you want to memoize the function.
#@functools.lru_cache(maxsize=None)
def SQL_Query(query, cur, limit = 20, output = 'standard'):
    
    #First, we'll make sure that we got a valid argument for the 'output' variable.
    #If we haven't, we'll raise an error to guide the user to a correct inputs.
    if output not in ('df','standard','metadata'):
        raise ValueError("This is not a recognized output - try 'standard', 'df', or 'metadata'.")
    
    #Next, we'll try to execute the query five times.
    #(Again, failures are random and routine, but five in a row means something is wrong.)
    tries = 0
    while tries < 5:
        try:
            cur.execute(query)
            break

        #If we find that the connection has dropped, which happens often,  we'll just reconnect to the database.
        except (InterfaceError, OperationalError) as e:
            print('Query Failed - Reconnecting')
            #(These are the two types of errors that occur when the connection has dropped.)
            
            cur = Connect()
            tries +=1
    
    #These are basic computations for extracting the names of each column returned
    descriptions = cur.description
    colnames = [field[0] for field in descriptions]
    
    #If we're returning only the metadata, all we need is the descriptions variable.
    if output == 'metadata':
        return descriptions
    
    #If we're returning the actual results of the query, we'll now need to adhere to the limit variable we defined:
    elif output in ('df', 'standard'):
        results = []
        
        #If there's no limit, capture all the results:
        if limit is False:
            results = [result for result in cur]

        #If there is, capture only as many as instructed:
        else:
            results = [result for result in cur[:range(limit)]]
    
    #The 'standard" output is just a list of lists...
    if output == 'standard':
        return results

    #...whereas the 'df' output turns the lists into a neat dataframe, complete with column names.
    elif output == 'df':
        df = pandas.DataFrame(results)
        df.columns = colnames
        return df

---
## Part 2: Data Profiling
---

The EY data profiling methodology measures three things: completeness, correctness, and class differentiation. The idea is to make sure that no data is missing, that the data is correct, and then to explore how the data is differentiated, which we'd do differently depending on the type and distribution of the data. This function automates the measurement of completeness and class differentiation. (Measuring correctness is more complicated and can require cross-referencing source systems or testing individual datapoints.)

Measuring **completeness** is simple - the function calculates what proportion data is `NULL`. Measuring **class differentiation** is more complicated, and depends on the data. The algorithm we wrote uses the following logic:
- If the data is numerical, compute summary statistics
- If the data is datetime, compute a date range
- If the data is string and there are only a few distinct values, show which % belong to each category
- If the data is string and has many values, just compute the # of unique values

### Profile Field
The first function conducts this analysis on a single field in a single table of a database. It takes the following arguments:
- **`field`**: The name of the field we want to profile
- **`table`**: The name of the table containing that field 
- **`schema`**: The name of the schema containing that table (which defaults to 'BUSINESS' - the schema most commonly used in the engagement)
- **`cur`**: A cursor object, which will default to the predefined `cur` returned by the `Connect()` function

(This function stayed memoized, as the underlying dataset was intentionally left static.)

In [None]:
@functools.lru_cache(maxsize=None)
def Profile_Field(field, 
                  table, 
                  schema='BUSINESS', 
                  distinction_threshold = 0.3, 
                  max_distinct = 100, 
                  cur = cur):
    
    print('Computing Class Diff for %s' % field)
    
    #Step 1: Establish the size of the table
    print('Establishing Size for %s' % field)
    row_count_query = 'SELECT /*parallel (12)*/ COUNT(*) FROM %s.%s' % (schema, table)
    row_count_results = SQL_Query(row_count_query, cur=cur)
    row_count = row_count_results[0][0]
    
    #Step 2: Find the datatype
    print('Finding MetaData for %s' % field)
    dtype_query = 'SELECT %s FROM %s.%s' % (field, schema, table)
    dtype_results = SQL_Query(dtype_query, cur=cur, output='metadata') 
    dtype = dtype_results[0][1]
    
    #Step 3: Find NULL Values
    print('Finding NULL for %s' % field)
    null_query = '''
        SELECT /*parallel(12)*/
            AVG(NULL_COUNT) AS PERCENT_NULL,
            SUM(NULL_COUNT) AS NULL_COUNT,
            1 - AVG(NULL_COUNT) AS PERCENT_COMPLETE
        FROM (
        SELECT 
            CASE WHEN %s IS NULL THEN 1 ELSE 0 END AS NULL_COUNT
            FROM %s.%s
        )
        ''' % (field, schema, table)

    null_query_result = SQL_Query(null_query, cur=cur)
    
    null_dict = {
        'Table' : table,
        'Field' : field,
        'Percent_Null' : null_query_result[0][0],
        'Null_Count' : null_query_result[0][1],
        'Percent_Complete' : null_query_result[0][2]
    }
    
    if null_query_result[0][0] == 1:
        results_dict = {
                'Table': table,
                'Field': field,
                'Dtype': str(dtype),
                'Unique_Values' : 0
        }
        
        return (dict(results_dict, **null_dict))
   
    #Step 4: Find the number of distinct values, unless it's all null
    print('Finding Distinct Values %s' % field)
    n_distinct_query = 'SELECT /*parallel (12)*/ COUNT(DISTINCT(%s)) FROM %s.%s' % (field, schema, table)
    n_distinct_results = SQL_Query(n_distinct_query, cur=cur)
    distinct_values = n_distinct_results[0][0]
    
    
    #Step 5: Determine a course of action based on the distinct values, total values, and datatype
    print('Assessing Class Diff for %s' % field)
    
    #First, we'll determine whether or not the field is too differentiated to compute a result (regardless of type).
    #We'll use the previously defined DISTINCTION THRESHOLD and the MAX DISTINCT inputs.
    
    if (distinct_values / row_count) > distinction_threshold or distinct_values > max_distinct:
        
        #If there are too many values, our response will depend on the datatype of the field:
        
        #If the class is numeric, we want to compute summary statistics:
        if dtype == cx_Oracle.NUMBER:
            print('Finding Summary Stats for %s' % field)
            summary_stats_query = \
                    '''
                    SELECT /*parallel(12)*/
                    MAX(%s) AS Max_Value, 
                    MIN(%s) AS Min_Value,
                    AVG(%s) AS Mean,
                    MEDIAN(%s) AS Median,
                    STDDEV(%s) AS Standard_Deviation,
                    PERCENTILE_CONT(0.85) WITHIN GROUP (ORDER BY %s) "P85",
                    PERCENTILE_CONT(0.977) WITHIN GROUP (ORDER BY %s) "P977"
                    FROM %s.%s''' % (field, field, field, field, field, field, field, schema, table)
            
            #We used to index this manually, but actually making it into a DataFrame is more scalable.
            summary_stats_results = SQL_Query(summary_stats_query, cur=cur, output='df')
            summary_stats_dict = summary_stats_results.to_dict(orient='records')
            
            results_dict = {
                'Table': table,
                'Field': field,
                'Dtype': str(dtype),
                'Unique_Values' : distinct_values,
                'Summary_Stats' : summary_stats_dict
            }
            
        #If the class is a date, we'll do something similar, except we'll only calculate the minimum and maximimum date:
        elif dtype == cx_Oracle.DATETIME:
            print('Finding DateRange for %s' % field)
            try:
                date_range_query = \
                        '''
                        SELECT /*parallel(12)*/
                        MAX(%s) AS Latest_Date, 
                        MIN(%s) AS First_Date
                        FROM %s.%s''' % (field, field, schema, table)

                date_range_results = SQL_Query(date_range_query, cur=cur, output='df')
                date_range_dict = date_range_results.to_dict(orient='records')

                results_dict = {
                    'Table': table,
                    'Field': field,
                    'Dtype': str(dtype),
                    'Unique_Values' : distinct_values,
                    'Date_Range' : date_range_dict
                }
            
            #It's possible that the datetime will be out of Pandas' interpretable range, in which case:
            except OutOfBoundsDateTime:
                print('Found Out-of-Bounds Date Range for %s' % field)
                results_dict = {
                    'Table': table,
                    'Field': field,
                    'Dtype': str(dtype),
                    'Unique_Values' : distinct_values,
                    'Date_Range' : 'Error - Date Out of Range'
                }
        
        #If the class is neither a date nor a number, however, we don't want to compute anything at all:
        else:
            print('Nothing to Compute for %s' % field)
            results_dict = {
                'Table': table,
                'Field': field,
                'Dtype': str(dtype),
                'Unique_Values' : distinct_values,
                'Class_Diff' : 'Not Applicable'
            }
    
    #If there were NOT too many values, we want to now show class differentiation regardless of datatype
    else:
        print('Finding Value Breakdown for %s' % field)
        ##Old Query:
        value_breakdown_query = \
                        '''SELECT /*parallel(12)*/
                            %s, 
                        COUNT(*) AS NUM_RECORDS,
                        COUNT(*) / (SELECT COUNT(*) FROM %s.%s) AS PERCENT_TOTAL
                        FROM %s.%s GROUP BY %s''' \
                        % (field, schema, table, schema, table, field)
                
        #New Query:
        value_breakdown_query = \
        '''
            WITH COUNT_RECORD AS (
                SELECT /*+ parallel(12)*/ 
                    COUNT(*) num
                FROM
                    %s.%s
            ) SELECT /*+ parallel(12)*/ 
                %s,
                COUNT(*) AS NUM_RECORDS,
                COUNT(*) / max(COUNT_RECORD.num) AS PERCENT_TOTAL
            FROM
                %s.%s,
                COUNT_RECORD
            WHERE
                1 = 1
            GROUP BY
                 %s
        ''' % (schema, table, field, schema, table, field)
        
        try:
            value_breakdown_results = SQL_Query(value_breakdown_query, cur=cur, limit = False, output = 'df')
            value_breakdown_dict = value_breakdown_results.to_dict(orient='records')
            #(See Pandas documentation for more info on the 'orien' argument and other possible values)
            
        except (TypeError, ValueError) as e:
            value_breakdown_dict = 'Error - NoneType'
        
        #Not a perfect representation of the data, but as good as we can hope for:  
        
        results_dict = {
                'Table': table,
                'Field': field,
                'Dtype': str(dtype),
                'Unique_Values' : distinct_values,
                'Value_Breakdown' : value_breakdown_dict
        }
        
    return (dict(results_dict, **null_dict))

### Profile Table
The next function profiles an entire _table_, using `Profile_Field` for each field within the table and then aggregating the results into a large dictionary. It takes the same arguments as `Profile_Field`, except for `field`.

In [None]:
def Profile_Table(table, schema='BUSINESS', cur=cur):
    #Don't forget to record performance:
    start_time = time.gmtime()
    
    fields_query = '''SELECT * FROM %s.%s''' % (schema, table)
    fields_query_results = SQL_Query(fields_query, cur=cur, output='metadata')
    col_names = [result[0] for result in fields_query_results]
    
    full_profile_list = [Profile_Field(column, table=table, schema=schema, cur=cur) for column in col_names]
    
    end_time = time.gmtime()
    print("This took " + str(time.mktime(end_time) - time.mktime(start_time)) + " seconds.")
    
    return full_profile_list

### Send Profile to CSV
The last function sends the table's full profile to CSVs. Given the way the data profiling report was organized, there were four CSVs generated for each table:

- A "full profile" which contained basic data on _all_ fields in the table, such as its type, number of unique values, and proportion of null values
- A "summary stats" CSV containing the summary stats computed for all numeric fields
- A "date range" CSV containing the ranges for all datetime fields
- A "class differentiation" CSV containing the breakdowns for categorical fields

In [2]:
def Profile_to_CSV(profile_output, table_name = 'Table'):
    
    def DType_RegEx(oracle_dtype):
        pattern = re.compile("<class 'cx_Oracle\.(\w+)'>")
        match = re.search(pattern, oracle_dtype).group(1)
        clean_dtype = match.title()
        return clean_dtype
    
    #Step 1: Output the "FullP Profile"/Master CSV
    profile_df = pandas.DataFrame(profile_output)
    #Take only the relevant fields
    profile_df = profile_df[['Table', 'Field', 'Dtype', 'Unique_Values', 'Null_Count','Percent_Complete']]
    #Clean the "Dtype" field using the RegEx function
    profile_df['Dtype'] = profile_df['Dtype'].apply(DType_RegEx)
    #Write to CSV
    profile_df.to_csv('%s_Full_Profile.csv' % table_name)
    
    summary_stats_list = []
    date_range_list = []
    value_breakdown_list = []
    
    for field in profile_output:
        try:
            #We need a lightweight dictionary to merge with subsequent dictionaries
            mini_dict = {'Table': field['Table'], 'Field' : field['Field']}

            if 'Summary_Stats' in field.keys():
                adjusted_stats_dict = dict(mini_dict, **field['Summary_Stats'][0])
                summary_stats_list.append(adjusted_stats_dict)

            elif 'Date_Range' in field.keys():
                if field['Date_Range'] == 'Error - Date Out of Range':
                    print('Date Error on Field: %s' % field['Field'])

                else:
                    adjusted_range_dict = dict(mini_dict, **field['Date_Range'][0])
                    date_range_list.append(adjusted_range_dict)

            elif 'Value_Breakdown' in field.keys() and field['Percent_Null'] < 1:
                #Remember, the values are dicts, but the keys are different every time!
                for value in field['Value_Breakdown']:
                    adjusted_value_dict = dict(mini_dict, **value)
                    #Fort the love of God I hope this works
                    adjusted_value_dict['Value'] = adjusted_value_dict.pop(field['Field'])
                    value_breakdown_list.append(adjusted_value_dict)
        except TypeError:
            print('Error on %s' % field)
            
    pandas.DataFrame(summary_stats_list).to_csv('%s_Summary_Stats.csv' % table_name)
    pandas.DataFrame(date_range_list).to_csv('%s_Date_Ranges.csv' % table_name)
    #Don't forget to sort the value breakdown:
    pandas.DataFrame(value_breakdown_list).sort_values(by=['Field','NUM_RECORDS'], \
                            ascending = [True, False]).to_csv('%s_Value_Breakdowns.csv' % table_name)

---
## Part 3: BTL Sampling
---

The last portion of the code exports below-the-line samples for AML scenarios. It should be noted that this part of the code is the least generalizable. Everything up to this point would work on any database in any context, and aren't really specific to AML or even financial services. But this final function is tailored to the particular structure of the tables and queries created during this engagement. It won't work for other engagements, but the code is still useful.

The algorithm takes the following arguments:
- **`scenario_name`**: The name of the scenario we're testing. (This is purely for labeling purposes and doesn't really affect the code.)
- **`lookback`**: The "lookback period" of the scenario (i.e. the timeframe over which it monitors customers). The function needs to know this so that it can export transactions from the relevant timeframe.
- **`scenario_focus`**: The "focus" of the scenario. (Some scenarios monitor individual accounts, while others monitor customers or entire households.) All of the queries in the function depend on this variable. 
- **`alert_field`** and **`rundate_field`**: The actual names of the fields in the tables that contain data on AML alerts. Sometimes these are spelled differently, so we needed to allow the function to specify them in cases when they were. 
- **`seed`**: Since this function involves some _random_ sampling, it was important to allow the seed to be configurable so that we could replicate results. (If you don't set the seed, you get different results every time.)

Given these arguments, the function works like this:
1. Find all of the "below-the-line" alerts in the relevant scenario over the past year.
2. Using the hypergeometric function, compute how many alerts we'd need to investigate to determine with statistical confidence that there was an acceptable amount of suspicious activity going on below the line. (That computation rests on lots of assumptions like acceptable rate of suspicious activity, confidence interval, margin of error, etc.)
3. Take a random sample of that size.
4. For all the accounts, customers, or households in that sample, query all relevant customer data and transaction data for the relevant timeframe.
5. Export all that data to CSVs.

In [3]:
def BTL_Sample(scenario_name, 
               alerts_table, 
               lookback, 
               scenario_focus, 
               schema, 
               alert_field = 'ALERT_COUNT', 
               rundate_field = 'RUNDATE',
               seed = 1234):

    start_time = time.gmtime()
    
    #Step 1: Determine Population Size of Strata (Using SIG vs. Non-SIG as Strata)
    if scenario_focus not in ('HH','ACCT'):
        raise ValueError('HH and ACCT are the only valid scenario foci.')

    #From the scenario focus, we can infer a few important parts of the query, so that it is made dynamic.
    focus_variables = {
        'HH' : {
            'segmentation_table' : 'BUSINESS.EY_TRXN_SIG_HOUSEHOLD',
            'primary_key' : 'ACCT_GRP_ID'
        },
        'ACCT' : {
            'segmentation_table' : 'BUSINESS.EY_TRXN_SIG_MODEL2',
            'primary_key' : 'ACCT_INTRL_ID'
        }
    }
    
    print('Variables defined.')
    
    strata_pop_query = '''
                        SELECT 
                            CASE WHEN EY.SIG IS NOT NULL THEN 'SIG' ELSE 'NON-SIG' END AS STRATA,
                            COUNT(*) AS FULL_POP
                        FROM
                            {schema}.{alerts_table} AL INNER JOIN
                            {segmentation_table} EY ON AL.{primary_key} = EY.{primary_key}
                        WHERE
                            {alert_field} = 0
                        GROUP BY
                            CASE WHEN EY.SIG IS NOT NULL THEN 'SIG' ELSE 'NON-SIG' END
                        '''.format(schema = schema, 
                                   alerts_table = alerts_table, 
                                   alert_field = alert_field,
                                   primary_key = focus_variables[scenario_focus]['primary_key'],
                                   segmentation_table = focus_variables[scenario_focus]['segmentation_table'],
                                  )
    
    print('Querying strata populations.')

    strata_pop = SQL_Query(cur = cur,
                           query = strata_pop_query,
                           limit = False,
                           output = 'standard'
                          )
    
    #print(strata_pop)
    
    for result in strata_pop:
        print(result[0] + ' - ' + str(result[1]))
    
    print('Results saved.')

    #Step 3: Hypergeometric Function for Computing Sample Size
    def required_size(N, ME=0.05 , p=0.05 , Z=1.64):
        q = 1 - p
        n = (N * (Z ** 2) * p * q) / ((ME ** 2) * (N - 1) + (Z ** 2) * p * q)
        return n
    
    #Step 3.5: Based on those populations, determine the necessary sample size 
    #(We'll make this a loop, since we may want to change our definition of Strata in the future)
    samples_master = []

    for result in strata_pop:
        temp = {
            'strata' : result[0],
            'required' : required_size(result[1]),
            'total_pop' : result[1]
        }
        
        print(temp['strata'] + ' needs ' + str(round(temp['required'],2)) + ' samples')
 
        samples_master.append(temp)
        
    #Step 4: For each strata, extract a random set of BTL alerts
    #(Set seed, so that results are replicable)
    random.seed(seed)

    print('Random seed set at ' + str(seed) + '.')
    
    for strata in samples_master:
        print('Grabbing samples for ' + strata['strata'] + '.')
        #First, query the full BTL population of each strata:
        strata_btl_query = '''
        SELECT 
            AL.*
        FROM
            {schema}.{alerts_table} AL INNER JOIN
            {segmentation_table} EY ON AL.{primary_key} = EY.{primary_key}
        WHERE
            AL.{alert_field} = 0 AND
            CASE WHEN EY.SIG IS NOT NULL THEN 'SIG' ELSE 'NON-SIG' END = '{strata}'
        '''.format(schema = schema,
                   alerts_table = alerts_table,
                   alert_field = alert_field,
                   strata = strata['strata'],
                   primary_key = focus_variables[scenario_focus]['primary_key'],
                   segmentation_table = focus_variables[scenario_focus]['segmentation_table']
                  )

        full_BTL_strata_pop = SQL_Query(cur = cur,
                                        query = strata_btl_query,
                                        limit = False,
                                        output = 'df'
                                       )

        #Then, we generate random indices:
        #We want X random numbers between 0 and Y where...
        #X --> The required sample sized, based on...
        #Y---> The population size
        
        #Sampling with replacement:
        #rows = np.random.randint(low=1, high = strata['total_pop'], size = int(strata['required']))
        
        #Sampling without replacement:
        rows = random.sample(list(full_BTL_strata_pop.index), int(strata['required']))

        #Then we will use those random numbers to pull a sample:
        random_strata_sample = full_BTL_strata_pop.iloc[rows,:]
        
        #Don't forget to order the sample by primary key then rundate
        random_strata_sample.sort_values(by = [focus_variables[scenario_focus]['primary_key'], rundate_field])

        #We'll save all results to memory, in the original dictionary.
        strata['full_BTL_population'] = full_BTL_strata_pop
        strata['random_BTL_sample'] = random_strata_sample
        
    #Step 5: For each of the random samples, extract all transactions from the lookback period
    #(Regardless of whether the transaction was associated with this particular scenario)

    for strata in samples_master:
        print('Grabbing transactions for ' + strata['strata'] + '.')
        
        trxns_query = '''
        SELECT
            *
        FROM
            BUSINESS.EY_ELIG_TRXN_BTL T
        WHERE'''


        for random_sample in zip(strata['random_BTL_sample'].loc[:,focus_variables[scenario_focus]['primary_key']],strata['random_BTL_sample'].loc[:,rundate_field]):
            where_clause = '''
            ({primary_key} = '{identifier}' AND TRXN_DATE BETWEEN TO_DATE('{end_date}') AND TO_DATE('{start_date}')) OR'''\
            .format(primary_key = focus_variables[scenario_focus]['primary_key'],
                                       identifier = random_sample[0],
                                       #Ensure formatting is correct for SQL
                                       start_date = str(random_sample[1].strftime('%d-%b-%y')),
                                        #Lookback period remains a dynamic parameter
                                       end_date = str((random_sample[1] + timedelta(days = -lookback)) .strftime('%d-%b-%y'))
                                      )
            trxns_query += where_clause

        #Trim query of final, superfluous 'OR' + order by date 
        trxns_query = trxns_query[:(len(trxns_query)-3)]
        trxns_query += '''
        ORDER BY 
            {primary_key},
            TRXN_DATE DESC
        '''.format(primary_key = focus_variables[scenario_focus]['primary_key'])
        
        #Send the query
        trxns_results = SQL_Query(cur = cur, 
              query = trxns_query,
              limit = False,
              output = 'df'
             )

        #Add the results to the samples master
        strata['trxns_sample'] = trxns_results
        
    #Step 6: Extract Customer Data for all related customers
    print('Extracting customer data.')
    
    #Need to extract fields relevant to investigators, in user-friendly (non-Mantas) format
    #Query will differ depending on the focus of the scenario:
    
    if scenario_focus == 'HH':
        cust_query_stem = '''
        SELECT 
            DISTINCT --Given the join structure, duplicate records are possible
            HH.ACCT_GRP_ID,
            C.CUST_INTRL_ID,
            C.FULL_NM AS NAME,
            C.BIRTH_DT AS DOB,
            C.CUST_ADD_DT AS DATE_ADDED,
            C.FNCL_PRFL_LAST_UPDT_DT AS LAST_UPDATE,
            C.TAX_ID,
            C.ANNL_INCM_BASE_AM AS EST_ANNUAL_INCOME,
            C.NET_WRTH_BASE_AM AS EST_NET_WORTH,
            C.LQD_NET_WRTH_BASE_AM AS EST_LIQUID_NET_WORTH,
            CASE 
                WHEN C.MRTL_STAT_CD = 'U' THEN 'Unknown'
                WHEN C.MRTL_STAT_CD = 'M' THEN 'Married'
                WHEN C.MRTL_STAT_CD = 'D' THEN 'Divorced'
                WHEN C.MRTL_STAT_CD = 'W' THEN 'Widowed'
                WHEN C.MRTL_STAT_CD = 'U' THEN 'Single'
            END AS MARTIAL_STATUS,
            C.MPLYR_NM AS EMPLOYER,
            C.OCPTN_NM AS OCCUPATION,
            CASE WHEN C.RES_CNTRY_CD IS NOT NULL THEN 'Y' ELSE 'N' END AS FOREIGN_RESIDENT_FLAG,
            C.WLTH_SRC_DSCR_TX AS WEALTH_SOURCE,
            C.CUST_CDT_RTNG AS CREDIT_RATING,
            C.CUST_CDT_RTNG_SRC AS CREDIT_RATING_SOURCE
        FROM
            BUSINESS.ACCT_GRP HH INNER JOIN
            BUSINESS.ACCT A ON HH.ACCT_GRP_ID = A.HH_ACCT_GRP_ID INNER JOIN
            BUSINESS.CUST C ON A.PRMRY_CUST_INTRL_ID = C.CUST_INTRL_ID
        WHERE
            HH.ACCT_GRP_ID IN 
        '''
        
        cust_query_end = '''
                ORDER BY
                HH.ACCT_GRP_ID
                '''
        
    elif scenario_focus == 'ACCT':
        cust_query_stem = '''
        SELECT 
            A.ACCT_INTRL_ID,
            C.CUST_INTRL_ID,
            C.FULL_NM AS NAME,
            C.BIRTH_DT AS DOB,
            C.CUST_ADD_DT AS DATE_ADDED,
            C.FNCL_PRFL_LAST_UPDT_DT AS LAST_UPDATE,
            C.TAX_ID,
            C.ANNL_INCM_BASE_AM AS EST_ANNUAL_INCOME,
            C.NET_WRTH_BASE_AM AS EST_NET_WORTH,
            C.LQD_NET_WRTH_BASE_AM AS EST_LIQUID_NET_WORTH,
            CASE 
                WHEN C.MRTL_STAT_CD = 'U' THEN 'Unknown'
                WHEN C.MRTL_STAT_CD = 'M' THEN 'Married'
                WHEN C.MRTL_STAT_CD = 'D' THEN 'Divorced'
                WHEN C.MRTL_STAT_CD = 'W' THEN 'Widowed'
                WHEN C.MRTL_STAT_CD = 'U' THEN 'Single'
            END AS MARTIAL_STATUS,
            C.MPLYR_NM AS EMPLOYER,
            C.OCPTN_NM AS OCCUPATION,
            CASE WHEN C.RES_CNTRY_CD IS NOT NULL THEN 'Y' ELSE 'N' END AS FOREIGN_RESIDENT_FLAG,
            C.WLTH_SRC_DSCR_TX AS WEALTH_SOURCE,
            C.CUST_CDT_RTNG AS CREDIT_RATING,
            C.CUST_CDT_RTNG_SRC AS CREDIT_RATING_SOURCE
        FROM
            BUSINESS.ACCT A INNER JOIN
            BUSINESS.CUST C ON A.PRMRY_CUST_INTRL_ID = C.CUST_INTRL_ID
        WHERE
            A.ACCT_INTRL_ID IN
        '''
        
        cust_query_end = '''
                ORDER BY
                A.ACCT_INTRL_ID
                '''
    
    for strata in samples_master:
        primary_key_filter_clause = '('
        
        #Loop through all of the different primary keys and turn them into a WHERE Clause
        for primary_key in strata['random_BTL_sample'][focus_variables[scenario_focus]['primary_key']]:
            primary_key_filter_clause += str('\'' + primary_key + '\', ')
        
        #Trim the string, integrate it into the full query
        primary_key_filter_clause = primary_key_filter_clause[:(len(primary_key_filter_clause) - 2)]
        primary_key_filter_clause += ')'
        
        #Combine the three components of the query
        full_customer_query = cust_query_stem + primary_key_filter_clause + cust_query_end
                                                       
        #Send query to database
        customer_data = SQL_Query(cur = cur, 
                                  query = full_customer_query,
                                  limit = False,
                                  output = 'df'
                                 )
        
        strata['customer_data'] = customer_data
        
                                                       
    #Step 7: Export the data to Excel.
    print('Sending to CSV.')
    
    for strata in samples_master:
        #Add another column to the main page explaining which strata the samples belong to
        strata['random_BTL_sample']['strata'] = strata['strata']

        #To CSV:
        strata['random_BTL_sample'].to_csv(scenario_name + '_' + strata['strata'] + '_BTL_Samples.csv')
        strata['trxns_sample'].to_csv(scenario_name + '_' + strata['strata'] + '_BTL_Samples_Transactions.csv')
        strata['customer_data'].to_csv(scenario_name + '_' + strata['strata'] + 'Customer_Data.csv')
        
    end_time = time.gmtime()
    print("This took " + str(time.mktime(end_time) - time.mktime(start_time)) + " seconds.")