# IBMi Jupyter Notebook Extension

In [None]:
import pyodbc, getpass, time, matplotlib
import pandas as pd
import matplotlib.pyplot as plot
from IPython.display import display, HTML, Image, Javascript
from IPython.core.magic import register_cell_magic, needs_local_scope

global _config
global _info

_config = {
    'host': '',
    'user': '',
    'pwd': ''
}

_info = {
    'sqlcode': 0,
    'sqlstate': '0',
    'sqlerror': ''
}

# disable pandas df display limit, trust user to limit result set in query
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# TODO: port? protocol? driver?
_config['host'] = input('Enter host: ').strip()
_config['user'] = input('Enter user: ').strip()
_config['pwd'] = getpass.getpass('Enter password: ')

## Establish Connection

In [None]:
print("Connecting to {} as {} ...".format(_config['host'], _config['user']))
try:
    conn = pyodbc.connect(driver='{IBM i Access ODBC Driver}', 
      system=_config['host'], uid=_config['user'], pwd=_config['pwd'])
    print('Successfully connected!')
except pyodbc.InterfaceError as e:
    print('Could not connect :(\n{}'.format(e))

## Define Utilities

In [3]:
# Splitting cell content into SQL statements delimited by ';' (when not in string)
def get_statements(sql):
    stmts = []
    i = 0
    stmt = ''
    q = ''
    for c in sql:
        if c in ('\'', '"'):
            stmt += c
            q = '' if c == q else (c if q == '' else q)
        elif c == ';':
            stmts.append(stmt)
            stmt = ''
        else:
            stmt += c
    if stmt != '':
        stmts.append(stmt)
    return stmts


# flag validation and sanitizing
def get_flag(line):
    flag = line.lower().strip()
    if   len(flag) == 0:                return ''
    elif flag[0] != '-':                raise Exception('Invalid flag passed {}'.format(flag))
    elif flag in ['-c', '-csv']:        return '-c'
    elif flag in ['-h', '-html']:       return '-h'
    elif flag in ['-j', '-json']:       return '-j'
    elif flag in ['-p', '-prep']:       return '-p'
    elif flag in ['-l', '-tex']:        return '-l'
    elif flag in ['-s', '-str']:        return '-s'
    elif flag in ['-x','-xml']:         return '-x'
    elif flag in ['-plb', '-plotbar']:  return '-plb'
    elif flag in ['-plp', '-plotpie']:  return '-plp'
    elif flag in ['-pll', '-plotline']: return '-pll'
    raise Exception('Invalid flag passed {}'.format(flag))


# convert dataframe to XML ... modified from
# https://stackoverflow.com/questions/18574108/how-do-convert-a-pandas-dataframe-to-xml
def df_to_xml(df):
    def row_to_xml(row):
        xml = ['<row>']
        for i, col_name in enumerate(row.index):
            xml.append('<col name="{0}">{1}</col>'.format(col_name, row.iloc[i]))
        xml.append('</row>')
        return ''.join(xml)
    return '<rows>'+(''.join(df.apply(row_to_xml, axis=1)))+'</rows>'


# Handle outputting result set dependent on flag passed
def output_rs(rows, cursor, flag):
    cols = [col[0] for col in cursor.description]
    df = pd.DataFrame((tuple(row) for row in rows), columns=cols)
    if   flag == '':     return df
    elif flag == '-c':   return df.to_csv(index=False)
    elif flag == '-h':   return df.to_html()
    elif flag == '-j':   return df.to_json(orient='records')
    elif flag == '-l':   return df.to_latex(index=False)
    elif flag == '-s':   return df.to_string(index=False)
    elif flag == '-x':   return df_to_xml(df)
    elif flag == '-plb': return plot_bar(df, len(df.columns))
    elif flag == '-plp': return plot_pie(df, len(df.columns))
    elif flag == '-pll': return plot_line(df, len(df.columns))
    raise Exception('Invalid flag passed {}'.format(flag))


# plot bar graph from dataframe
def plot_bar(df, col_count):
    if col_count == 1:
        df.index = df.index + 1
        df.plot(kind='bar')
        plot.plot()
    elif col_count == 2:
        df.plot(kind='bar', x=df.columns.values[0], y=df.columns.values[1])
        plot.plot()
    elif col_count == 3:
        pivoted = pandas.pivot_table(df, 
                            values=df.columns.values[2], 
                            columns=df.columns.values[0], 
                            index=df.columns.values[1])
        pivoted.plot.bar()
    else:
        raise Exception('cannot generate bar plot with > 3 column(s)')


# plot pie graph from dataframe        
def plot_pie(df, col_count):
    if col_count == 1:
        df.index = df.index + 1
        pie = df.plot(kind='pie', y=df.columns.values[0], autopct='%.2f')
    elif col_count == 2:
        pie = df.plot(kind='pie', y=df.columns.values[1], labels=df[df.columns.values[0]].tolist(), 
            autopct='%.2f')
    else:
        raise Exception('cannot generate pie plot with > 2 column(s)')
    plot.legend(loc='right', bbox_to_anchor=(1.5, 0.50))
    plot.show()


# plot line graph from dataframe
def plot_line(df, col_count):
    if col_count == 1:
        df.index = df.index + 1
        df.plot(kind='line')
    elif col_count == 2:
        df.plot(kind='line', x=df.columns.values[0], y=df.columns.values[1])
    elif col_count == 3:
        pivoted = pandas.pivot_table(df, 
                            values=df.columns.values[2], 
                            columns=df.columns.values[0], 
                            index=df.columns.values[1])
        pivoted.plot()
    else:
        raise Exception('cannot generate line plot with > 3 column(s)')

## Setup Cell Magic

In [None]:
# invoke with %%ibmi
@needs_local_scope 
@register_cell_magic
def ibmi(line, cell):
    if conn is None:
        print('Not connected.')
        return None
    i = 0
    flag = get_flag(line)
    stmts = get_statements(cell.strip())
    for stmt in stmts:
        sql = ' '.join([x.replace('{','{{').replace('}','}}').replace('\n','').strip() for x in stmt.split('\n')])
        i += 1
        print('Executing statement {} of {} ...'.format(i, len(stmts)))
        sql = sql.format(**globals())
        cursor = conn.cursor()
        try:
            start_time = time.time()
            affected = cursor.execute(sql)
            while 1:
                try:
                    rows = cursor.fetchall()
                    print('    Fetched {} row(s) in {} second(s)'.format(len(rows), round(time.time()-start_time,3)))
                    if len(rows) > 0:
                        result = output_rs(rows, cursor, flag)
                        if not result is None: 
                            display(result)
                except pyodbc.ProgrammingError as e:  # no result set, just an update,insert,create,etc statement
                    print('    Statement executed successfully.\n    {} row(s) affected by statement.'.format(affected.rowcount))
                if not cursor.nextset(): break
        except Exception as e:
            print('Unexpected error occurred\n' + str(e))
        finally:
            cursor.close()
del ibmi