Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ install:
before_script:
- mysql -e 'CREATE DATABASE IF NOT EXISTS test;'
- psql -c 'create database test;' -U postgres
- touch test.db
- touch test.db test1.db
script: python tests/sql.py
after_script: rm -f test.db
jobs:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
package_dir={"": "src"},
packages=["cs50"],
url="https://github.com/cs50/python-cs50",
version="2.4.0"
version="2.4.1"
)
35 changes: 32 additions & 3 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import sqlalchemy
import sqlite3
import sqlparse
import sys
import termcolor
Expand Down Expand Up @@ -32,12 +33,25 @@ def __init__(self, url, **kwargs):
if not os.path.isfile(matches.group(1)):
raise RuntimeError("not a file: {}".format(matches.group(1)))

# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)
# Remember foreign_keys and remove it from kwargs
foreign_keys = kwargs.pop("foreign_keys", False)

# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)

# Enable foreign key constraints
if foreign_keys:
sqlalchemy.event.listen(self.engine, "connect", _connect)
else:

# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)


# Log statements to standard error
logging.basicConfig(level=logging.DEBUG)
self.logger = logging.getLogger("cs50")
disabled = self.logger.disabled

# Test database
try:
Expand All @@ -48,7 +62,7 @@ def __init__(self, url, **kwargs):
e.__cause__ = None
raise e
else:
self.logger.disabled = False
self.logger.disabled = disabled

def _parse(self, e):
"""Parses an exception, returns its message."""
Expand Down Expand Up @@ -133,6 +147,8 @@ def process(value):
return process(value)

# Allow only one statement at a time
# SQLite does not support executing many statements
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
if len(sqlparse.split(text)) > 1:
raise RuntimeError("too many statements at once")

Expand Down Expand Up @@ -211,3 +227,16 @@ def process(value):
else:
self.logger.debug(termcolor.colored(log, "green"))
return ret


# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
def _connect(dbapi_connection, connection_record):
"""Enables foreign key support."""

# Ensure backend is sqlite
if type(dbapi_connection) is sqlite3.Connection:
cursor = dbapi_connection.cursor()

# Respect foreign key constraints by default
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
14 changes: 11 additions & 3 deletions tests/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,26 @@ class SQLiteTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("sqlite:///test.db")
self.db1 = SQL("sqlite:///test1.db", foreign_keys=True)

def setUp(self):
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")

def multi_inserts_enabled(self):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used anywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, we disabled :)

def test_foreign_key_support(self):
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1)

self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None)

if __name__ == "__main__":
suite = unittest.TestSuite([
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
])
logging.getLogger("cs50.sql").disabled = True

logging.getLogger("cs50").disabled = True
sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())