**DATABASE CONFIGURATION**

In [288]:
from sqlalchemy.engine import URL
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
import json
import sys
import random

Load Configuration File

In [289]:
with open('config.json', 'r') as config_file:
    config = json.load(config_file)
    print(config)

{'database': {'name': 'BikeStores', 'server': '.\\SQLEXPRESS', 'driver': 'SQL Server', 'sample_query': 'SELECT * FROM production.brands'}, 'rl': {'state_info': {'mark_error': True, 'query_result_length': 10000, 'padding_char': '\x00'}, 'contexts': [{'query': 'SELECT * FROM production.products WHERE brand_id=8 [INPUT]', 'table_filter': ['production.products'], 'column_filter': ['product_name'], 'goal': 'Trek 820 - 2016'}]}}


Set Up Connection to Microsoft SQL Database

In [290]:
db = config['database']

name = db['name']
server = db['server']
driver = db['driver']

# Connect to SQL database using the above parameters.
conn_string = f'DRIVER={driver};SERVER={server};DATABASE={name};Trusted_Connection=yes'
conn_url = URL.create('mssql+pyodbc', query={'odbc_connect': conn_string})
engine = create_engine(conn_url)

# Display a dataframe from a sample query if set.
if 'sample_query' in db:
    sample_query = db['sample_query']
    df = pd.read_sql(sample_query, engine)
    print(df)

   brand_id    brand_name
0         1       Electra
1         2          Haro
2         3        Heller
3         4   Pure Cycles
4         5       Ritchey
5         6       Strider
6         7  Sun Bicycles
7         8         Surly
8         9          Trek


**REINFORCEMENT LEARNING**

Define State Configuration

In [291]:
state_info_config = config['rl']['state_info']
mark_error = state_info_config['mark_error']
query_result_length = state_info_config['query_result_length']
padding_char = state_info_config['padding_char']
error_char = 'y'
no_error_char = 'n'

# If a query result is an error, and if mark_error is true, this distinction will be added to the state information.
# TODO: Ensure mark_error information has a higher weighting in the neural net.
nn_input_size = query_result_length
if(mark_error):
    nn_input_size += 1

Define RL Contexts and Incrementation

In [292]:
# Not having any defined RL contexts will result in an error.
contexts = config['rl']['contexts']
context = contexts[0]
context_index = 0

# Increments RL context and returns whether another context was assigned.
def incr_context():
    global context, context_index
    if context_index < len(contexts) - 1:
        context_index += 1
        context = contexts[context_index]
        return True
    return False

Create SQL Injection Attack Functionality

In [293]:
# TODO: Remove non-MSSQL payloads.
payloads = open('sqli_payloads.txt', 'r').read().split('\n')

# Perfoms an SQL injection attack based on an index from the list of payloads.
def inject_payload(payload_index):
    global qtable
    
    # Finds [INPUT] within the context query configuration and replaces it with the payload.
    payload = payloads[payload_index]
    query = context['query']
    query = query.replace('[INPUT]', payload)

    reward = -1
    episode_over = False

    try:
        # Runs SQL injection query.
        df = pd.read_sql(query, engine)
        res = df.to_csv()

        # Check episode termination condition, and if true, apply appropriate reward.
        # TODO: Ensure tables are filtered as the same column name could exist in another table.
        for column in context['column_filter']:
            if column in df and context['goal'] in df[column].values:
                reward = 100
                episode_over = True

        has_error = no_error_char
    except:
        # Record error as a String.
        res = str(sys.exc_info()[1])
        has_error = error_char

    # Trim resulting string or pad it so that the length is equal to query_result_length.
    if len(res) > query_result_length:
        res = res[:query_result_length]
    else:
        res = res.ljust(query_result_length, padding_char)

    # Add error information if this is set.
    if mark_error:
        res = res + has_error

    # If the state exists, retrieve it, then return it and its index.
    for i in range(0, len(states_info)):
        state_info = states_info[i]
        if state_info == res:
            return states_info.index(res), reward, episode_over, res

    # If the state does not exist, create a new state. The states info list as well as the Q-Table must be appeneded with this new state. 
    # After creation of the new state, return it and its index.
    states_info.append(res)
    qtable = np.concatenate((qtable, np.zeros((1, len(payloads)))), axis=0)
    return len(states_info) - 1, reward, episode_over, res

Q-Learning Parameters

In [294]:
# Q-Learning adapted from: https://deeplizard.com/learn/video/ZaILVnqZFCg

total_episodes = 100          # Total episodes
learning_rate = 0.8           # Learning rate
max_steps = 99                # Max steps per episode
gamma = 0.95                  # Discounting rate

# Exploration parameters
epsilon = 1.0                 # Exploration rate
max_epsilon = 1.0             # Exploration probability at start
min_epsilon = 0.01            # Minimum exploration probability 
decay_rate = 0.005            # Exponential decay rate for exploration prob

Q-Learning Algorithm

In [295]:
states_info = [padding_char * query_result_length]
if mark_error:
    states_info[0] = states_info[0] + no_error_char

# Initialise Q-Table with a singleton.
qtable = np.zeros((1, len(payloads)))

rewards = []
terminating_actions = set()

for episode in range(total_episodes):
    # Reset the environment
    state = 0
    step = 0
    done = False
    total_rewards = 0
    
    for step in range(max_steps):
        exp_exp_tradeoff = random.uniform(0, 1)
        
        # If this number > greater than epsilon --> exploitation (taking the biggest Q value for this state).
        if exp_exp_tradeoff > epsilon:
            action = np.argmax(qtable[state,:])

        # Else doing a random choice --> exploration
        else:
            action = random.randint(0, len(payloads)-1)

        # Take the action (a) and observe the outcome state(s') and reward (r).
        new_state, reward, done, info = inject_payload(action)

        # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)].
        # qtable[new_state,:] : all the actions we can take from new state.
        qtable[state, action] = qtable[state, action] + learning_rate * (reward + gamma * np.max(qtable[new_state, :]) - qtable[state, action])
        
        total_rewards += reward
        
        state = new_state
        
        # Finish episode upon terminal state reached.
        if done == True:
            terminating_actions.add(action)
            break
        
    # Reduce epsilon (because we need less and less exploration)
    epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*episode)
    rewards.append(total_rewards)

print('Learned Q-Table:\n', qtable)

Learned Q-Table:
 [[ 0.    0.   -0.96 ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 ...
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 [80.    0.    0.   ...  0.    0.    0.  ]]


In [296]:
print('Score over time:', str(sum(rewards)/total_episodes))
print('Number of states:', len(qtable))
print('Terminating payloads:\n', [payloads[a] for a in terminating_actions])

Score over time: 88.34
Number of states: 438
Terminating payloads:
 ['OR 1=1', 'OR 1=1-- ']
