In [7]:
from sqlalchemy import create_engine
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm import sessionmaker
from sqlalchemy import inspect, Table, MetaData
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import select
from sqlalchemy.orm import scoped_session
import pandas as pd
import sys
import os
import yaml
import datetime
from contextlib import contextmanager
from IPython.display import display, HTML

In [8]:
@contextmanager
def get_session_context(sessionmaker_factory):
    """Safe-ish way to handle a session within a context. Implements rollback
    on failure, commit on success, and close on exit.

    Parameters
    ----------
    sessionmaker_factory : `sqlalchemy.orm.session.sessionmaker`
        `sessionmaker` instance used to generate sessions

    Yields
    ------
    `sqlalchemy.orm.session.Session`
        sqlalchemy session used to manage ORM operations

    Notes
    -----
    - Technically this function yields an
        `sqlalchemy.orm.scoping.scoped_session` instance but I think that it is
        probably the same as returning a `Session` from our perspective
    """
    # Commit occurs when this context ends; uncaught exception triggers rollback (see get_session_context)
    session = scoped_session(sessionmaker_factory)

    try:
        yield session
        session.commit()
    except Exception:
        session.rollback()
        raise
    finally:
        session.close()

In [9]:
e = create_engine('postgresql://postgres:postgres@devdb2:5432/mtrain_test')
_SessionmakerFactory = sessionmaker(bind=e)

In [12]:
# Dump old work:
meta = MetaData(e)
meta.reflect()
meta.drop_all()

# Add back ORM:
from orm import TrainingStage, Transition, BehavorialTraining, BehavioralStageGraph
BehavioralStageGraph.metadata.create_all(e)

# Populate ORM with stages:
training_stages_file_name = os.path.join(os.path.dirname('./'), 'training_stages.yml')
stages = yaml.load(open(training_stages_file_name, 'r'))
for training_stage, stage in stages.items():
    script = [stage['script'], None, None]
    curr_stage = TrainingStage(training_stage=training_stage, script=script, parameters=stage['parameters'], regimen='default')
    with get_session_context(_SessionmakerFactory) as session:
        session.add(curr_stage)

# Populate ORM with transitions
transitions = yaml.load(open(os.path.join(os.path.dirname('./'), 'transitions.yml'), 'r'))
for transition in transitions:
    
    # Update source and dest:
    for stage in ['source', 'dest']:
        with get_session_context(_SessionmakerFactory) as session:
            q = select([TrainingStage]).where(TrainingStage.training_stage==transition[stage] and TrainingStage.regimen=='default')
            result = [row['id'] for row in session.execute(q)]
        assert len(result) == 1
        transition[stage] = result[0]

    curr_transition = Transition(**transition)
    with get_session_context(_SessionmakerFactory) as session:
        session.add(curr_transition)

TEST_TRAINING_STAGES = {
    'stage_0': {
        'script': None,
        'parameters': None
    },
    'stage_1': {
        'script': None,
        'parameters': None
    },
}

TEST_TRANSITIONS = [
    {
        'dest': 'stage_1',
        'source': 'stage_0',
        'trigger': 'progress',
        'conditions': []
    },
]

# Populate ORM with stages:
for training_stage, stage in TEST_TRAINING_STAGES.items():
    script = [stage['script'], None, None]
    curr_stage = TrainingStage(training_stage=training_stage, script=script, parameters=stage['parameters'], regimen='test')
    with get_session_context(_SessionmakerFactory) as session:
        session.add(curr_stage)

# Populate ORM with transitions
for transition in TEST_TRANSITIONS:
    
    # Update source and dest:
    for stage in ['source', 'dest']:
        with get_session_context(_SessionmakerFactory) as session:
            q = select([TrainingStage]).where(TrainingStage.training_stage==transition[stage] and TrainingStage.regimen=='test')
            result = [row['id'] for row in session.execute(q)]
        assert len(result) == 1
        transition[stage] = result[0]

    curr_transition = Transition(**transition)
    with get_session_context(_SessionmakerFactory) as session:
        session.add(curr_transition)


# Populate ORM with training results
for mouse_id in [1234, 5678]:
    for stage in ['1_AutoRewards', 'static_full_field_gratings', 'natural_images']:
        with get_session_context(_SessionmakerFactory) as session:
            q = select([TrainingStage]).where(TrainingStage.training_stage==stage and TrainingStage.regimen=='default')
            result = [row['id'] for row in session.execute(q)]
        assert len(result) == 1
        stage_id = result[0]

        curr_training_session = BehavorialTraining(mouse_id=mouse_id, training_stage=stage_id, regimen='default', input_date=datetime.datetime.now().isoformat())
        with get_session_context(_SessionmakerFactory) as session:
            session.add(curr_training_session)

In [13]:
with get_session_context(_SessionmakerFactory) as session:
    x = pd.read_sql(select([Transition]).compile(dialect=postgresql.dialect()), session.connection())
display(x)

Unnamed: 0,id,trigger,source,dest,conditions
0,1,progress,2,3,autorewards_complete
1,2,progress,3,1,"[two_out_of_three_aint_bad, yesterday_was_good]"
2,3,progress,1,9,"[two_out_of_three_aint_bad, yesterday_was_good]"
3,4,progress,9,4,three_complete
4,5,progress,10,11,[]


In [14]:
with get_session_context(_SessionmakerFactory) as session:
    y = pd.read_sql(select([TrainingStage]).compile(dialect=postgresql.dialect()), session.connection())
display(y)

Unnamed: 0,id,training_stage,regimen,script,parameters
0,1,static_full_field_gratings_flash_500ms,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'blank_duration_rang..."
1,2,1_AutoRewards,default,"[DoC_SummerPilot.py, None, None]","{u'rewardvol': 0.007, u'delta_minimum': 2.25, ..."
2,3,static_full_field_gratings,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'blank_duration_rang..."
3,4,natural_images_drop_reward,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'pos': [0, 0], u'bla..."
4,5,natural_images_ophys_session_A,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'pos': [0, 0], u'bla..."
5,6,natural_images_ophys_session_B,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'pos': [0, 0], u'bla..."
6,7,natural_images_ophys_session_C,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'pos': [0, 0], u'bla..."
7,8,natural_images_ophys_session_D,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'pos': [0, 0], u'bla..."
8,9,natural_images,default,"[DoC_SummerPilot.py, None, None]","{u'delta_minimum': 2.25, u'pos': [0, 0], u'bla..."
9,10,stage_0,test,"[None, None, None]",


In [15]:
with get_session_context(_SessionmakerFactory) as session:
    z = pd.read_sql(select([BehavorialTraining]).compile(dialect=postgresql.dialect()), session.connection())
display(z)

Unnamed: 0,id,mouse_id,training_stage,regimen,input_date
0,1,1234,2,default,2018-01-30 14:31:30.315323
1,2,1234,3,default,2018-01-30 14:31:30.327096
2,3,1234,9,default,2018-01-30 14:31:30.337584
3,4,5678,2,default,2018-01-30 14:31:30.344864
4,5,5678,3,default,2018-01-30 14:31:30.355545
5,6,5678,9,default,2018-01-30 14:31:30.362868
