Skip to content

Commit

Permalink
New dataset, along with linq unit tests for JOINs
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed May 9, 2013
1 parent ddbfadc commit 5c39dd3
Show file tree
Hide file tree
Showing 3 changed files with 5,432 additions and 481 deletions.
185 changes: 90 additions & 95 deletions macropy/macros2/linq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@
from macropy.macros2.linq import macros, sql, generate_schema

engine = create_engine("sqlite://")
for line in open("macros2/linq_test_dataset.sql").read().split(";"):

for line in open("macros2/world.sql").read().split(";"):
engine.execute(line.strip())

db = generate_schema(engine)

def compare_queries(query1, query2):
def compare_queries(query1, query2, post_process=lambda x: x):
res1 = engine.execute(query1).fetchall()
res2 = engine.execute(query2).fetchall()
try:
assert res1 == res2
assert post_process(res1) == post_process(res2)
except Exception, e:
print "FAILURE"
print e
print query1
print res1
print "\n".join(map(str, post_process(res1)))
print query2
print res2
print "\n".join(map(str, post_process(res2)))
raise e

class Tests(unittest.TestCase):
Expand All @@ -29,172 +30,166 @@ class Tests(unittest.TestCase):
http://sqlzoo.net/wiki/Main_Page
"""
def test_basic(self):
# all countries in europe
compare_queries(
"SELECT name FROM bbc WHERE region = 'Europe'",
sql%(x.name for x in db.bbc if x.region == 'Europe')
"SELECT name FROM country WHERE continent = 'Europe'",
sql%(x.name for x in db.country if x.continent == 'Europe')
)
# countries whose area is bigger than 10000000
compare_queries(
"SELECT name, area FROM bbc WHERE area > 10000000",
sql%((x.name, x.area) for x in db.bbc if x.area > 10000000)
"SELECT name, surface_area FROM country WHERE surface_area > 10000000",
sql%((x.name, x.surface_area) for x in db.country if x.surface_area > 10000000)
)

def test_nested(self):
compare_queries(
"""
SELECT name FROM bbc
WHERE population > (
SELECT population FROM bbc
WHERE name='Russia'
)
""",
sql%(
x.name for x in db.bbc if x.population > (
y.population for y in db.bbc if y.name == 'Russia'
)
)
)

# countries on the same continent as India or Iran
compare_queries(
"""
SELECT name, region FROM bbc
WHERE region IN (
SELECT region FROM bbc
SELECT name, continent FROM country
WHERE continent IN (
SELECT continent FROM country
WHERE name IN ('India', 'Iran')
)
""",
sql%(
(x.name, x.region) for x in db.bbc
if x.region in (
y.region for y in db.bbc
(x.name, x.continent) for x in db.country
if x.continent in (
y.continent for y in db.country
if y.name in ['India', 'Iran']
)
)
)

# countries in the same continent as Belize or Belgium
compare_queries(
"""
SELECT w.name, w.region
FROM bbc w
WHERE w.region in (
SELECT z.region
FROM bbc z
SELECT w.name, w.continent
FROM country w
WHERE w.continent in (
SELECT z.continent
FROM country z
WHERE z.name = 'Belize' OR z.name = 'Belgium'
)
""",
sql%(
(c.name, c.region) for c in db.bbc
if c.region in (
x.region for x in db.bbc
(c.name, c.continent) for c in db.country
if c.continent in (
x.continent for x in db.country
if (x.name == 'Belize') | (x.name == 'Belgium')
)
)
)

def test_operators(self):
# countries in europe with a DNP per capita larger than the UK
compare_queries(
"""
SELECT name FROM bbc
WHERE gdp/population > (
SELECT gdp/population FROM bbc
SELECT name FROM country
WHERE gnp/population > (
SELECT gnp/population FROM country
WHERE name = 'United Kingdom'
)
AND region = 'Europe'
AND continent = 'Europe'
""",
sql%(
x.name for x in db.bbc
if x.gdp / x.population > (
y.gdp / y.population for y in db.bbc
x.name for x in db.country
if x.gnp / x.population > (
y.gnp / y.population for y in db.country
if y.name == 'United Kingdom'
)
if (x.region == 'Europe')
if (x.continent == 'Europe')
)
)

def test_aggregate(self):
# the population of the world
compare_queries(
"SELECT SUM(population) FROM bbc",
sql%(func.sum(x.population) for x in db.bbc)
"SELECT SUM(population) FROM country",
sql%(func.sum(x.population) for x in db.country)
)

def test_aliased(self):
# number of countries whose area is at least 1000000
compare_queries(
"select count(*) from bbc where area >= 1000000",
sql%(func.count(x.name) for x in db.bbc if x.area >= 1000000)
"select count(*) from country where surface_area >= 1000000",
sql%(func.count(x.name) for x in db.country if x.surface_area >= 1000000)
)

def test_aliased(self):

# continents whose total population is greater than 100000000
compare_queries(
"""
SELECT DISTINCT(x.region)
FROM bbc x
SELECT DISTINCT(x.continent)
FROM country x
WHERE 100000000 < (
SELECT SUM(w.population)
from bbc w
WHERE w.region = x.region
from country w
WHERE w.continent = x.continent
)
""",
sql%(
func.distinct(x.region) for x in db.bbc
func.distinct(x.continent) for x in db.country
if (
func.sum(w.population) for w in db.bbc
if w.region == x.region
func.sum(w.population) for w in db.country
if w.continent == x.continent
) > 100000000
)
)

def test_query_macro(self):
query = sql%(
func.distinct(x.region) for x in db.bbc
func.distinct(x.continent) for x in db.country
if (
func.sum(w.population) for w in db.bbc
if w.region == x.region
func.sum(w.population) for w in db.country
if w.continent == x.continent
) > 100000000
)
sql_results = engine.execute(query).fetchall()
query_macro_results = query%(
func.distinct(x.region) for x in db.bbc
func.distinct(x.continent) for x in db.country
if (
func.sum(w.population) for w in db.bbc
if w.region == x.region
func.sum(w.population) for w in db.country
if w.continent == x.continent
) > 100000000
)
assert sql_results == query_macro_results


def test_join(self):
# names of all cities in Asia
compare_queries(
"""
SELECT name
FROM movie m
JOIN actor a
JOIN casting c
WHERE m.title = 'Casablanca'
AND m.id = c.movieid
AND a.id = c.actorid
SELECT COUNT(t.name)
FROM country c
JOIN city t
ON (t.country_code = c.code)
WHERE c.continent = 'Asia'
""",
sql%(
a.name
for m in db.movie
for a in db.actor
for c in db.casting
if m.title == 'Casablanca'
if m.id == c.movieid
if a.id == c.actorid
func.count(t.name)
for c in db.country
for t in db.city
if t.country_code == c.code
if c.continent == 'Asia'
)
)

(
"""
SELECT mm.title
FROM movie mm
JOIN actor aa
JOIN casting cc
WHERE mm.id = cc.movieid
AND aa.id = cc.actorid
AND mm.title IN (
SELECT m.title
FROM movie m
JOIN actor a
JOIN casting c
WHERE m.id = c.movieid
AND a.id = c.actorid
AND a.name = 'Julie Andrews'
)
# name and population for each country and city where the city's
# population is more than half the country's
compare_queries(
"""
SELECT t.name, t.population, c.name, c.population
FROM country c
JOIN city t
ON t.country_code = c.code
WHERE t.population > c.population / 2
""",
sql%(
(t.name, t.population, c.name, c.population)
for c in db.country
for t in db.city
if t.country_code == c.code
if t.population > c.population / 2
),
lambda x: sorted(map(str, x))
)
Loading

0 comments on commit 5c39dd3

Please sign in to comment.