# Demand Prediction Integrated Schema Team Demo Notebook# 

### Import Integrated Schema related code ### 

In [1]:
import Datalog_Parsing as dp
from Mediator import Mediator
md = Mediator()

### Sample call to unfold a datalog query ###

In [2]:
input_datalog = ('''Ans ( nodeid, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol,
    p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating,
    p12m_numreviews, p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) :-
 mlfeatures ( nodeid, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol,
     p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating,
     p12m_numreviews, p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) ,
     nodeid in (15, 45, 121)
''')

md.unfold_datalog(input_datalog)[0]

u'Ans(nodeid,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-S1.mv_ml_features(nodeid,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),S2.mlview(nodeid,yr,mn,pm_avgsntp,p3m_avgsntp,p12m_avgsntp),nodeid in (15,45,121)'

# Query Validation Examples

### Correct number of body atoms

In [3]:
# validation on number of body atoms fails
query = '''Ans (nodeid, sales, vol) :- 
    sales_agg_mn (nodeid, mn, sales, vol, _, _, rank_vol, EXTRA),
    mn=12, rank_vol<=3'''

md.unfold_datalog(query)[0]

Exception: wrong number of atoms in sales_agg_mn

In [4]:
# same query with correct number of body atoms unfolds
query = '''Ans (nodeid, sales, vol) :- 
    sales_agg_mn (nodeid, mn, sales, vol, _, _, rank_vol),
    mn=12, rank_vol<=3'''

md.unfold_datalog(query)[0]

u'Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3'

### Head atoms must appear in body

In [5]:
# validation fails - head attribute not in body
query = '''Ans (nodeid, sales, vol, EXTRA) :- 
    sales_agg_mn (nodeid, mn, sales, vol, _, _, rank_vol),
    mn=12, rank_vol<=3'''

md.unfold_datalog(query)[0]

Exception: head attribute EXTRA is not in datalog body

In [6]:
# same query without extra head attribute unfolds
query = '''Ans (nodeid, sales, vol) :- 
    sales_agg_mn (nodeid, mn, sales, vol, _, _, rank_vol),
    mn=12, rank_vol<=3'''

md.unfold_datalog(query)[0]

u'Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3'

### Condition atoms must appear in body

In [7]:
# validation fails - rank_vol condition not in body
query = '''Ans (nodeid, sales, vol) :- 
    sales_agg_mn (nodeid, mn, sales, vol, _, _, _ ),     
    mn=12, rank_vol<=3'''

md.unfold_datalog(query)[0]

Exception: condition attribute rank_vol is not in datalog body

In [8]:
# same query with rank_vol condition in body runs
query = '''Ans (nodeid, sales, vol) :- 
    sales_agg_mn (nodeid, mn, sales, vol, _, _, rank_vol ),     
    mn=12, rank_vol<=3'''

md.unfold_datalog(query)[0]

u'Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3'

# Query Optimization Example

In [9]:
# query on mlfeatures unfolds to S1.mv_ml_features and S2.mlview
query = '''Ans (nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, 
    p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, 
    p12m_numreviews, p12m_avgrating) :- 
mlfeatures (nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, 
    p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, 
    p12m_numreviews, p12m_avgrating, _, _, SENTIMENT_ATTRIBUTE)'''

md.unfold_datalog(query)[0]

u'Ans(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating):-S1.mv_ml_features(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),S2.mlview(nodeId,yr,mn,_,_,SENTIMENT_ATTRIBUTE)'

In [10]:
# query on mlfeatures without sentiment attributes unfolds to S1.mv_ml_features
query = '''Ans (nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, 
    p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, 
    p12m_numreviews, p12m_avgrating) :- 
mlfeatures (nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, 
    p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, 
    p12m_numreviews, p12m_avgrating, _, _, _)'''

md.unfold_datalog(query)[0]

u'Ans(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating):-S1.mv_ml_features(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating)'

# Unit Testing

### Import Unit testing code ###

In [11]:
# set the file name (required)
__file__ = 'test_mediator.ipynb'

# add ipython magics
import ipytest.magics

import pytest

### Unit tests for unfolding queries ###

In [12]:
%%run_pytest[clean] 

def test_qry_train_nodeids():
    input_datalog = ('''Ans ( nodeid, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol,
    p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating,
    p12m_numreviews, p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) :-
 mlfeatures ( nodeid, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol,
     p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating,
     p12m_numreviews, p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) ,
     nodeid in (15, 45, 121)
''')

    output_datalog = Ans_1 = ('''Ans(nodeid,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-S1.mv_ml_features(nodeid,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),S2.mlview(nodeid,yr,mn,pm_avgsntp,p3m_avgsntp,p12m_avgsntp),nodeid in (15,45,121)''')    
    assert md.unfold_datalog(input_datalog)[0] == output_datalog



def test_dec_sales():
    input_datalog = ('''Ans ( nodeid, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales,
    p3m_vol, p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews,
    p3m_avgrating, p12m_numreviews, p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) :-
 mlfeatures ( nodeid, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales,
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews,
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) ,
     nodeid in (15, 45, 121),
     mn=12,
     yr=2015
''')

    output_datalog = ('''Ans(nodeid,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-S1.mv_ml_features(nodeid,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),S2.mlview(nodeid,yr,mn,pm_avgsntp,p3m_avgsntp,p12m_avgsntp),mn=12,yr=2015,nodeid in (15,45,121)''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_top_3_sales_xmas():
    input_datalog = ('''Ans ( nodeid, sales) :- sales_agg_mn(nodeid, mn, sales, _, _, rank_sales, _) ,
    mn=12,
    rank_sales<=3
''')

    output_datalog = ('''Ans(nodeid,sales):-S1.sales_agg_mn(nodeid,mn,sales,_,_,rank_sales,_),mn=12,rank_sales<=3''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_top_3_vol_last_yr():
    input_datalog = ('''Ans ( nodeid, sales,vol) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol) ,
    mn=12,
    rank_vol<=3
''')

    output_datalog = ('''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_top_3_sales_last_yr():
    input_datalog = ('''Ans ( nodeid, sales) :- sales_agg_yr(nodeid, yr, sales, _ , _ , rank_sales, _ ) ,
    yr = 2016,
    rank_sales<=3
''')

    output_datalog = ('''Ans(nodeid,sales):-S1.sales_agg_yr(nodeid,yr,sales,_,_,rank_sales,_),yr=2016,rank_sales<=3''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_too_many_atoms():
    input_datalog = ('''Ans ( nodeid, sales,vol) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol,_) ,
            mn=12,
            rank_vol<=3
        ''')

    output_datalog = (
        '''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')

    with pytest.raises(Exception) as e:
        assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_too_few_atoms():
    input_datalog = ('''Ans ( nodeid, sales,vol) :- sales_agg_mn(nodeid, mn, sales, vol, rank_vol) ,
            mn=12,
            rank_vol<=3
        ''')

    output_datalog = (
            '''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')

    with pytest.raises(Exception) as e:
        assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_head_att_not_in_body():
    input_datalog = ('''Ans ( nodeid, sales, vol, EXTRA_FIELD) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol) ,
            mn=12,
            rank_vol<=3
        ''')

    output_datalog = (
            '''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')

    with pytest.raises(Exception) as e:
        assert md.unfold_datalog(input_datalog)[0] == output_datalog        

        
def test_cond_att_not_in_body():
    input_datalog = ('''Ans ( nodeid, sales,vol,_,_,_) :- sales_agg_mn(nodeid, mn, sales, vol,_,_,_) , 
    mn=12, rank_vol<=3''')

    output_datalog = ('''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')    
    
    with pytest.raises(Exception) as e:
        md.unfold_datalog(input_datalog)[0] == output_datalog
    assert ('condition attribute' in str(e.value) and 'not in datalog body' in str(e.value))

        
def test_optimize_query():
    input_datalog = ('''Ans ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, p12m_avgrating) :-
 mlfeatures ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, p12m_avgrating, _, _, _) , nodeId in (15, 45, 121)''')
        
    output_datalog = ('''Ans(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating):-S1.mv_ml_features(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),nodeId in (15,45,121)''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_top_3_sales_xmas_multi_step():
    input_datalog = ('''Step1 (nodeId, sales):-order_by(sales_agg_mn (nodeId, mn, sales, _, _, _, _), [sales], [d]), mn=12.
Ans(nodeId, sales):-top(3, Step1(nodeId, sales))''')
        
    output_datalog = ('''sales_agg_mn_orderby_1(nodeId,mn,sales,_,_,_,_):-S1.sales_agg_mn(nodeId,mn,sales,_,_,_,_).Step1(nodeId,sales):-orderby(sales_agg_mn_orderby_1(nodeId,mn,sales,_,_,_,_),[sales],[d]),mn=12.Ans(nodeId,sales):-top(3,Step1(nodeId,sales))''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_top_3_sales_last_yr_multi_step():
    input_datalog = ('''Step1 (nodeId, sales) :- order_by(sales_agg_yr (nodeId, yr, sales,_,_,_,_), [sales], [d]), yr=2016.
Ans (nodeId, sales):-top(3, Step1(nodeId, sales))''')
        
    output_datalog = ('''sales_agg_yr_orderby_1(nodeId,yr,sales,_,_,_,_):-S1.sales_agg_yr(nodeId,yr,sales,_,_,_,_).Step1(nodeId,sales):-orderby(sales_agg_yr_orderby_1(nodeId,yr,sales,_,_,_,_),[sales],[d]),yr=2016.Ans(nodeId,sales):-top(3,Step1(nodeId,sales))''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_ml_features_orderby_sales():
    input_datalog = ('''Ans ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, 
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, 
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) :- 
order_by(mlfeatures ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, 
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, 
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp), [sales], [d]), yr=2016.''')
        
    output_datalog = ('''mlfeatures_orderby_1(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-S1.mv_ml_features(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),S2.mlview(nodeId,yr,mn,pm_avgsntp,p3m_avgsntp,p12m_avgsntp).Ans(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-orderby(mlfeatures_orderby_1(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp),[sales],[d]),yr=2016''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_ml_features_top3_sales_2016_multi_step():
    input_datalog = ('''Step1 ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, 
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, 
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp ) :- 
order_by(mlfeatures ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, 
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, 
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp), [sales], [d]), yr=2016.
Ans ( nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales,
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, 
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp) :- 
     top(3, Step1(nodeId, yr, mn, sales, vol, pm_sales, pm_vol, p3m_sales, p3m_vol, p12m_sales, 
     p12m_vol, pm_numreviews, pm_avgrating, p3m_numreviews, p3m_avgrating, p12m_numreviews, 
     p12m_avgrating, pm_avgsntp, p3m_avgsntp, p12m_avgsntp))''')
        
    output_datalog = ('''mlfeatures_orderby_1(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-S1.mv_ml_features(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating),S2.mlview(nodeId,yr,mn,pm_avgsntp,p3m_avgsntp,p12m_avgsntp).Step1(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-orderby(mlfeatures_orderby_1(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp),[sales],[d]),yr=2016.Ans(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp):-top(3,Step1(nodeId,yr,mn,sales,vol,pm_sales,pm_vol,p3m_sales,p3m_vol,p12m_sales,p12m_vol,pm_numreviews,pm_avgrating,p3m_numreviews,p3m_avgrating,p12m_numreviews,p12m_avgrating,pm_avgsntp,p3m_avgsntp,p12m_avgsntp))''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_group_by():
    input_datalog = ('''Ans ( nodeId, yr, mn, total_sales) :-
 group_by(sales_agg_yrmn ( nodeId, yr, mn, sales, _, _, _, _) , [nodeId, yr, mn], total_sales=sum(sales))
 ''')

    output_datalog = ('''sales_agg_yrmn_groupby_1(nodeId,yr,mn,sales,_,_,_,_):-S1.sales_agg_yrmn(nodeId,yr,mn,sales,_,_).Ans(nodeId,yr,mn,total_sales):-groupby(sales_agg_yrmn_groupby_1(nodeId,yr,mn,sales,_,_,_,_),[nodeId,yr,mn],total_sales=sum(sales))''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_order_by():
    input_datalog = ('''Ans ( nodeId, yr, mn, sales) :-
 order_by(sales_agg_yrmn ( nodeId, yr, mn, sales, _, _, _, _) , [sales], [d])
 ''')

    output_datalog = ('''sales_agg_yrmn_orderby_1(nodeId,yr,mn,sales,_,_,_,_):-S1.sales_agg_yrmn(nodeId,yr,mn,sales,_,_).Ans(nodeId,yr,mn,sales):-orderby(sales_agg_yrmn_orderby_1(nodeId,yr,mn,sales,_,_,_,_),[sales],[d])''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog


def test_top_n():
    input_datalog = ('''Ans ( nodeId, yr, mn, sales) :-
 top(3, sales_agg_yrmn ( nodeId, yr, mn, sales, _, _, _, _))
 ''')

    output_datalog = ('''sales_agg_yrmn_topn_1(nodeId,yr,mn,sales,_,_,_,_):-S1.sales_agg_yrmn(nodeId,yr,mn,sales,_,_).Ans(nodeId,yr,mn,sales):-top(3,sales_agg_yrmn_topn_1(nodeId,yr,mn,sales,_,_,_,_))''')
    assert md.unfold_datalog(input_datalog)[0] == output_datalog

platform darwin -- Python 2.7.13, pytest-3.3.0, py-1.5.2, pluggy-0.6.0
rootdir: /Users/joshwilson/Documents/DSE/2017FA-DSE203/dse203-demand-pred/final-demo/SCHEMA, inifile:
plugins: hypothesis-3.38.5
collected 17 items

test_mediator.py .................                                                                       [100%]



### Confirm correct input datalog fails

In [13]:
%%run_pytest[clean] -qq

def test_too_many_atoms():
    input_datalog = ('''Ans ( nodeid, sales, vol) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol) ,
            mn=12,
            rank_vol<=3
        ''')

    output_datalog = (
        '''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')

    with pytest.raises(Exception) as e:
        assert md.unfold_datalog(input_datalog)[0] == output_datalog

        
def test_head_att_not_in_body():
    input_datalog = ('''Ans ( nodeid, sales, vol) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol) ,
            mn=12,
            rank_vol<=3
        ''')

    output_datalog = (
            '''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')

    with pytest.raises(Exception) as e:
        md.unfold_datalog(input_datalog)[0] == output_datalog
    assert ('head attribute' in str(e.value) and 'not in datalog body' in str(e.value))

    
def test_cond_att_not_in_body():
    input_datalog = ('''Ans ( nodeid, sales,vol,_,_,_) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol) , 
    mn=12, rank_vol<=3''')

    output_datalog = ('''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')    
    
    with pytest.raises(Exception) as e:
        md.unfold_datalog(input_datalog)[0] == output_datalog
    assert ('condition attribute' in str(e.value) and 'not in datalog body' in str(e.value))


..............FFF                                                                                        [100%]
_____________________________________________ test_too_many_atoms ______________________________________________

    def test_too_many_atoms():
        input_datalog = ('''Ans ( nodeid, sales, vol) :- sales_agg_mn(nodeid, mn, sales, vol, _, _, rank_vol) ,
                mn=12,
                rank_vol<=3
            ''')
    
        output_datalog = (
            '''Ans(nodeid,sales,vol):-S1.sales_agg_mn(nodeid,mn,sales,vol,_,_,rank_vol),mn=12,rank_vol<=3''')
    
        with pytest.raises(Exception) as e:
>           assert md.unfold_datalog(input_datalog)[0] == output_datalog
E           Failed: DID NOT RAISE <type 'exceptions.Exception'>

<ipython-input-13-cc6d9e16307c>:12: Failed
__________________________________________ test_head_att_not_in_body ___________________________________________

    def test_head_att_not_in_body():
        input_datalog = ('''Ans ( nodeid