In [7]:
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML


def is_subselect(parsed):
    if not parsed.is_group:
        return False
    for item in parsed.tokens:
        if item.ttype is DML and item.value.upper() == 'SELECT':
            return True
    return False


def extract_from_part(parsed):
    from_seen = False
    for item in parsed.tokens:
        if from_seen:
            if is_subselect(item):
                yield from extract_from_part(item)
            elif item.ttype is Keyword:
                return
            else:
                yield item
        elif item.ttype is Keyword and item.value.upper() == 'FROM':
            from_seen = True


def extract_table_identifiers(token_stream):
    for item in token_stream:
        if isinstance(item, IdentifierList):
            for identifier in item.get_identifiers():
                yield identifier.get_name()
        elif isinstance(item, Identifier):
            yield item.get_name()
        # It's a bug to check for Keyword here, but in the example
        # above some tables names are identified as keywords...
        elif item.ttype is Keyword:
            yield item.value


def extract_tables(sql):
    stream = extract_from_part(sqlparse.parse(sql)[0])
    return list(extract_table_identifiers(stream))


if __name__ == '__main__':
    sql = """
    select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders join revenue r on o.order_id = r.order_id
    """

    tables = ', '.join(extract_tables(sql))
    print('Tables: {}'.format(tables))

Tables: orders


In [None]:
query ="select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id"


In [4]:
# Example for retrieving column definitions from a CREATE statement
# using low-level functions.

import sqlparse


def extract_definitions(token_list):
    # assumes that token_list is a parenthesis
    definitions = []
    tmp = []
    par_level = 0
    for token in token_list.flatten():
        if token.is_whitespace:
            continue
        elif token.match(sqlparse.tokens.Punctuation, '('):
            par_level += 1
            continue
        if token.match(sqlparse.tokens.Punctuation, ')'):
            if par_level == 0:
                break
            else:
                par_level += 1
        elif token.match(sqlparse.tokens.Punctuation, ','):
            if tmp:
                definitions.append(tmp)
            tmp = []
        else:
            tmp.append(token)
    if tmp:
        definitions.append(tmp)
    return definitions


if __name__ == '__main__':
    SQL = """CREATE TABLE foo (
             id integer primary key,
             title varchar(200) not null,
             description text);"""

    parsed = sqlparse.parse(SQL)[0]

    # extract the parenthesis which holds column definitions
    _, par = parsed.token_next_by(i=sqlparse.sql.Parenthesis)
    columns = extract_definitions(par)

    for column in columns:
        print('NAME: {name!s:12} DEFINITION: {definition}'.format(
            name=column[0], definition=' '.join(str(t) for t in column[1:])))

NAME: id           DEFINITION: integer primary key
NAME: title        DEFINITION: varchar 200 not null
NAME: description  DEFINITION: text


In [10]:
import sqlparse

def get_identifiers(query):
    parsed_tokens = sqlparse.parse(query)[0]
    identifier_set = set()
    for token in parsed_tokens.tokens:
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                identifier_name = identifier.get_real_name()
                if '.' in identifier_name:
                    identifier_name = identifier_name.split('.')[1]
                identifier_set.add(identifier_name)
        # If the token is an identifier, add it to the set
        elif isinstance(token, sqlparse.sql.Identifier):
            identifier_name = token.get_real_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
        # If the token is a comparison operator, get the column name
        elif isinstance(token, sqlparse.sql.Comparison):
            identifier_name = token.left.get_real_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
        # If the token is a column, add it to the set
        elif isinstance(token, sqlparse.sql.Function):
            for sub_token in token.tokens:
                if isinstance(sub_token, sqlparse.sql.Identifier):
                    identifier_name = sub_token.get_real_name()
                    if '.' in identifier_name:
                        identifier_name = identifier_name.split('.')[1]
                    identifier_set.add(identifier_name)
    return list(identifier_set)



# OMG finally works. A bit specific to the query though

In [23]:
import sqlparse

def get_identifiers(query):
    parsed_tokens = sqlparse.parse(query)[0]
    identifier_set = set()
    
    reserved_words = ['TOP', 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'GROUP', 'BY', 'HAVING', 'ORDER', 'ASC', 'DESC']

    for token in parsed_tokens.tokens:
        # If the token is a function, skip it
        if isinstance(token, sqlparse.sql.Function) or token.value.upper() in reserved_words:
            continue
        if isinstance(token, sqlparse.sql.Function):
            continue
        if isinstance(token, sqlparse.sql.IdentifierList):
            for identifier in token.get_identifiers():
                if isinstance(identifier, sqlparse.sql.Identifier):
                    identifier_name = identifier.get_real_name()
                    if '.' in identifier_name:
                        identifier_name = identifier_name.split('.')[1]
                    identifier_set.add(identifier_name)
        # If the token is a comparison operator, get the column name
        elif isinstance(token, sqlparse.sql.Comparison):
            identifier_name = token.left.get_real_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
        elif isinstance(token, sqlparse.sql.Identifier):
            identifier_name = token.get_real_name()
            if '.' in identifier_name:
                identifier_name = identifier_name.split('.')[1]
            identifier_set.add(identifier_name)
    return list(identifier_set)


In [24]:
query ="select top 10 getdate(), order_number, client_name, sum(net_revenue) from orders o join revenue r on o.order_id = r.order_id"
print(get_identifiers(query))

['revenue', 'order_id', 'orders', 'client_name', 'order_number']
