Skip to content

Commit

Permalink
Tests for the migration script
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzogil committed Mar 4, 2014
1 parent 5828320 commit a5a90e8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 25 deletions.
4 changes: 4 additions & 0 deletions yithlibraryserver/scripts/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def new_authorized_apps_collection(db):
'response_type': 'code',
}
auth.store_user_authorization(scopes, credentials)
safe_print('Storing authorized app "%s" for user %s' % (
app['client_id'],
get_user_display_name(user),
))

# remove the authorized_apps attribute from all users
db.users.update({}, {'$unset': {'authorized_apps': ''}}, multi=True)
Expand Down
124 changes: 99 additions & 25 deletions yithlibraryserver/scripts/tests/test_migrations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Yith Library Server is a password storage server.
# Copyright (C) 2013 Lorenzo Gil Sanchez <lorenzo.gil.sanchez@gmail.com>
# Copyright (C) 2013-2014 Lorenzo Gil Sanchez <lorenzo.gil.sanchez@gmail.com>
#
# This file is part of Yith Library Server.
#
Expand All @@ -19,65 +19,77 @@
import sys

from yithlibraryserver.compat import StringIO
from yithlibraryserver.oauth2.authorization import Authorizator
from yithlibraryserver.scripts.migrations import migrate
from yithlibraryserver.scripts.testing import ScriptTests


class MigrationsTests(ScriptTests):
class BaseMigrationsTests(ScriptTests):

clean_collections = ('users', 'passwords', 'applications')
def setUp(self):
super(BaseMigrationsTests, self).setUp()
self.old_args = sys.argv[:]
self.old_stdout = sys.stdout

def tearDown(self):
super(BaseMigrationsTests, self).tearDown()
# Restore sys.values
sys.argv = self.old_args
sys.stdout = self.old_stdout

def test_migrate_add_send_email_preference(self):
# Save sys values
old_args = sys.argv[:]
old_stdout = sys.stdout

# Replace sys argv and stdout
class MigrationsTests(BaseMigrationsTests):

def test_no_arguments(self):
sys.argv = []
sys.stdout = StringIO()

# Call migrate with no arguments
result = migrate()
self.assertEqual(result, 2)
stdout = sys.stdout.getvalue()
self.assertEqual(stdout, 'You must provide two arguments. The first one is the config file and the second one is the migration name.\n')

# Call migrate with a config file but no migration name
def test_no_migration_name(self):
sys.argv = ['notused', self.conf_file_path]
sys.stdout = StringIO()
result = migrate()
self.assertEqual(result, 2)
stdout = sys.stdout.getvalue()
self.assertEqual(stdout, 'You must provide two arguments. The first one is the config file and the second one is the migration name.\n')

# Call migrate with a config file and wrong migration name
def test_bad_migration_name(self):
sys.argv = ['notused', self.conf_file_path, 'bad_migration']
sys.stdout = StringIO()
result = migrate()
self.assertEqual(result, 3)
stdout = sys.stdout.getvalue()
self.assertEqual(stdout, 'The migration "bad_migration" does not exist.\n')

# Good call

class AddSendEmailPreferenceTests(BaseMigrationsTests):

clean_collections = ('users', )

def test_no_users(self):
sys.argv = ['notused', self.conf_file_path, 'add_send_email_preference']
sys.stdout = StringIO()
result = migrate()
self.assertEqual(result, None)
stdout = sys.stdout.getvalue()
self.assertEqual(stdout, '')

def test_some_users(self):
# Add some users
u1_id = self.db.users.insert({
'first_name': 'John',
'last_name': 'Doe',
'email': 'john@example.com',
})
'first_name': 'John',
'last_name': 'Doe',
'email': 'john@example.com',
})
self.db.users.insert({
'first_name': 'John2',
'last_name': 'Doe2',
'email': 'john2@example.com',
'send_passwords_periodically': False,
})
'first_name': 'John2',
'last_name': 'Doe2',
'email': 'john2@example.com',
'send_passwords_periodically': False,
})
sys.argv = ['notused', self.conf_file_path, 'add_send_email_preference']
sys.stdout = StringIO()
result = migrate()
Expand All @@ -90,6 +102,68 @@ def test_migrate_add_send_email_preference(self):
user1 = self.db.users.find_one({'_id': u1_id})
self.assertEqual(user1['send_passwords_periodically'], True)

# Restore sys.values
sys.argv = old_args
sys.stdout = old_stdout

class NewAuthorizedAppsCollectionTests(BaseMigrationsTests):

clean_collections = ('users', 'applications', 'authorized_apps')

def test_no_users(self):
sys.argv = ['notused', self.conf_file_path, 'new_authorized_apps_collection']
sys.stdout = StringIO()
result = migrate()
self.assertEqual(result, None)
stdout = sys.stdout.getvalue()
self.assertEqual(stdout, '')

def test_some_users(self):
authorizator = Authorizator(self.db)

app1_id = self.db.applications.insert({
'client_id': 'app1',
'callback_url': 'https://example.com/callback/1',
})
app2_id = self.db.applications.insert({
'client_id': 'app2',
'callback_url': 'https://example.com/callback/2',
})

u1_id = self.db.users.insert({
'first_name': 'John',
'last_name': 'Doe',
'email': 'john@example.com',
'authorized_apps': [app1_id, app2_id],
})
auths = authorizator.get_user_authorizations({'_id': u1_id})
self.assertEqual(auths.count(), 0)
u2_id = self.db.users.insert({
'first_name': 'John2',
'last_name': 'Doe2',
'email': 'john2@example.com',
'send_passwords_periodically': False,
'authorized_apps': [app1_id],
})
auths = authorizator.get_user_authorizations({'_id': u2_id})
self.assertEqual(auths.count(), 0)

sys.argv = ['notused', self.conf_file_path, 'new_authorized_apps_collection']
sys.stdout = StringIO()
result = migrate()
self.assertEqual(result, None)
stdout = sys.stdout.getvalue()
stdout = sys.stdout.getvalue()
expected_output = """Storing authorized app "app1" for user John Doe <john@example.com>
Storing authorized app "app2" for user John Doe <john@example.com>
Storing authorized app "app1" for user John2 Doe2 <john2@example.com>
"""
self.assertEqual(stdout, expected_output)

user1 = self.db.users.find_one({'_id': u1_id})
self.assertNotIn('authorized_apps', user1)
auths = authorizator.get_user_authorizations({'_id': u1_id})
self.assertEqual(auths.count(), 2)

user2 = self.db.users.find_one({'_id': u2_id})
self.assertNotIn('authorized_apps', user2)
auths = authorizator.get_user_authorizations({'_id': u2_id})
self.assertEqual(auths.count(), 1)

0 comments on commit a5a90e8

Please sign in to comment.