**DATABASE CONFIGURATION**

In [65]:
from sqlalchemy.engine import URL
from sqlalchemy import create_engine
import pandas as pd
import json
import sys

Load Configuration File

In [66]:
with open('config.json', 'r') as config_file:
    config = json.load(config_file)
    print(config)

{'database': {'name': 'BikeStores', 'server': '.\\SQLEXPRESS', 'driver': 'SQL Server', 'sample_query': 'SELECT * FROM production.brands'}, 'rl': {'state_info': {'mark_error': True, 'query_result_length': 10000, 'padding_char': '\x00'}, 'contexts': [{'query': 'SELECT * FROM production.products WHERE brand_id=8 [INPUT]', 'table_filter': ['production.products'], 'column_filter': ['product_name'], 'goal': 'Trek 820 - 2016'}]}}


Set Up Connection to Microsoft SQL Database

In [67]:
db = config['database']

name = db['name']
server = db['server']
driver = db['driver']

# Connect to SQL database using the above parameters.
conn_string = f'DRIVER={driver};SERVER={server};DATABASE={name};Trusted_Connection=yes'
conn_url = URL.create('mssql+pyodbc', query={'odbc_connect': conn_string})
engine = create_engine(conn_url)

# Display a dataframe from a sample query if set.
if 'sample_query' in db:
    sample_query = db['sample_query']
    df = pd.read_sql(sample_query, engine)
    print(df)

   brand_id    brand_name
0         1       Electra
1         2          Haro
2         3        Heller
3         4   Pure Cycles
4         5       Ritchey
5         6       Strider
6         7  Sun Bicycles
7         8         Surly
8         9          Trek


**REINFORCEMENT LEARNING**

In [68]:
state_info = config['rl']['state_info']
mark_error = state_info['mark_error']
query_result_length = state_info['query_result_length']
padding_char = state_info['padding_char']

# If a query result is an error, and if mark_error is true, this distinction will be added to the state information.
# TODO: Ensure mark_error information has a higher weighting in the neural net.
nn_input_size = query_result_length
if(mark_error):
    nn_input_size += 1

Define RL Contexts and Incrementation

In [69]:
# Not having any defined RL contexts will result in an error.
contexts = config['rl']['contexts']
context = contexts[0]
context_index = 0

# Increments RL context and returns whether another context was assigned.
def incr_context():
    global context, context_index
    if context_index < len(contexts) - 1:
        context_index += 1
        context = contexts[context_index]
        return True
    return False

In [84]:
# TODO: Remove non-MSSQL payloads.
payloads = open('sqli_payloads.txt', 'r').read().split('\n')

def inject_payload(payload_index):
    payload = payloads[payload_index]
    query = context['query']
    query = query.replace('[INPUT]', payload)

    try:
        df = pd.read_sql(query, engine)
        res = df.to_csv()
        has_error = False
    except:
        res = str(sys.exc_info()[1])
        has_error = True

    # Trim resulting string or pad it so that the length is equal to query_result_length.
    if len(res) > query_result_length:
        res = res[:query_result_length]
    else:
        res = res.ljust(query_result_length, padding_char)

    if mark_error:
        return res, has_error
    return res

10000 False
