# Parse the sql to encode and decode only some parts of it

In [1]:
import sqlparse

# Define the SQL statement
sql_statement = "SELECT name, age, email FROM users"

# Parse the SQL statement using sqlparse
parsed_statement = sqlparse.parse(sql_statement)[0]

print(parsed_statement)

# Extract the field names from the SELECT statement
# fields = [str(token) for token in parsed_statement.tokens if token.ttype is sqlparse.tokens.Name]

# # Print the field names
# print(fields)


SELECT name, age, email FROM users


In [2]:
import glob
dir(sqlparse)

['__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 'cli',
 'engine',
 'exceptions',
 'filters',
 'format',
 'formatter',
 'keywords',
 'lexer',
 'parse',
 'parsestream',
 'split',
 'sql',
 'tokens',
 'utils']

In [None]:
raw = 'select * from foo; select * from bar;'
statements = sqlparse.split(raw)
statements


['select * from foo;', 'select * from bar;']

In [5]:
# Format the first statement and print it out:
first = statements[0]
print(sqlparse.format(first, reindent=True, keyword_case='upper'))


SELECT *
FROM foo;


In [6]:
identify = sqlparse.get_identifiers('select order.orderid, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc')[0]
print(identify.tokens)

AttributeError: module 'sqlparse' has no attribute 'get_identifiers'

In [4]:
parsed = sqlparse.parse('select order.orderid, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc')[0]
parsed.tokens

[<DML 'select' at 0x118EAAE60>,
 <Whitespace ' ' at 0x11990D900>,
 <IdentifierList 'order....' at 0x119909F50>,
 <Whitespace ' ' at 0x11990F880>,
 <Keyword 'from' at 0x11990F8E0>,
 <Whitespace ' ' at 0x11990F940>,
 <Identifier 'orders...' at 0x119909CB0>,
 <Whitespace ' ' at 0x11990FAC0>,
 <Keyword 'join' at 0x11990FB20>,
 <Whitespace ' ' at 0x11990FB80>,
 <Identifier 'revenu...' at 0x119909D90>,
 <Whitespace ' ' at 0x11990FD00>,
 <Keyword 'on' at 0x11990FD60>,
 <Whitespace ' ' at 0x11990FDC0>,
 <Comparison 'r.orde...' at 0x119909EE0>,
 <Whitespace ' ' at 0x119934220>,
 <Keyword 'order ...' at 0x119934280>,
 <Whitespace ' ' at 0x1199342E0>,
 <Identifier 'order....' at 0x119909E70>]

In [23]:
# Parsing a SQL statement:
parsed = sqlparse.parse('select order.orderid, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc')[0]
parsed_tokens = parsed.tokens
for i, token in enumerate(parsed_tokens):
    print(f"Token {i}: {token}")

Token 0: select
Token 1:  
Token 2: order.orderid, order.order_name, revenue.order_id, r.product_name, revenue.cost
Token 3:  
Token 4: from
Token 5:  
Token 6: orders o
Token 7:  
Token 8: join
Token 9:  
Token 10: revenue r
Token 11:  
Token 12: on
Token 13:  
Token 14: r.order_id  = r.order_id
Token 15:  
Token 16: order by
Token 17:  
Token 18: order.orderid desc


In [24]:
parsed_tokens = parsed.tokens
token_counts = {}

for i, token in enumerate(parsed_tokens):
    tok_type = type(token).__name__
    if tok_type in token_counts:
        token_counts[tok_type] += 1
    else:
        token_counts[tok_type] = 1

for tok_type, count in token_counts.items():
    print(f"{tok_type}: {count}")

Token: 14
IdentifierList: 1
Identifier: 3
Comparison: 1


In [17]:
#rewrite above function but return a list of fields per type in a data table
def get_fields(parsed):
    parsed_tokens = parsed.tokens
    token_counts = {}
    for i, token in enumerate(parsed_tokens):
        tok_type = type(token).__name__
        if tok_type in token_counts:
            token_counts[tok_type].append(token)
        else:
            token_counts[tok_type] = [token]
    #return a dictionary of types and their corresponding tokens in a data table
    return token_counts


In [19]:
def get_fields(parsed):
    parsed_tokens = parsed.tokens
    token_counts = {}
    for i, token in enumerate(parsed_tokens):
        tok_type = type(token).__name__
        if tok_type == 'Whitespace' or tok_type == 'Keyword':
            continue
        if tok_type in token_counts:
            token_counts[tok_type].append(str(token))
        else:
            token_counts[tok_type] = [str(token)]
    #return a list of fields per type in a data table
    return token_counts


In [61]:
from sqlparse import parse

# Example SQL query
#query = "select order.order_id, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc"
query = "SELECT first_name, last_name FROM employees WHERE salary > 50000"

# Parse the query into tokens
parsed = parse(query)[0]

# Get the fields in the parsed query
fields = get_fields(parsed)

# Print out the types and tokens in the data table
for tok_type, tokens in fields.items():
    print(f"{tok_type}: {tokens}")


Token: ['SELECT', ' ', ' ', 'FROM', ' ', ' ']
IdentifierList: ['first_name, last_name']
Identifier: ['employees']
Where: ['WHERE salary > 50000']


In [None]:
import sqlparse

def get_identifiers(parsed):
    parsed_tokens = parsed.tokens
    identifiers = []
    for token in parsed_tokens:
        tok_type = type(token).__name__
        if tok_type == 'Whitespace':
            continue
        elif tok_type == 'Keyword':
            continue
        elif tok_type == 'Identifier':
            identifier = str(token)
            if '.' in identifier:
                identifiers.extend(identifier.split('.'))
            else:
                identifiers.append(identifier)
        elif tok_type == 'Name':
            identifiers.append(str(token))
    return identifiers

#sql = "SELECT first_name, last_name FROM employees WHERE department='Sales'"
sql = "select order.order_id, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc"

parsed = sqlparse.parse(sql)[0]
identifiers = get_identifiers(parsed)
print(identifiers)



['orders o', 'revenue r', 'order', 'orderid desc']


In [41]:
def get_identifiers(parsed):
    parsed_tokens = parsed.tokens
    identifier_list = []
    for token in parsed_tokens:
        tok_type = type(token).__name__
        if tok_type == 'IdentifierList':
            for identifier in token.get_identifiers():
                if not isinstance(identifier, sqlparse.sql.Identifier):
                    continue
                identifier_name = identifier.get_name()
                if '.' in identifier_name:
                    identifier_name = identifier_name.split('.')[1]
                identifier_list.append(identifier_name)
        elif tok_type == 'Identifier':
            identifier_name = token.get_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_list.append(identifier_name)
    return identifier_list


In [None]:
import sqlparse

def get_identifiers(parsed):
    parsed_tokens = parsed.tokens
    identifiers = []
    for token in parsed_tokens:
        tok_type = type(token).__name__
        if tok_type == 'Whitespace':
            continue
        elif tok_type == 'Keyword':
            continue
        elif tok_type == 'Identifier':
            identifier = str(token)
            if '.' in identifier:
                identifiers.extend(identifier.split('.'))
            else:
                identifiers.append(identifier)
        elif tok_type == 'Name':
            identifiers.append(str(token))
    return identifiers

#sql = "SELECT first_name, last_name FROM employees WHERE department='Sales'"
sql = "select order.order_id, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc"

parsed = sqlparse.parse(sql)[0]
identifiers = get_identifiers(parsed)
print(identifiers)



['orders o', 'revenue r', 'order', 'orderid desc']


In [None]:
import sqlparse

def get_identifiers(parsed):
    parsed_tokens = parsed.tokens
    identifiers = []
    for token in parsed_tokens:
        tok_type = type(token).__name__
        if tok_type == 'Whitespace':
            continue
        elif tok_type == 'Keyword':
            continue
        elif tok_type == 'Identifier':
            identifier = str(token)
            if '.' in identifier:
                identifiers.extend(identifier.split('.'))
            else:
                identifiers.append(identifier)
        elif tok_type == 'Name':
            identifiers.append(str(token))
    return identifiers

#sql = "SELECT first_name, last_name FROM employees WHERE department='Sales'"
sql = "select order.order_id, order.order_name, revenue.order_id, r.product_name, revenue.cost from orders o join revenue r on r.order_id  = r.order_id order by order.orderid desc"

parsed = sqlparse.parse(sql)[0]
identifiers = get_identifiers(parsed)
print(identifiers)



['orders o', 'revenue r', 'order', 'orderid desc']


In [42]:
import sqlparse

sql = "SELECT column1, column2 FROM table1 WHERE column3 = 'value'"
parsed = sqlparse.parse(sql)[0]

identifiers = get_identifiers(parsed)
print(identifiers)


['column1', 'column2', 'table1']


In [50]:
def get_identifier_list(parsed):
    parsed_tokens = parsed.tokens
    identifier_list = []
    for token in parsed_tokens:
        tok_type = type(token).__name__
        if tok_type == 'IdentifierList':
            for identifier in token.get_identifiers():
                if not isinstance(identifier, sqlparse.sql.Identifier):
                    continue
                identifier_name = identifier.get_name()
                if '.' in identifier_name:
                    identifier_name = identifier_name.split('.')[1]
                identifier_list.append(identifier_name)
    return identifier_list


In [52]:
import sqlparse

sql = "SELECT column1, column2 FROM table1 WHERE column3 = 'value'"
parsed = sqlparse.parse(sql)[0]

identifier_list = get_identifier_list(parsed)
print(identifier_list)


['column1', 'column2']


In [72]:
def get_tok_types(parsed):
    parsed_tokens = parsed.tokens
    tok_types = []
    for token in parsed_tokens:
        tok_type = type(token).__name__
        if tok_type in ['Identifier', 'Comparison', 'IdentifierList', 'Where']:
            tok_types.append(tok_type)
    return tok_types


In [73]:
import sqlparse

query = "SELECT first_name, last_name FROM employees WHERE salary > 50000"
parsed_query = sqlparse.parse(query)[0]
tok_types = get_tok_types(parsed_query)
print(tok_types)


['IdentifierList', 'Identifier', 'Where']


In [226]:
import sqlparse

query = "Select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id"
parsed_query = sqlparse.parse(query)[0]
parsed_query.tokens
#print(parsed_query.tokens)
# tok_types = get_identifiers(query)
# print(tok_types)

[<Identifier 'elect ...' at 0x119F04740>,
 <Whitespace ' ' at 0x11AB80280>,
 <Integer '10' at 0x11AB801C0>,
 <Whitespace ' ' at 0x11AB825C0>,
 <IdentifierList 'getdat...' at 0x11AADB4C0>,
 <Whitespace ' ' at 0x11AB81900>,
 <Keyword 'from' at 0x11AB827A0>,
 <Whitespace ' ' at 0x11AB818A0>,
 <Identifier 'orders...' at 0x119F04200>,
 <Whitespace ' ' at 0x11AAB41C0>,
 <Keyword 'join' at 0x11AAB4040>,
 <Whitespace ' ' at 0x11AAB46A0>,
 <Identifier 'revenu...' at 0x119F044A0>,
 <Whitespace ' ' at 0x11AAB4DC0>,
 <Keyword 'on' at 0x11AAB4D60>,
 <Whitespace ' ' at 0x11AAB50C0>,
 <Comparison 'o.orde...' at 0x11AADB680>]

In [235]:
import sqlparse

def extract_columns_and_tables(sql):
    # Parse the SQL statement into tokens
    parsed = sqlparse.parse(sql)[0]

    # Initialize empty lists to store column and table names
    results = []
    tables = set()

    # Loop through all tokens in the parsed SQL statement
    for token in parsed.flatten():
        # If the token is a function, skip it
        if isinstance(token, sqlparse.sql.Function):
            continue
        # If the token is a column or table name, add it to the results list
        elif isinstance(token, sqlparse.sql.Identifier):
            # Check if the previous token is a period, indicating a table alias
            if token.parent and token.parent.value == '.':
                table_name = token.parent.get_previous_sibling().get_real_name()
                column_name = token.get_real_name()
                results.append(table_name + '.' + column_name)
                tables.add(table_name)
            else:
                results.append(token.get_real_name())
        # If the token is a join keyword, get the table name from the next token
        elif token.value.lower() == 'join':
            tables.add(token.next_token.next_token.get_real_name())

    # Return the list of column and table names
    return list(results) + list(tables)


query = "Select top 10 getdate(), o.order_number, client_name, sum(r.net_revenue) from orders o join revenue r on o.order_id = r.order_id"
print(extract_columns_and_tables(query))




AttributeError: 'Token' object has no attribute 'next_token'

In [236]:
import sqlparse

def extract_columns_and_tables(sql):
    # Parse the SQL statement into tokens
    parsed = sqlparse.parse(sql)[0]

    # Initialize empty lists to store column and table names
    results = []
    tables = set()

    # Loop through all tokens in the parsed SQL statement
    for token in parsed.flatten():
        # If the token is a function, skip it
        if isinstance(token, sqlparse.sql.Function):
            continue
        # If the token is a column or table name, add it to the results list
        elif isinstance(token, sqlparse.sql.Identifier):
            # Check if the previous token is a period, indicating a table alias
            if token.parent and token.parent.value == '.':
                table_name = token.parent.get_previous_sibling().get_real_name()
                column_name = token.get_real_name()
                results.append(table_name + '.' + column_name)
                tables.add(table_name)
            else:
                results.append(token.get_real_name())
        # If the token is a join keyword, get the table name from the next token
        elif token.value.lower() == 'join':
            tables.add(token.next_token.next_token.get_real_name())

    # Return the list of column and table names
    return list(results) + list(tables)


query = "Select top 10 getdate(), o.order_number, client_name, sum(r.net_revenue) from orders o join revenue r on o.order_id = r.order_id"
print(extract_columns_and_tables(query))




AttributeError: 'Token' object has no attribute 'next_token'

In [None]:
lst = ['first_name', 'employees', 'last_name', 'salary', 'department']

# Mask string literals in the list
masked_lst = mask_string_literals(lst)
print(masked_lst)  # Output: ["'first_name'", 'employees', "'last_name'", "'salary'", "'department'"]

# Unmask string literals in the list
unmasked_lst = unmask_string_literals(masked_lst)
print(unmasked_lst)  # Output: ['first_name', 'employees', 'last_name', 'salary', 'department']

# # Apply mask to SQL code
# sql = "SELECT first_name FROM employees WHERE department = 'Sales'"
# masked_sql = apply_mask_to_sql(sql)
# print(masked_sql)  # Output: "SELECT first_name FROM employees WHERE department = ''Sales''"


['first_name', 'employees', 'last_name', 'salary', 'department']
['first_name', 'employees', 'last_name', 'salary', 'department']


In [None]:
lst = ['first_name', 'employees', 'last_name', 'salary', 'department']

# Mask string literals in the list
masked_lst = mask_string_literals(lst)
print(masked_lst)  # Output: ["'first_name'", 'employees', "'last_name'", "'salary'", "'department'"]

# Unmask string literals in the list
unmasked_lst = unmask_string_literals(masked_lst)
print(unmasked_lst)  # Output: ['first_name', 'employees', 'last_name', 'salary', 'department']

# # Apply mask to SQL code
# sql = "SELECT first_name FROM employees WHERE department = 'Sales'"
# masked_sql = apply_mask_to_sql(sql)
# print(masked_sql)  # Output: "SELECT first_name FROM employees WHERE department = ''Sales''"


['first_name', 'employees', 'last_name', 'salary', 'department']
['first_name', 'employees', 'last_name', 'salary', 'department']


In [None]:
lst = ['first_name', 'employees', 'last_name', 'salary', 'department']

# Mask string literals in the list
masked_lst = mask_string_literals(lst)
print(masked_lst)  # Output: ["'first_name'", 'employees', "'last_name'", "'salary'", "'department'"]

# Unmask string literals in the list
unmasked_lst = unmask_string_literals(masked_lst)
print(unmasked_lst)  # Output: ['first_name', 'employees', 'last_name', 'salary', 'department']

# # Apply mask to SQL code
# sql = "SELECT first_name FROM employees WHERE department = 'Sales'"
# masked_sql = apply_mask_to_sql(sql)
# print(masked_sql)  # Output: "SELECT first_name FROM employees WHERE department = ''Sales''"


['first_name', 'employees', 'last_name', 'salary', 'department']
['first_name', 'employees', 'last_name', 'salary', 'department']


In [244]:
#Alles zusammen - Working for parsing fields from a sql statement. Decide to do it modularly, not sure why but this works. That is why
#works for where!!!
import sqlparse

def get_where_fields(query):
    parsed_query = sqlparse.parse(query)[0]
    where_clause = None
    for token in parsed_query.tokens:
        if isinstance(token, sqlparse.sql.Where):
            where_clause = token
            break
    if not where_clause:
        return []
    fields = []
    for token in where_clause.tokens:
        if isinstance(token, sqlparse.sql.Comparison):
            left = token.left
            if isinstance(left, sqlparse.sql.Identifier):
                fields.append(left.get_name())
            elif isinstance(left, sqlparse.sql.Function):
                fields.append(left.tokens[0].get_name())
    return fields

# Works for identifiers
def get_identifiers(query):
    parsed_tokens = sqlparse.parse(query)[0]
    identifier_set = set()
    for token in parsed_tokens:
        # If the token is a function, skip it
        if isinstance(token, sqlparse.sql.Function):
            continue
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                identifier_name = identifier.get_name()
                if '.' in identifier_name:
                    identifier_name = identifier_name.split('.')[1]
                identifier_set.add(identifier_name)
        elif isinstance(token, sqlparse.sql.Identifier):
            identifier_name = token.get_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
        elif isinstance(token, sqlparse.sql.Comparison):
            identifier_name = token.get_name()
            identifier_set.add(identifier_name)
        elif isinstance(token, sqlparse.sql.Where):
            identifier_name = token.get_name()
            identifier_set.add(identifier_name)
        elif isinstance(token, sqlparse.sql.Function):
            continue
    return list(identifier_set)

import sqlparse
#query = "SELECT first_name, last_name FROM employees WHERE salary > 50000 AND department = 'Sales'"
query ="select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id"
identifiers = get_identifiers(query)
# where_fields = get_where_fields(query)
# list_of_fields = identifiers + where_fields
# print(list_of_fields)
print(identifiers)


['order_number', 'o', None, 'sum', 'top', 'client_name', 'r', 'getdate']


#Finally working for this query.

In [252]:
import sqlparse

def get_identifiers(sql):
    parsed_tokens = sqlparse.parse(sql)[0]
    identifier_set = set()
    
    reserved_words = ['TOP', 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'GROUP', 'BY', 'HAVING', 'ORDER', 'ASC', 'DESC']

    for token in parsed_tokens.tokens:
        # If the token is a function, skip it
        if isinstance(token, sqlparse.sql.Function) or token.value.upper() in reserved_words:
            continue
        if isinstance(token, sqlparse.sql.Function):
            continue
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    identifier_name = identifier.get_real_name()
                    if '.' in identifier_name:
                        identifier_name = identifier_name.split('.')[1]
                    identifier_set.add(identifier_name)
        # If the token is a comparison operator, get the column name
        elif isinstance(token, sqlparse.sql.Comparison):
            identifier_name = token.left.get_real_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
        elif isinstance(token, sqlparse.sql.Identifier):
            identifier_name = token.get_real_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
    return list(identifier_set)



query ="select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id where revenue_order in"
identifiers = get_identifiers(query)
print(identifiers)

['order_number', 'revenue', 'order_id', 'orders', 'client_name']


# Masking and Demasking

In [253]:
import re
import random

def masking(list_of_fields, sql):
    """
    This function takes in a list of words and a SQL string as input and replaces the words in the SQL string with random words.
    """
    # Create a dictionary to store the mapping between original words and masked words
    word_map = {}
    
    # Loop through each word in the list of words
    for word in list_of_fields:
        # Generate a random word to replace the original word
        random_word = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=len(word)))
        
        # Add the mapping to the dictionary
        word_map[word] = random_word
        
        # Replace the original word with the random word in the SQL string
        sql = re.sub(r'\b{}\b'.format(word), random_word, sql)
    
    # Return the masked SQL string and the word map
    return sql, word_map


def demasking(word_map, sql):
    """
    This function takes in a word map and a masked SQL string as input and replaces the masked words with their original words.
    """
    # Loop through each key-value pair in the word map
    for original_word, masked_word in word_map.items():
        # Replace the masked word with the original word in the SQL string
        sql_string = re.sub(r'\b{}\b'.format(masked_word), original_word, sql)
    
    # Return the demasked SQL string
    return sql_string


In [254]:
# Define the list of words to mask
words_to_mask = ['first_name', 'employees', 'last_name', 'salary', 'department']

# Define the SQL string to mask
sql = "SELECT first_name, last_name FROM employees WHERE salary > 50000 AND department = 'Sales'"

# Mask the SQL string
masked_sql_string, word_map = masking(words_to_mask, query)

# Print the masked SQL string and the word map
print("Masked SQL string:", masked_sql_string)
print("Word map:", word_map)

# Demask the SQL string
demasked_sql_string = demasking(word_map, masked_sql_string)

# Print the demasked SQL string
print("Demasked SQL string:", demasked_sql_string)


Masked SQL string: select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id where revenue_order in
Word map: {'first_name': 'shyuckvdgw', 'employees': 'uudyfsqyx', 'last_name': 'rvecpyxir', 'salary': 'yqrvxk', 'department': 'uwlebjbjak'}
Demasked SQL string: select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id where revenue_order in


In [222]:
# Define the list of words to mask
words_to_mask = ['client_name', 'r', 'top', 'order_number', 'sum', 'o', 'getdate']

# Define the SQL string to mask
query = "select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id"

# Mask the SQL string
masked_sql_string, word_map = masking(words_to_mask, query)

# Print the masked SQL string and the word map
print("Masked SQL string:", masked_sql_string)
print("Word map:", word_map)


Masked SQL string: select ciw 10 ufbdnsy(), wjdonbpitghk, ndyykbwdxur, tcv(net_revenue) from orders x join revenue y on x.order_id = y.order_id
Word map: {'client_name': 'ndyykbwdxur', 'r': 'y', 'top': 'ciw', 'order_number': 'wjdonbpitghk', 'sum': 'tcv', 'o': 'x', 'getdate': 'ufbdnsy'}


In [277]:
import sqlparse

def get_identifiers(sql):
    parsed_tokens = sqlparse.parse(sql)[0]
    identifier_set = set()

    reserved_words = ['TOP', 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'GROUP', 'BY', 'HAVING', 'ORDER', 'ASC', 'DESC']

    def process_identifier(token, identifier_set):
        identifier_name = token.get_real_name()
        if '.' in identifier_name:
            identifier_name = identifier_name.split('.')[1]
        identifier_set.add(identifier_name)

    def process_function_arguments(token, identifier_set):
        if isinstance(token, sqlparse.sql.Identifier):
            process_identifier(token, identifier_set)
        elif isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    process_identifier(identifier, identifier_set)
        elif isinstance(token, sqlparse.sql.Parenthesis):
            for subtoken in token.tokens:
                process_function_arguments(subtoken, identifier_set)

    def add_identifiers_from_function(token, identifier_set):
        for subtoken in token.tokens:
            process_function_arguments(subtoken, identifier_set)

    def process_where(token, identifier_set):
        for subtoken in token.tokens:
            if isinstance(subtoken, sqlparse.sql.Comparison):
                process_identifier(subtoken.left, identifier_set)
            elif isinstance(subtoken, sqlparse.sql.Identifier):
                process_identifier(subtoken, identifier_set)

    for token in parsed_tokens.tokens:
        if isinstance(token, sqlparse.sql.Function):
            add_identifiers_from_function(token, identifier_set)
            continue
        if token.value.upper() in reserved_words:
            continue
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    process_identifier(identifier, identifier_set)
        elif isinstance(token, sqlparse.sql.Comparison):
            process_identifier(token.left, identifier_set)
        elif isinstance(token, sqlparse.sql.Where):
            process_where(token, identifier_set)
        elif isinstance(token, sqlparse.sql.Identifier):
            process_identifier(token, identifier_set)
    return list(identifier_set)

query = "select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id where revenue_order in ('3245', '34244',  '4532')"
identifiers = get_identifiers(query)
print(identifiers)


['order_number', 'revenue', 'order_id', 'revenue_order', 'orders', 'client_name']


In [278]:
import re
import random

def masking(identifiers, sql):
    """
    This function takes in a list of identifiers and an SQL query as input, and replaces the identifiers in the SQL query with random words.
    """
    # Create a dictionary to store the mapping between original identifiers and masked words
    word_map = {}
    
    # Loop through each identifier in the list of identifiers
    for identifier in identifiers:
        # Generate a random word to replace the original identifier
        random_word = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=len(identifier)))
        
        # Add the mapping to the dictionary
        word_map[identifier] = random_word
        
        # Replace the original identifier with the random word in the SQL string
        sql = re.sub(r'\b{}\b'.format(identifier), random_word, sql)
    
    # Return the masked SQL string and the word map
    return sql, word_map


masked_sql_wordmap = masking(identifiers, query)
print(masked_sql_wordmap)

# Still missing the Where cluse pick up - I am missing revenue_order field. This needs to be added to identifiers function. Here I am then this is ready. 

("select top 10 getdate(), hzeuedhpnmuv, qimyvibgkze, sum(net_revenue) from lauwkt o join jjiurhy r on o.zggccmub = r.zggccmub where etuanvvjinxdc in ('3245', '34244',  '4532')", {'order_number': 'hzeuedhpnmuv', 'revenue': 'jjiurhy', 'order_id': 'zggccmub', 'revenue_order': 'etuanvvjinxdc', 'orders': 'lauwkt', 'client_name': 'qimyvibgkze'})


In [283]:
def get_identifiers(sql):
    parsed_tokens = sqlparse.parse(sql)[0]
    identifier_set = set()

    reserved_words = ['TOP', 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'GROUP', 'BY', 'HAVING', 'ORDER', 'ASC', 'DESC']

    def process_identifier(token, identifier_set):
        identifier_name = token.get_real_name()
        if '.' in identifier_name:
            identifier_name = identifier_name.split('.')[1]
        identifier_set.add(identifier_name)

    def process_function_arguments(token, identifier_set):
        if isinstance(token, sqlparse.sql.Identifier):
            process_identifier(token, identifier_set)
        elif isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    process_identifier(identifier, identifier_set)
        elif isinstance(token, sqlparse.sql.Parenthesis):
            for subtoken in token.tokens:
                process_function_arguments(subtoken, identifier_set)

    def add_identifiers_from_function(token, identifier_set):
        for subtoken in token.tokens:
            process_function_arguments(subtoken, identifier_set)

    def process_where(token, identifier_set):
        for subtoken in token.tokens:
            if isinstance(subtoken, sqlparse.sql.Comparison):
                process_identifier(subtoken.left, identifier_set)
            elif isinstance(subtoken, sqlparse.sql.Identifier):
                process_identifier(subtoken, identifier_set)

    for token in parsed_tokens.tokens:
        if isinstance(token, sqlparse.sql.Function):
            add_identifiers_from_function(token, identifier_set)
            continue
        if token.value.upper() in reserved_words:
            continue
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    process_identifier(identifier, identifier_set)
        elif isinstance(token, sqlparse.sql.Comparison):
            process_identifier(token.left, identifier_set)
        elif isinstance(token, sqlparse.sql.Where):
            process_where(token, identifier_set)
        elif isinstance(token, sqlparse.sql.Identifier):
            process_identifier(token, identifier_set)
    return list(identifier_set)

def sql_masking(identifiers, sql):
    """
    This function takes in a list of identifiers and an SQL query as input, and replaces the identifiers in the SQL query with random words.
    """
    # Create a dictionary to store the mapping between original identifiers and masked words
    word_map = {}
    
    # Loop through each identifier in the list of identifiers
    for identifier in identifiers:
        # Generate a random word to replace the original identifier
        random_word = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=len(identifier)))
        
        # Add the mapping to the dictionary
        word_map[identifier] = random_word
        
        # Replace the original identifier with the random word in the SQL string
        sql = re.sub(r'\b{}\b'.format(identifier), random_word, sql)
    
    # Return the masked SQL string and the word map
    return sql, word_map

In [284]:
sql="select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id where revenue_number in ('2345', '9908', '6671')"

In [285]:
list_of_fields = get_identifiers(sql)
print(list_of_fields)

['order_number', 'revenue_number', 'revenue', 'order_id', 'orders', 'client_name']


In [286]:
masked_sql, word_map  = sql_masking(list_of_fields, sql)
print(masked_sql)

select top 10 getdate(), rbkvnpvaqzio, ggdcjfoivjv, sum(net_revenue) from mmkuuz o join orekjeg r on o.hpdptrzn = r.hpdptrzn where xqpoososamybqd in ('2345', '9908', '6671')


## OpenAI

In [None]:
def sql_dialectify(to_sql, masked_sql):
    completion = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": 'You are an expert SQL developer that is proficient in MS SQL Server, MySQL, Oracle, PostgreSQL, SQLite, Snowflake SQL dialects.'},
            {"role": "system", "content": 'Only return the converted sql code and do not explain the conversion process.'},
            {"role": "system", "content": 'Check for the correctness of the entered SQL code. And make updates if necessary. List the changes succinctly in the chat.'},
            {"role": "system", "content": 'Let''s think step by step.'},
            {"role": "user", "content": f'Detect the dialect of the following SQL code: "{masked_sql}"'},
            {"role": "system", "content": f'Check and fix errors for the top common SQL syntax mistakes for the detected dialect. List updated parts of the following SQL code: "{masked_sql}"'},
            {"role": "user", "content": f'Convert the updated SQL code from detected dialect to "{to_sql}": "\n\n{masked_sql}"'}
        ]
    )
    converted_sql = completion.choices[0].message.content
    return converted_sql

In [None]:
if openai.api_key and st.button("Convert"):
    st.write("Converting the SQL Code...")
    # Convert the SQL dialect using the OpenAI API
    masked_converted_sql = sql_dialectify(to_sql, masked_sql)
    # Display the converted SQL code
    st.text_area("Converted SQL Code", masked_converted_sql)