diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 651f863d89..4b5f762786 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -37,6 +37,7 @@ jobs: mongodb-version: 4.4 - name: Run tests run: | + pip install mypy python setup.py test mypytest: @@ -59,4 +60,4 @@ jobs: - name: Run mypy run: | mypy --install-types --non-interactive bson gridfs tools pymongo - mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test + mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index --exclude "test/mypy_fails/*.*" test diff --git a/test/mypy_fails/insert_many_dict.py b/test/mypy_fails/insert_many_dict.py new file mode 100644 index 0000000000..6e8acb67b4 --- /dev/null +++ b/test/mypy_fails/insert_many_dict.py @@ -0,0 +1,6 @@ +from pymongo import MongoClient + +client = MongoClient() +client.test.test.insert_many( + {"a": 1} +) # error: Dict entry 0 has incompatible type "str": "int"; expected "Mapping[str, Any]": "int" diff --git a/test/mypy_fails/insert_one_list.py b/test/mypy_fails/insert_one_list.py new file mode 100644 index 0000000000..7a26a3ff79 --- /dev/null +++ b/test/mypy_fails/insert_one_list.py @@ -0,0 +1,6 @@ +from pymongo import MongoClient + +client = MongoClient() +client.test.test.insert_one( + [{}] +) # error: Argument 1 to "insert_one" of "Collection" has incompatible type "List[Dict[, ]]"; expected "Mapping[str, Any]" diff --git a/test/test_bson.py b/test/test_bson.py index f8f587567d..46aa6e5d9a 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -1117,6 +1117,16 @@ def test_int64_pickling(self): ) self.round_trip_pickle(i64, pickled_with_3) + def test_bson_encode_decode(self) -> None: + doc = {"_id": ObjectId()} + encoded = bson.encode(doc) + decoded = bson.decode(encoded) + encoded = bson.encode(decoded) + decoded = bson.decode(encoded) + # Documents returned from decode are mutable. + decoded["new_field"] = 1 + self.assertTrue(decoded["_id"].generation_time) + if __name__ == "__main__": unittest.main() diff --git a/test/test_mypy.py b/test/test_mypy.py new file mode 100644 index 0000000000..0f1498c64b --- /dev/null +++ b/test/test_mypy.py @@ -0,0 +1,125 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that each file in mypy_fails/ actually fails mypy, and test some +sample client code that uses PyMongo typings.""" + +import os +import sys +import unittest +from typing import Any, Dict, Iterable, List + +try: + from mypy import api +except ImportError: + api = None + +from bson.son import SON +from pymongo.collection import Collection +from pymongo.errors import ServerSelectionTimeoutError +from pymongo.mongo_client import MongoClient +from pymongo.operations import InsertOne + +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails") + + +def get_tests() -> Iterable[str]: + for dirpath, _, filenames in os.walk(TEST_PATH): + for filename in filenames: + yield os.path.join(dirpath, filename) + + +class TestMypyFails(unittest.TestCase): + def ensure_mypy_fails(self, filename: str) -> None: + if api is None: + raise unittest.SkipTest("Mypy is not installed") + stdout, stderr, exit_status = api.run([filename]) + self.assertTrue(exit_status, msg=stdout) + + def test_mypy_failures(self) -> None: + for filename in get_tests(): + with self.subTest(filename=filename): + self.ensure_mypy_fails(filename) + + +class TestPymongo(unittest.TestCase): + client: MongoClient + coll: Collection + + @classmethod + def setUpClass(cls) -> None: + cls.client = MongoClient(serverSelectionTimeoutMS=250, directConnection=False) + cls.coll = cls.client.test.test + try: + cls.client.admin.command("ping") + except ServerSelectionTimeoutError as exc: + raise unittest.SkipTest(f"Could not connect to MongoDB: {exc}") + + @classmethod + def tearDownClass(cls) -> None: + cls.client.close() + + def test_insert_find(self) -> None: + doc = {"my": "doc"} + coll2 = self.client.test.test2 + result = self.coll.insert_one(doc) + self.assertEqual(result.inserted_id, doc["_id"]) + retreived = self.coll.find_one({"_id": doc["_id"]}) + if retreived: + # Documents returned from find are mutable. + retreived["new_field"] = 1 + result2 = coll2.insert_one(retreived) + self.assertEqual(result2.inserted_id, result.inserted_id) + + def test_cursor_iterable(self) -> None: + def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + return list(iterable) + + self.coll.insert_one({}) + cursor = self.coll.find() + docs = to_list(cursor) + self.assertTrue(docs) + + def test_bulk_write(self) -> None: + self.coll.insert_one({}) + requests = [InsertOne({})] + result = self.coll.bulk_write(requests) + self.assertTrue(result.acknowledged) + + def test_aggregate_pipeline(self) -> None: + coll3 = self.client.test.test3 + coll3.insert_many( + [ + {"x": 1, "tags": ["dog", "cat"]}, + {"x": 2, "tags": ["cat"]}, + {"x": 2, "tags": ["mouse", "cat", "dog"]}, + {"x": 3, "tags": []}, + ] + ) + + class mydict(Dict[str, Any]): + pass + + result = coll3.aggregate( + [ + mydict({"$unwind": "$tags"}), + {"$group": {"_id": "$tags", "count": {"$sum": 1}}}, + {"$sort": SON([("count", -1), ("_id", -1)])}, + ] + ) + self.assertTrue(len(list(result))) + + +if __name__ == "__main__": + unittest.main()