# Jupyter Notebook für Test Cases

In [None]:
%run ../../Setup.ipynb

import unittest
import subprocess
import pymongo
import json
import pickle
from bdcc.database.connection.database_connector import DatabaseConnector, get_matches, save_model, load_model
from sklearn.neural_network import MLPClassifier
from api_functions.api_functions import write_match, delete_match

In [None]:
class TestGetMatches(unittest.TestCase):
    def test_get_all_matches(self):
        test_list = get_matches()
        test_call = get_matches()
        self.assertListEqual(test_call, test_list)
        self.assertEqual(test_call, test_list)
        self.assertEqual(len(test_call), len(test_list))

    def test_get_one_match_index_0(self):
        test_list = get_matches(num_matches=1)
        test_call = get_matches(num_matches=1)
        self.assertListEqual(test_call, test_list)
        self.assertEqual(test_call, test_list)
        self.assertEqual(len(test_call), 1)
    
    def test_get_one_match_index_10(self):
        test_list = get_matches(start_id=10, num_matches=1)
        test_call = get_matches(start_id=10, num_matches=1)
        self.assertListEqual(test_call, test_list)
        self.assertEqual(test_call, test_list)
        self.assertEqual(len(test_call), 1)
    
    def test_get_10_matches(self):
        test_list = get_matches(num_matches=10)
        test_call = get_matches(num_matches=10)
        self.assertListEqual(test_call, test_list)
        self.assertEqual(test_call, test_list)
        self.assertEqual(len(test_call), 10)
    
    def test_get_matches_return_type(self):
        test_call = get_matches(num_matches=1)
        self.assertIsInstance(test_call, list)
        self.assertIsInstance(test_call[0], dict)

In [None]:
class TestDatabaseConnector(unittest.TestCase):
    def setUp(self):
        self.db_con = DatabaseConnector()
    
    def test_DataBaseConnector_init(self):
        db_con_temp = DatabaseConnector()
        urls = ['localhost', 'mongodb-ki']
        # url
        self.assertTrue(db_con_temp.url in urls)
        db_con_temp.url = '127.0.0.1'
        self.assertEqual(db_con_temp.url, '127.0.0.1')
        # port
        self.assertEqual(db_con_temp.port, 27018)
        db_con_temp.port = 27017
        self.assertEqual(db_con_temp.port, 27017)
        # database
        self.assertEqual(db_con_temp.database, 'Data-KI')
        db_con_temp.database = 'Data'
        self.assertEqual(db_con_temp.database, 'Data')
        # collection
        self.assertEqual(db_con_temp.collection, 'MatchesDto')
        db_con_temp.collection = 'Matches'
        self.assertEqual(db_con_temp.collection, 'Matches')

    def test_DataBaseConnector_connect(self):
        test_call = self.db_con.connect()
        self.assertIsInstance(test_call, type(self.db_con))
        self.assertIsInstance(test_call.client, pymongo.MongoClient)
        self.assertIsInstance(test_call.db, pymongo.database.Database)
        self.assertIsInstance(test_call.connection, type(self.db_con))
        self.assertIsInstance(test_call.active_collection, pymongo.collection.Collection)
    
    def test_DataBaseConnector_disconnect(self):
        # client nicht verbunden
        self.assertRaises(ConnectionError, lambda: self.db_con.disconnect())
        test_call = self.db_con.connect()
        self.assertIsInstance(test_call, type(self.db_con))
    
    def test_DataBaseConnector_getCollectionNames(self):
        self.db_con.connect()
        test_call = self.db_con.getCollectionNames()
        self.assertIsInstance(test_call, list)

    def test_DataBaseConnector_getCollection(self):
        self.db_con.connect()
        test_call = self.db_con.getCollection('MatchesDto')
        self.assertIsInstance(test_call, pymongo.collection.Collection)

    def test_DataBaseConnector_create(self):
        self.db_con.connect()
        match = json.loads("{\"match_id\": 123}")
        self.assertTrue(self.db_con.create(match).acknowledged)
        loaded_match = self.db_con.get(id=123)[0]
        self.assertEqual(match['match_id'], loaded_match['match_id'])
        self.db_con.disconnect()
        delete_match(123)
    
    def test_DataBaseConnector_get(self):
        match = json.loads("{\"match_id\": 123}")
        write_match(match)
        self.db_con.connect()
        test_call_all = self.db_con.get()
        test_call_one = self.db_con.get(id=123)
        self.assertIsInstance(test_call_all, pymongo.cursor.Cursor)
        self.assertGreaterEqual(len(list(test_call_all)), 1)
        self.assertIsInstance(test_call_one, pymongo.cursor.Cursor)
        self.assertEqual(len(list(test_call_one)), 1)
        delete_match(123)
    
    def test_DataBaseConnector_update(self):
        self.db_con.connect()
        match = json.loads("{\"match_id\": 123, \"radiant_win\": true}")
        new_match = json.loads("{\"match_id\": 123, \"radiant_win\": false}")
        write_match(match)
        self.db_con.update(123, new_match)
        updated_match = self.db_con.get(123)[0]
        self.assertEqual(updated_match['radiant_win'], new_match['radiant_win'])
        delete_match(123)

    def test_DataBaseConnector_remove(self):
        self.db_con.connect()
        match = json.loads("{\"match_id\": 123, \"radiant_win\": true}")
        write_match(match)
        self.db_con.remove(123)
        self.assertRaises(IndexError, lambda: self.db_con.get(123)[0])

    def test_DataBaseConnector_save_model(self):
        db_con = DatabaseConnector(collection="Models")
        db_con.connect()
        model_name = "kda"
        model = MLPClassifier() # define model
        pickled_model = pickle.dumps(model) 
        self.assertTrue(db_con.save_model(pickled_model, model_name).acknowledged)
        # load model from database
        models = list(db_con.get())
        db_con.disconnect()
        # load models
        loaded_model = [pickle.loads(model['model']) for model in models if model['name'] == model_name][0]
        self.assertEqual(type(model), type(loaded_model))
        self.assertEqual(model.get_params(), loaded_model.get_params())


    def test_DataBaseConnector_update_model(self):
        db_con = DatabaseConnector(collection="Models")
        db_con.connect()
        model_name = "kda"
        model = MLPClassifier() # define model
        pickled_model = pickle.dumps(model) 
        db_con.save_model(pickled_model, model_name)
        # define updated model
        updated_model = MLPClassifier(activation='tanh')
        updated_pickled_model = pickle.dumps(updated_model) 
        db_con.update_model(updated_pickled_model, model_name)
        # load model from database
        models = list(db_con.get())
        db_con.disconnect()
        # load models
        loaded_model = [pickle.loads(model['model']) for model in models if model['name'] == model_name][0]
        self.assertEqual(type(updated_model), type(loaded_model))
        self.assertEqual(updated_model.get_params(), loaded_model.get_params())
        

In [None]:
class TestSaveModel(unittest.TestCase):
    def setUp(self):
        model_name_kda = "kda"
        model_name_no_kda = "no_kda"
        self.model = MLPClassifier() # define model
        save_model(self.model, model_name_kda)
        save_model(self.model, model_name_no_kda)
        # load model from database
        db_con = DatabaseConnector(collection="Models")
        db_con.connect()
        models = list(db_con.get())
        db_con.disconnect()
        # load models
        self.loaded_model_kda = [pickle.loads(model['model']) for model in models if model['name'] == model_name_kda][0]
        self.loaded_model_no_kda = [pickle.loads(model['model']) for model in models if model['name'] == model_name_no_kda][0]
    
    def test_save_model_kda(self):
        self.assertEqual(type(self.model), type(self.loaded_model_kda))
        self.assertEqual(self.model.get_params(), self.loaded_model_kda.get_params())

    def test_save_model_no_kda(self):
        self.assertEqual(type(self.model), type(self.loaded_model_no_kda))
        self.assertEqual(self.model.get_params(), self.loaded_model_no_kda.get_params())

    def test_save_model_invalid(self):
        model_name = "invalid_name"
        self.assertRaises(NameError, lambda: save_model(self.model, model_name))

In [None]:
class TestLoadModel(unittest.TestCase):
    def setUp(self):
        self.model_name_kda = "kda"
        self.model_name_no_kda = "no_kda"
        self.model = MLPClassifier() # define model
        save_model(self.model, self.model_name_kda)
        save_model(self.model, self.model_name_no_kda)

    def test_load_model_kda(self):
        loaded_model_kda = load_model(self.model_name_kda)
        self.assertEqual(type(self.model), type(loaded_model_kda))
        self.assertEqual(self.model.get_params(), loaded_model_kda.get_params())

    def test_load_model_no_kda(self):
        loaded_model_no_kda = load_model(self.model_name_no_kda)
        self.assertEqual(type(self.model), type(loaded_model_no_kda))
        self.assertEqual(self.model.get_params(), loaded_model_no_kda.get_params())

    def test_load_model_fail(self):
        model_name = "invalid_name"
        self.assertRaises(NameError, lambda: load(self.model, model_name))

In [None]:
class TestWriteMatch(unittest.TestCase):
    def setUp(self):
        self.match_id = 123
        self.match = json.loads("{\"match_id\": 123}")
        self.http_status_created = 201
        self.http_status_conflicted = 409

    def test_write_match_success(self):
        self.assertEqual(write_match(self.match), self.http_status_created)
        delete_match(self.match_id)

    def test_write_match_duplicate(self):
        http_status = write_match(self.match)
        self.assertEqual(http_status, self.http_status_created)
        self.assertEqual(write_match(self.match), self.http_status_conflicted)
        delete_match(self.match_id)

In [None]:
class TestDeleteMatch(unittest.TestCase):
    def setUp(self):
        self.match_id = 123
        self.match = json.loads("{\"match_id\": 123}")
        self.http_status_no_content = 204
        self.http_status_reset_content = 205
        write_match(self.match)

    def test_delete_match_success(self):
        self.assertEqual(delete_match(self.match_id), self.http_status_reset_content)
    def test_delete_match_no_existing_match(self):
        invalid_id = 0
        self.assertEqual(delete_match(invalid_id), self.http_status_no_content)
        

In [None]:
def suite():
    """Funktion zum Erstellen einer Test Suite für die Ausführung der Unittests."""
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestGetMatches, 'test'))
    suite.addTest(unittest.makeSuite(TestDatabaseConnector, 'test'))
    suite.addTest(unittest.makeSuite(TestSaveModel, 'test'))
    suite.addTest(unittest.makeSuite(TestLoadModel, 'test'))
    suite.addTest(unittest.makeSuite(TestWriteMatch, 'test'))
    suite.addTest(unittest.makeSuite(TestDeleteMatch, 'test'))
    return suite

In [None]:
def execute_unittests():
    """Funktion zum Ausführen der Unittests."""
    unittest.main(argv=[''], verbosity=2, exit=False, defaultTest='suite')