In [1]:
import IPython

import ipykernel.ipkernel
from ipython_genutils.py3compat import builtin_mod, PY3, unicode_type, safe_unicode

from IPython.utils.tokenutil import token_at_cursor, line_at_cursor


try:
    import stringdisplays, markdown, tokens, phrases, magics
except:
    with __import__('importnb').Notebook(): from pidgin import stringdisplays, markdown, tokens, phrases, magics

import types, sys, textwrap, black, mimetypes, collections, html, base64, sys, IPython, trio, ipykernel, trio, importlib, traitlets, jinja2, asyncio, tokenize, itertools, operator, toolz.curried as toolz, tokenize, io

from traitlets import *
from tornado import gen

ip = IPython.get_ipython()


In [2]:
def expression_tokens(str): return list(tokenize.generate_tokens(io.StringIO(str).readline))

In [3]:
def split_expression(str, *expressions):
    """Split an expression on the semi colons."""
    start = 0
    for id in toolz.pipe(
        str, expression_tokens, 
        toolz.filter(toolz.compose(
            ';'.__eq__, operator.attrgetter('string')
        )), 
        toolz.map(toolz.compose(
            toolz.second,
            operator.attrgetter('start')
        )),
        list
    ) + [len(str)]:
        expressions += str[start:id],
        start = id + 1
    return expressions

In [4]:
def requote(str, token='"""'):
    if token in str: token = "'''"
    return token+str+token

In [14]:
class PidginShell(IPython.core.interactiveshell.InteractiveShell):
    markdown = Bool(True)
    template = Bool(True)
    expressions = Bool(True)
    tangle_expressions = Bool(True)

    def user_expressions(self, expressions): 
        if self.expressions:
            return trio.run(self.async_user_expressions, expressions)
        return {}

    async def async_user_expressions(self, expressions):
        async with trio.open_nursery() as nursery: 
            for key, expression in expressions.items():
                nursery.start_soon(self.single_user_expression, expression)


    async def single_user_expression(self, code):
        expressions = split_expression(code)
        results = IPython.core.interactiveshell.InteractiveShell.user_expressions(self, dict(
            zip(expressions, map(self.input_transformer_manager.transform_cell, expressions))))

        IPython.display.display(IPython.display.Markdown('''`>>> {}`'''.format(code.strip('`'))))

        error_msg = []
        for expression, result in results.items():
            expressions_ns = '`{}`'.format(expression.strip())
            if result['status'] == 'error':
                error_msg.extend(result['traceback'])
            elif expressions_ns in self.user_ns:
                display = self.user_ns[expressions_ns]
                IPython.display.publish_display_data(
                    result['data'], update=True, transient={
                        'display_id': display.display_id}
                )

        if error_msg:
            IPython.display.publish_display_data({'text/plain': ''.join(error_msg)})

        if result['status'] == 'ok':
            display = self.user_ns[
                "`{}`".format(code.strip())
            ] = self.user_ns[expressions_ns] = IPython.display.DisplayHandle()
            IPython.display.publish_display_data(result['data'], transient={
                'display_id': display.display_id})
        if error_msg: result['status'] = 'error'
        return result

    def run_cell(self, code, store_history=False, silent=False, shell_futures=True, **user_expressions):
        self._last_traceback = None
        if self.markdown and not code.lstrip().startswith('%'): 
            source = markdown.renderer(code, user_expressions=user_expressions)
        else: source = code

        silent = silent or not source
        display = None

        if self.markdown and source:
            if code.strip() and code.splitlines()[0].strip():
                display = IPython.display.display( IPython.display.Markdown(code), display_id=True)

        if not source:
            if self.template:
                code = jinja2.Template(code).render(**self.user_ns)
            result = IPython.core.interactiveshell.InteractiveShell.run_cell(
                    self, requote(code), store_history=False, silent=silent, shell_futures=shell_futures)        
        else: 
            result = IPython.core.interactiveshell.InteractiveShell.run_cell(
                self, source, store_history=store_history, silent=silent, shell_futures=shell_futures)

        if user_expressions:
            IPython.display.display(IPython.display.Markdown('---'))
            self.user_expressions(user_expressions)

        display and self.template and display.update(IPython.display.Markdown(jinja2.Template(code).render(**self.user_ns)))
        return result

In [15]:
original_methods = {}
def load_ipython_extension(ip):
    global original_methods
    for method in (
        PidginShell.user_expressions, PidginShell.run_cell, 
        PidginShell.single_user_expression, PidginShell.async_user_expressions
    ): 
        object = getattr(ip, method.__name__, None)
        if object:
            original_methods[method.__name__] = original_methods.get(method.__name__, object)
        setattr(ip, method.__name__, types.MethodType(method, ip))
    
    for trait in (
        'markdown', 'template', 'expressions', 'tangle_expressions'
    ): setattr(ip, trait, getattr(PidginShell,  trait))
        
def unload_ipython_extension(ip):
    ip.__dict__.update(original_methods)
    

In [None]:
    if __name__ == '__main__':
        !ipython -m pytest -- shell.ipynb
        #load_ipython_extension(get_ipython())

In [40]:
    def test_extension():
        ip = IPython.get_ipython()
        try:
            ip.run_cell("""This is not code\n\n\n\ta=range(10)""")
            assert False
        except: assert True
        load_ipython_extension(ip)
        ip.run_cell("""This is not code\n\n\n\ta=range(10)""")
        unload_ipython_extension(ip)
        