Skip to content

Commit

Permalink
Merge pull request #1133 from mbmcgarry/mac_sqlite_flag
Browse files Browse the repository at this point in the history
Mac sqlite flag
  • Loading branch information
gidden committed Apr 24, 2015
2 parents 33964dd + c37a391 commit 3902446
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 132 deletions.
97 changes: 74 additions & 23 deletions tests/helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A set of tools for use in integration tests."""
import os
import platform
import sqlite3
from hashlib import sha1

import numpy as np
Expand All @@ -26,11 +28,30 @@ def clean_outs():
if os.path.exists(sqliteout):
os.remove(sqliteout)

def tables_exist(db, tables):
"""Checks if hdf5 database contains the specified tables.
def which_outfile():
"""Uses sqlite if platform is Mac, otherwise uses hdf5
"""
return all([t in db.root for t in tables])
return h5out if platform.system() == 'Linux' else sqliteout

def tables_exist(outfile, table_names):
"""Checks if output database contains the specified tables.
"""
if outfile == h5out:
f = tables.open_file(outfile, mode = "r")
res = all([t in f.root for t in table_names])
f.close()
return res
else:
table_names = [t.replace('/', '') for t in table_names]
conn = sqlite3.connect(outfile)
conn.row_factory = sqlite3.Row
cur = conn.cursor()
exc = cur.execute
res = all([bool(exc('SELECT * From sqlite_master WHERE name = ? ', \
(t, )).fetchone()) for t in table_names])
conn.close()
return res

def find_ids(data, data_table, id_table):
"""Finds ids of the specified data located in the specified data_table,
and extracts the corresponding id from the specified id_table.
Expand All @@ -47,6 +68,14 @@ def find_ids(data, data_table, id_table):
ids.append(id_table[i])
return ids


def to_ary(a, k):
if which_outfile() == sqliteout:
return np.array([x[k] for x in a])
else:
return a[k]


def exit_times(agent_id, exit_table):
"""Finds exit times of the specified agent from the exit table.
"""
Expand All @@ -59,48 +88,70 @@ def exit_times(agent_id, exit_table):

return exit_times

def agent_time_series(f, names):
def agent_time_series(names):
"""Return a list of timeseries corresponding to the number of agents in a
Cyclus simulation
Parameters
----------
f : PyTables file
the output file
outfile : the output file (hdf5 or sqlite format)
names : list
the list of agent names
"""
nsteps = f.root.Info.cols.Duration[:][0]
entries = {name: [0] * nsteps for name in names}
exits = {name: [0] * nsteps for name in names}

# Get specific tables and columns
agent_entry = f.get_node("/AgentEntry")[:]
agent_exit = f.get_node("/AgentExit")[:] if hasattr(f.root, 'AgentExit') \
else None

# Find agent ids
agent_ids = agent_entry["AgentId"]
agent_type = agent_entry["Prototype"]
if which_outfile() == h5out :
f = tables.open_file(h5out, mode = "r")
nsteps = f.root.Info.cols.Duration[:][0]
entries = {name: [0] * nsteps for name in names}
exits = {name: [0] * nsteps for name in names}

# Get specific tables and columns
agent_entry = f.get_node("/AgentEntry")[:]
agent_exit = f.get_node("/AgentExit")[:] if \
hasattr(f.root, 'AgentExit') else None

f.close()

else :
conn = sqlite3.connect(sqliteout)
conn.row_factory = sqlite3.Row
cur = conn.cursor()
exc = cur.execute

nsteps = exc('SELECT MIN(Duration) FROM Info').fetchall()[0][0]
entries = {name: [0] * nsteps for name in names}
exits = {name: [0] * nsteps for name in names}

# Get specific tables and columns
agent_entry = exc('SELECT * FROM AgentEntry').fetchall()
agent_exit = exc('SELECT * FROM AgentExit').fetchall() \
if len(exc(
("SELECT * FROM sqlite_master WHERE "
"type='table' AND name='AgentExit'")).fetchall()) > 0 \
else None

conn.close()

# Find agent id
agent_ids = to_ary(agent_entry, "AgentId")
agent_type = to_ary(agent_entry, "Prototype")
agent_ids = {name: find_ids(name, agent_type, agent_ids) for name in names}

# entries per timestep
for name, ids in agent_ids.items():
for id in ids:
idx = np.where(agent_entry['AgentId'] == id)[0][0]
idx = np.where(to_ary(agent_entry,'AgentId') == id)[0]
entries[name][agent_entry[idx]['EnterTime']] += 1

# cumulative entries
entries = {k: [sum(v[:i+1]) for i in range(len(v))] \
for k, v in entries.items()}

if agent_exit is None:
return entries

# entries per timestep
# exits per timestep
for name, ids in agent_ids.items():
for id in ids:
idxs = np.where(agent_exit['AgentId'] == id)[0]
idxs = np.where(to_ary(agent_exit,'AgentId') == id)[0]
if len(idxs) > 0:
exits[name][agent_exit[idxs[0]]['ExitTime']] += 1

Expand Down
5 changes: 3 additions & 2 deletions tests/test_include_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from tools import check_cmd
from helper import tables_exist, find_ids, exit_times, \
h5out, sqliteout, clean_outs
h5out, sqliteout, clean_outs, which_outfile

"""Tests"""
def test_include_recipe():
Expand All @@ -16,7 +16,8 @@ def test_include_recipe():
# Cyclus simulation input for recipe including
sim_input = "./input/include_recipe.xml"
holdsrtn = [1] # needed because nose does not send() to test generator
cmd = ["cyclus", "-o", h5out, "--input-file", sim_input]
outfile = which_outfile()
cmd = ["cyclus", "-o", outfile, "--input-file", sim_input]
yield check_cmd, cmd, '.', holdsrtn
rtn = holdsrtn[0]
if rtn != 0:
Expand Down
27 changes: 14 additions & 13 deletions tests/test_lotka_volterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import hashlib

from tools import check_cmd
from helper import tables_exist, h5out, clean_outs, agent_time_series
from helper import tables_exist, clean_outs, agent_time_series, \
h5out, sqliteout, which_outfile

prey = "Prey"
pred = "Predator"
Expand All @@ -26,15 +27,16 @@ def test_predator_only():
sim_input = "./input/predator.xml"

holdsrtn = [1] # needed because nose does not send() to test generator
cmd = ["cyclus", "-o", h5out, "--input-file", sim_input]
outfile = which_outfile()

cmd = ["cyclus", "-o", outfile, "--input-file", sim_input]
yield check_cmd, cmd, '.', holdsrtn
rtn = holdsrtn[0]

print("Confirming valid Cyclus execution.")
assert_equal(rtn, 0)

output = tables.open_file(h5out, mode = "r")
series = agent_time_series(output, [prey, pred])
series = agent_time_series([prey, pred])
print("Prey:", series[prey], "Predators:", series[pred])

prey_exp = [0 for n in range(10)]
Expand All @@ -43,7 +45,6 @@ def test_predator_only():
assert_equal(series[prey], prey_exp)
assert_equal(series[pred], pred_exp)

output.close()
clean_outs()

def test_prey_only():
Expand All @@ -54,15 +55,16 @@ def test_prey_only():
clean_outs()
sim_input = "./input/prey.xml"
holdsrtn = [1] # needed because nose does not send() to test generator
cmd = ["cyclus", "-o", h5out, "--input-file", sim_input]
outfile = which_outfile()

cmd = ["cyclus", "-o", outfile, "--input-file", sim_input]
yield check_cmd, cmd, '.', holdsrtn
rtn = holdsrtn[0]

print("Confirming valid Cyclus execution.")
assert_equal(rtn, 0)

output = tables.open_file(h5out, mode = "r")
series = agent_time_series(output, [prey, pred])
series = agent_time_series([prey, pred])
print("Prey:", series[prey], "Predators:", series[pred])

prey_exp = [2**n for n in range(10)]
Expand All @@ -71,7 +73,6 @@ def test_prey_only():
assert_equal(series[prey], prey_exp)
assert_equal(series[pred], pred_exp)

output.close()
clean_outs()

def test_lotka_volterra():
Expand All @@ -91,15 +92,16 @@ def test_lotka_volterra():
clean_outs()
sim_input = "./input/lotka_volterra_determ.xml"
holdsrtn = [1] # needed because nose does not send() to test generator
cmd = ["cyclus", "-o", h5out, "--input-file", sim_input]
outfile = which_outfile()

cmd = ["cyclus", "-o", outfile, "--input-file", sim_input]
yield check_cmd, cmd, '.', holdsrtn
rtn = holdsrtn[0]

print("Confirming valid Cyclus execution.")
assert_equal(rtn, 0)

output = tables.open_file(h5out, mode = "r")
series = agent_time_series(output, [prey, pred])
series = agent_time_series([prey, pred])
print("Prey:", series[prey], "Predators:", series[pred])

prey_max = series[prey].index(max(series[prey]))
Expand All @@ -108,7 +110,6 @@ def test_lotka_volterra():

assert_true(prey_max < pred_max)

output.close()
clean_outs()

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 3902446

Please sign in to comment.