88python binary_search_tree_recursive.py
99"""
1010import unittest
11+ from typing import Iterator , Optional
1112
1213
1314class Node :
14- def __init__ (self , label : int , parent ) :
15+ def __init__ (self , label : int , parent : Optional [ "Node" ]) -> None :
1516 self .label = label
1617 self .parent = parent
17- self .left = None
18- self .right = None
18+ self .left : Optional [ Node ] = None
19+ self .right : Optional [ Node ] = None
1920
2021
2122class BinarySearchTree :
22- def __init__ (self ):
23- self .root = None
23+ def __init__ (self ) -> None :
24+ self .root : Optional [ Node ] = None
2425
25- def empty (self ):
26+ def empty (self ) -> None :
2627 """
2728 Empties the tree
2829
@@ -46,7 +47,7 @@ def is_empty(self) -> bool:
4647 """
4748 return self .root is None
4849
49- def put (self , label : int ):
50+ def put (self , label : int ) -> None :
5051 """
5152 Put a new node in the tree
5253
@@ -65,7 +66,9 @@ def put(self, label: int):
6566 """
6667 self .root = self ._put (self .root , label )
6768
68- def _put (self , node : Node , label : int , parent : Node = None ) -> Node :
69+ def _put (
70+ self , node : Optional [Node ], label : int , parent : Optional [Node ] = None
71+ ) -> Node :
6972 if node is None :
7073 node = Node (label , parent )
7174 else :
@@ -95,7 +98,7 @@ def search(self, label: int) -> Node:
9598 """
9699 return self ._search (self .root , label )
97100
98- def _search (self , node : Node , label : int ) -> Node :
101+ def _search (self , node : Optional [ Node ] , label : int ) -> Node :
99102 if node is None :
100103 raise Exception (f"Node with label { label } does not exist" )
101104 else :
@@ -106,7 +109,7 @@ def _search(self, node: Node, label: int) -> Node:
106109
107110 return node
108111
109- def remove (self , label : int ):
112+ def remove (self , label : int ) -> None :
110113 """
111114 Removes a node in the tree
112115
@@ -122,22 +125,22 @@ def remove(self, label: int):
122125 Exception: Node with label 3 does not exist
123126 """
124127 node = self .search (label )
125- if not node .right and not node .left :
126- self ._reassign_nodes (node , None )
127- elif not node .right and node .left :
128- self ._reassign_nodes (node , node .left )
129- elif node .right and not node .left :
130- self ._reassign_nodes (node , node .right )
131- else :
128+ if node .right and node .left :
132129 lowest_node = self ._get_lowest_node (node .right )
133130 lowest_node .left = node .left
134131 lowest_node .right = node .right
135132 node .left .parent = lowest_node
136133 if node .right :
137134 node .right .parent = lowest_node
138135 self ._reassign_nodes (node , lowest_node )
136+ elif not node .right and node .left :
137+ self ._reassign_nodes (node , node .left )
138+ elif node .right and not node .left :
139+ self ._reassign_nodes (node , node .right )
140+ else :
141+ self ._reassign_nodes (node , None )
139142
140- def _reassign_nodes (self , node : Node , new_children : Node ) :
143+ def _reassign_nodes (self , node : Node , new_children : Optional [ Node ]) -> None :
141144 if new_children :
142145 new_children .parent = node .parent
143146
@@ -192,7 +195,7 @@ def get_max_label(self) -> int:
192195 >>> t.get_max_label()
193196 10
194197 """
195- if self .is_empty () :
198+ if self .root is None :
196199 raise Exception ("Binary search tree is empty" )
197200
198201 node = self .root
@@ -216,7 +219,7 @@ def get_min_label(self) -> int:
216219 >>> t.get_min_label()
217220 8
218221 """
219- if self .is_empty () :
222+ if self .root is None :
220223 raise Exception ("Binary search tree is empty" )
221224
222225 node = self .root
@@ -225,7 +228,7 @@ def get_min_label(self) -> int:
225228
226229 return node .label
227230
228- def inorder_traversal (self ) -> list :
231+ def inorder_traversal (self ) -> Iterator [ Node ] :
229232 """
230233 Return the inorder traversal of the tree
231234
@@ -241,13 +244,13 @@ def inorder_traversal(self) -> list:
241244 """
242245 return self ._inorder_traversal (self .root )
243246
244- def _inorder_traversal (self , node : Node ) -> list :
247+ def _inorder_traversal (self , node : Optional [ Node ] ) -> Iterator [ Node ] :
245248 if node is not None :
246249 yield from self ._inorder_traversal (node .left )
247250 yield node
248251 yield from self ._inorder_traversal (node .right )
249252
250- def preorder_traversal (self ) -> list :
253+ def preorder_traversal (self ) -> Iterator [ Node ] :
251254 """
252255 Return the preorder traversal of the tree
253256
@@ -263,7 +266,7 @@ def preorder_traversal(self) -> list:
263266 """
264267 return self ._preorder_traversal (self .root )
265268
266- def _preorder_traversal (self , node : Node ) -> list :
269+ def _preorder_traversal (self , node : Optional [ Node ] ) -> Iterator [ Node ] :
267270 if node is not None :
268271 yield node
269272 yield from self ._preorder_traversal (node .left )
@@ -272,7 +275,7 @@ def _preorder_traversal(self, node: Node) -> list:
272275
273276class BinarySearchTreeTest (unittest .TestCase ):
274277 @staticmethod
275- def _get_binary_search_tree ():
278+ def _get_binary_search_tree () -> BinarySearchTree :
276279 r"""
277280 8
278281 / \
@@ -298,14 +301,15 @@ def _get_binary_search_tree():
298301
299302 return t
300303
301- def test_put (self ):
304+ def test_put (self ) -> None :
302305 t = BinarySearchTree ()
303306 assert t .is_empty ()
304307
305308 t .put (8 )
306309 r"""
307310 8
308311 """
312+ assert t .root is not None
309313 assert t .root .parent is None
310314 assert t .root .label == 8
311315
@@ -315,6 +319,7 @@ def test_put(self):
315319 \
316320 10
317321 """
322+ assert t .root .right is not None
318323 assert t .root .right .parent == t .root
319324 assert t .root .right .label == 10
320325
@@ -324,6 +329,7 @@ def test_put(self):
324329 / \
325330 3 10
326331 """
332+ assert t .root .left is not None
327333 assert t .root .left .parent == t .root
328334 assert t .root .left .label == 3
329335
@@ -335,6 +341,7 @@ def test_put(self):
335341 \
336342 6
337343 """
344+ assert t .root .left .right is not None
338345 assert t .root .left .right .parent == t .root .left
339346 assert t .root .left .right .label == 6
340347
@@ -346,13 +353,14 @@ def test_put(self):
346353 / \
347354 1 6
348355 """
356+ assert t .root .left .left is not None
349357 assert t .root .left .left .parent == t .root .left
350358 assert t .root .left .left .label == 1
351359
352360 with self .assertRaises (Exception ):
353361 t .put (1 )
354362
355- def test_search (self ):
363+ def test_search (self ) -> None :
356364 t = self ._get_binary_search_tree ()
357365
358366 node = t .search (6 )
@@ -364,7 +372,7 @@ def test_search(self):
364372 with self .assertRaises (Exception ):
365373 t .search (2 )
366374
367- def test_remove (self ):
375+ def test_remove (self ) -> None :
368376 t = self ._get_binary_search_tree ()
369377
370378 t .remove (13 )
@@ -379,6 +387,9 @@ def test_remove(self):
379387 \
380388 5
381389 """
390+ assert t .root is not None
391+ assert t .root .right is not None
392+ assert t .root .right .right is not None
382393 assert t .root .right .right .right is None
383394 assert t .root .right .right .left is None
384395
@@ -394,6 +405,9 @@ def test_remove(self):
394405 \
395406 5
396407 """
408+ assert t .root .left is not None
409+ assert t .root .left .right is not None
410+ assert t .root .left .right .left is not None
397411 assert t .root .left .right .right is None
398412 assert t .root .left .right .left .label == 4
399413
@@ -407,6 +421,8 @@ def test_remove(self):
407421 \
408422 5
409423 """
424+ assert t .root .left .left is not None
425+ assert t .root .left .right .right is not None
410426 assert t .root .left .left .label == 1
411427 assert t .root .left .right .label == 4
412428 assert t .root .left .right .right .label == 5
@@ -422,6 +438,7 @@ def test_remove(self):
422438 / \ \
423439 1 5 14
424440 """
441+ assert t .root is not None
425442 assert t .root .left .label == 4
426443 assert t .root .left .right .label == 5
427444 assert t .root .left .left .label == 1
@@ -437,13 +454,15 @@ def test_remove(self):
437454 / \
438455 1 14
439456 """
457+ assert t .root .left is not None
458+ assert t .root .left .left is not None
440459 assert t .root .left .label == 5
441460 assert t .root .left .right is None
442461 assert t .root .left .left .label == 1
443462 assert t .root .left .parent == t .root
444463 assert t .root .left .left .parent == t .root .left
445464
446- def test_remove_2 (self ):
465+ def test_remove_2 (self ) -> None :
447466 t = self ._get_binary_search_tree ()
448467
449468 t .remove (3 )
@@ -456,6 +475,12 @@ def test_remove_2(self):
456475 / \ /
457476 5 7 13
458477 """
478+ assert t .root is not None
479+ assert t .root .left is not None
480+ assert t .root .left .left is not None
481+ assert t .root .left .right is not None
482+ assert t .root .left .right .left is not None
483+ assert t .root .left .right .right is not None
459484 assert t .root .left .label == 4
460485 assert t .root .left .right .label == 6
461486 assert t .root .left .left .label == 1
@@ -466,25 +491,25 @@ def test_remove_2(self):
466491 assert t .root .left .left .parent == t .root .left
467492 assert t .root .left .right .left .parent == t .root .left .right
468493
469- def test_empty (self ):
494+ def test_empty (self ) -> None :
470495 t = self ._get_binary_search_tree ()
471496 t .empty ()
472497 assert t .root is None
473498
474- def test_is_empty (self ):
499+ def test_is_empty (self ) -> None :
475500 t = self ._get_binary_search_tree ()
476501 assert not t .is_empty ()
477502
478503 t .empty ()
479504 assert t .is_empty ()
480505
481- def test_exists (self ):
506+ def test_exists (self ) -> None :
482507 t = self ._get_binary_search_tree ()
483508
484509 assert t .exists (6 )
485510 assert not t .exists (- 1 )
486511
487- def test_get_max_label (self ):
512+ def test_get_max_label (self ) -> None :
488513 t = self ._get_binary_search_tree ()
489514
490515 assert t .get_max_label () == 14
@@ -493,7 +518,7 @@ def test_get_max_label(self):
493518 with self .assertRaises (Exception ):
494519 t .get_max_label ()
495520
496- def test_get_min_label (self ):
521+ def test_get_min_label (self ) -> None :
497522 t = self ._get_binary_search_tree ()
498523
499524 assert t .get_min_label () == 1
@@ -502,20 +527,20 @@ def test_get_min_label(self):
502527 with self .assertRaises (Exception ):
503528 t .get_min_label ()
504529
505- def test_inorder_traversal (self ):
530+ def test_inorder_traversal (self ) -> None :
506531 t = self ._get_binary_search_tree ()
507532
508533 inorder_traversal_nodes = [i .label for i in t .inorder_traversal ()]
509534 assert inorder_traversal_nodes == [1 , 3 , 4 , 5 , 6 , 7 , 8 , 10 , 13 , 14 ]
510535
511- def test_preorder_traversal (self ):
536+ def test_preorder_traversal (self ) -> None :
512537 t = self ._get_binary_search_tree ()
513538
514539 preorder_traversal_nodes = [i .label for i in t .preorder_traversal ()]
515540 assert preorder_traversal_nodes == [8 , 3 , 1 , 6 , 4 , 5 , 7 , 10 , 14 , 13 ]
516541
517542
518- def binary_search_tree_example ():
543+ def binary_search_tree_example () -> None :
519544 r"""
520545 Example
521546 8
0 commit comments