Permalink
Browse files

New dataset, along with linq unit tests for JOINs

  • Loading branch information...
1 parent ddbfadc commit 5c39dd3b9334a4fdff79f9fe1ff9d5bd0c71d010 @lihaoyi committed May 9, 2013
Showing with 5,432 additions and 481 deletions.
  1. +90 −95 macropy/macros2/linq_test.py
  2. +0 −386 macropy/macros2/linq_test_dataset.sql
  3. +5,342 −0 macropy/macros2/world.sql
@@ -4,23 +4,24 @@
from macropy.macros2.linq import macros, sql, generate_schema
engine = create_engine("sqlite://")
-for line in open("macros2/linq_test_dataset.sql").read().split(";"):
+
+for line in open("macros2/world.sql").read().split(";"):
engine.execute(line.strip())
db = generate_schema(engine)
-def compare_queries(query1, query2):
+def compare_queries(query1, query2, post_process=lambda x: x):
res1 = engine.execute(query1).fetchall()
res2 = engine.execute(query2).fetchall()
try:
- assert res1 == res2
+ assert post_process(res1) == post_process(res2)
except Exception, e:
print "FAILURE"
print e
print query1
- print res1
+ print "\n".join(map(str, post_process(res1)))
print query2
- print res2
+ print "\n".join(map(str, post_process(res2)))
raise e
class Tests(unittest.TestCase):
@@ -29,172 +30,166 @@ class Tests(unittest.TestCase):
http://sqlzoo.net/wiki/Main_Page
"""
def test_basic(self):
+ # all countries in europe
compare_queries(
- "SELECT name FROM bbc WHERE region = 'Europe'",
- sql%(x.name for x in db.bbc if x.region == 'Europe')
+ "SELECT name FROM country WHERE continent = 'Europe'",
+ sql%(x.name for x in db.country if x.continent == 'Europe')
)
+ # countries whose area is bigger than 10000000
compare_queries(
- "SELECT name, area FROM bbc WHERE area > 10000000",
- sql%((x.name, x.area) for x in db.bbc if x.area > 10000000)
+ "SELECT name, surface_area FROM country WHERE surface_area > 10000000",
+ sql%((x.name, x.surface_area) for x in db.country if x.surface_area > 10000000)
)
def test_nested(self):
- compare_queries(
- """
- SELECT name FROM bbc
- WHERE population > (
- SELECT population FROM bbc
- WHERE name='Russia'
- )
- """,
- sql%(
- x.name for x in db.bbc if x.population > (
- y.population for y in db.bbc if y.name == 'Russia'
- )
- )
- )
+ # countries on the same continent as India or Iran
compare_queries(
"""
- SELECT name, region FROM bbc
- WHERE region IN (
- SELECT region FROM bbc
+ SELECT name, continent FROM country
+ WHERE continent IN (
+ SELECT continent FROM country
WHERE name IN ('India', 'Iran')
)
""",
sql%(
- (x.name, x.region) for x in db.bbc
- if x.region in (
- y.region for y in db.bbc
+ (x.name, x.continent) for x in db.country
+ if x.continent in (
+ y.continent for y in db.country
if y.name in ['India', 'Iran']
)
)
)
+
+ # countries in the same continent as Belize or Belgium
compare_queries(
"""
- SELECT w.name, w.region
- FROM bbc w
- WHERE w.region in (
- SELECT z.region
- FROM bbc z
+ SELECT w.name, w.continent
+ FROM country w
+ WHERE w.continent in (
+ SELECT z.continent
+ FROM country z
WHERE z.name = 'Belize' OR z.name = 'Belgium'
)
""",
sql%(
- (c.name, c.region) for c in db.bbc
- if c.region in (
- x.region for x in db.bbc
+ (c.name, c.continent) for c in db.country
+ if c.continent in (
+ x.continent for x in db.country
if (x.name == 'Belize') | (x.name == 'Belgium')
)
)
)
def test_operators(self):
+ # countries in europe with a DNP per capita larger than the UK
compare_queries(
"""
- SELECT name FROM bbc
- WHERE gdp/population > (
- SELECT gdp/population FROM bbc
+ SELECT name FROM country
+ WHERE gnp/population > (
+ SELECT gnp/population FROM country
WHERE name = 'United Kingdom'
)
- AND region = 'Europe'
+ AND continent = 'Europe'
""",
sql%(
- x.name for x in db.bbc
- if x.gdp / x.population > (
- y.gdp / y.population for y in db.bbc
+ x.name for x in db.country
+ if x.gnp / x.population > (
+ y.gnp / y.population for y in db.country
if y.name == 'United Kingdom'
)
- if (x.region == 'Europe')
+ if (x.continent == 'Europe')
)
)
def test_aggregate(self):
+ # the population of the world
compare_queries(
- "SELECT SUM(population) FROM bbc",
- sql%(func.sum(x.population) for x in db.bbc)
+ "SELECT SUM(population) FROM country",
+ sql%(func.sum(x.population) for x in db.country)
)
-
- def test_aliased(self):
+ # number of countries whose area is at least 1000000
compare_queries(
- "select count(*) from bbc where area >= 1000000",
- sql%(func.count(x.name) for x in db.bbc if x.area >= 1000000)
+ "select count(*) from country where surface_area >= 1000000",
+ sql%(func.count(x.name) for x in db.country if x.surface_area >= 1000000)
)
+
+ def test_aliased(self):
+
+ # continents whose total population is greater than 100000000
compare_queries(
"""
- SELECT DISTINCT(x.region)
- FROM bbc x
+ SELECT DISTINCT(x.continent)
+ FROM country x
WHERE 100000000 < (
SELECT SUM(w.population)
- from bbc w
- WHERE w.region = x.region
+ from country w
+ WHERE w.continent = x.continent
)
""",
sql%(
- func.distinct(x.region) for x in db.bbc
+ func.distinct(x.continent) for x in db.country
if (
- func.sum(w.population) for w in db.bbc
- if w.region == x.region
+ func.sum(w.population) for w in db.country
+ if w.continent == x.continent
) > 100000000
)
)
def test_query_macro(self):
query = sql%(
- func.distinct(x.region) for x in db.bbc
+ func.distinct(x.continent) for x in db.country
if (
- func.sum(w.population) for w in db.bbc
- if w.region == x.region
+ func.sum(w.population) for w in db.country
+ if w.continent == x.continent
) > 100000000
)
sql_results = engine.execute(query).fetchall()
query_macro_results = query%(
- func.distinct(x.region) for x in db.bbc
+ func.distinct(x.continent) for x in db.country
if (
- func.sum(w.population) for w in db.bbc
- if w.region == x.region
+ func.sum(w.population) for w in db.country
+ if w.continent == x.continent
) > 100000000
)
assert sql_results == query_macro_results
+
def test_join(self):
+ # names of all cities in Asia
compare_queries(
"""
- SELECT name
- FROM movie m
- JOIN actor a
- JOIN casting c
- WHERE m.title = 'Casablanca'
- AND m.id = c.movieid
- AND a.id = c.actorid
+ SELECT COUNT(t.name)
+ FROM country c
+ JOIN city t
+ ON (t.country_code = c.code)
+ WHERE c.continent = 'Asia'
""",
sql%(
- a.name
- for m in db.movie
- for a in db.actor
- for c in db.casting
- if m.title == 'Casablanca'
- if m.id == c.movieid
- if a.id == c.actorid
+ func.count(t.name)
+ for c in db.country
+ for t in db.city
+ if t.country_code == c.code
+ if c.continent == 'Asia'
)
)
- (
- """
- SELECT mm.title
- FROM movie mm
- JOIN actor aa
- JOIN casting cc
- WHERE mm.id = cc.movieid
- AND aa.id = cc.actorid
- AND mm.title IN (
- SELECT m.title
- FROM movie m
- JOIN actor a
- JOIN casting c
- WHERE m.id = c.movieid
- AND a.id = c.actorid
- AND a.name = 'Julie Andrews'
- )
+ # name and population for each country and city where the city's
+ # population is more than half the country's
+ compare_queries(
"""
+ SELECT t.name, t.population, c.name, c.population
+ FROM country c
+ JOIN city t
+ ON t.country_code = c.code
+ WHERE t.population > c.population / 2
+ """,
+ sql%(
+ (t.name, t.population, c.name, c.population)
+ for c in db.country
+ for t in db.city
+ if t.country_code == c.code
+ if t.population > c.population / 2
+ ),
+ lambda x: sorted(map(str, x))
)
Oops, something went wrong.

0 comments on commit 5c39dd3

Please sign in to comment.