@@ -51,6 +51,8 @@ def rotate_left(self) -> RedBlackTree:
5151 """
5252 parent = self .parent
5353 right = self .right
54+ if right is None :
55+ return self
5456 self .right = right .left
5557 if self .right :
5658 self .right .parent = self
@@ -69,6 +71,8 @@ def rotate_right(self) -> RedBlackTree:
6971 returns the new root to this subtree.
7072 Performing one rotation can be done in O(1).
7173 """
74+ if self .left is None :
75+ return self
7276 parent = self .parent
7377 left = self .left
7478 self .left = left .right
@@ -123,23 +127,30 @@ def _insert_repair(self) -> None:
123127 if color (uncle ) == 0 :
124128 if self .is_left () and self .parent .is_right ():
125129 self .parent .rotate_right ()
126- self .right ._insert_repair ()
130+ if self .right :
131+ self .right ._insert_repair ()
127132 elif self .is_right () and self .parent .is_left ():
128133 self .parent .rotate_left ()
129- self .left ._insert_repair ()
134+ if self .left :
135+ self .left ._insert_repair ()
130136 elif self .is_left ():
131- self .grandparent .rotate_right ()
132- self .parent .color = 0
133- self .parent .right .color = 1
137+ if self .grandparent :
138+ self .grandparent .rotate_right ()
139+ self .parent .color = 0
140+ if self .parent .right :
141+ self .parent .right .color = 1
134142 else :
135- self .grandparent .rotate_left ()
136- self .parent .color = 0
137- self .parent .left .color = 1
143+ if self .grandparent :
144+ self .grandparent .rotate_left ()
145+ self .parent .color = 0
146+ if self .parent .left :
147+ self .parent .left .color = 1
138148 else :
139149 self .parent .color = 0
140- uncle .color = 0
141- self .grandparent .color = 1
142- self .grandparent ._insert_repair ()
150+ if uncle and self .grandparent :
151+ uncle .color = 0
152+ self .grandparent .color = 1
153+ self .grandparent ._insert_repair ()
143154
144155 def remove (self , label : int ) -> RedBlackTree :
145156 """Remove label from this tree."""
@@ -149,8 +160,9 @@ def remove(self, label: int) -> RedBlackTree:
149160 # so we replace this node with the greatest one less than
150161 # it and remove that.
151162 value = self .left .get_max ()
152- self .label = value
153- self .left .remove (value )
163+ if value is not None :
164+ self .label = value
165+ self .left .remove (value )
154166 else :
155167 # This node has at most one non-None child, so we don't
156168 # need to replace
@@ -160,10 +172,11 @@ def remove(self, label: int) -> RedBlackTree:
160172 # The only way this happens to a node with one child
161173 # is if both children are None leaves.
162174 # We can just remove this node and call it a day.
163- if self .is_left ():
164- self .parent .left = None
165- else :
166- self .parent .right = None
175+ if self .parent :
176+ if self .is_left ():
177+ self .parent .left = None
178+ else :
179+ self .parent .right = None
167180 else :
168181 # The node is black
169182 if child is None :
@@ -188,7 +201,7 @@ def remove(self, label: int) -> RedBlackTree:
188201 self .left .parent = self
189202 if self .right :
190203 self .right .parent = self
191- elif self .label > label :
204+ elif self .label is not None and self . label > label :
192205 if self .left :
193206 self .left .remove (label )
194207 else :
@@ -198,6 +211,13 @@ def remove(self, label: int) -> RedBlackTree:
198211
199212 def _remove_repair (self ) -> None :
200213 """Repair the coloring of the tree that may have been messed up."""
214+ if (
215+ self .parent is None
216+ or self .sibling is None
217+ or self .parent .sibling is None
218+ or self .grandparent is None
219+ ):
220+ return
201221 if color (self .sibling ) == 1 :
202222 self .sibling .color = 0
203223 self .parent .color = 1
@@ -231,7 +251,8 @@ def _remove_repair(self) -> None:
231251 ):
232252 self .sibling .rotate_right ()
233253 self .sibling .color = 0
234- self .sibling .right .color = 1
254+ if self .sibling .right :
255+ self .sibling .right .color = 1
235256 if (
236257 self .is_right ()
237258 and color (self .sibling ) == 0
@@ -240,7 +261,8 @@ def _remove_repair(self) -> None:
240261 ):
241262 self .sibling .rotate_left ()
242263 self .sibling .color = 0
243- self .sibling .left .color = 1
264+ if self .sibling .left :
265+ self .sibling .left .color = 1
244266 if (
245267 self .is_left ()
246268 and color (self .sibling ) == 0
@@ -275,29 +297,25 @@ def check_color_properties(self) -> bool:
275297 """
276298 # I assume property 1 to hold because there is nothing that can
277299 # make the color be anything other than 0 or 1.
278-
279300 # Property 2
280301 if self .color :
281302 # The root was red
282303 print ("Property 2" )
283304 return False
284-
285305 # Property 3 does not need to be checked, because None is assumed
286306 # to be black and is all the leaves.
287-
288307 # Property 4
289308 if not self .check_coloring ():
290309 print ("Property 4" )
291310 return False
292-
293311 # Property 5
294312 if self .black_height () is None :
295313 print ("Property 5" )
296314 return False
297315 # All properties were met
298316 return True
299317
300- def check_coloring (self ) -> None :
318+ def check_coloring (self ) -> bool :
301319 """A helper function to recursively check Property 4 of a
302320 Red-Black Tree. See check_color_properties for more info.
303321 """
@@ -310,12 +328,12 @@ def check_coloring(self) -> None:
310328 return False
311329 return True
312330
313- def black_height (self ) -> int :
331+ def black_height (self ) -> int | None :
314332 """Returns the number of black nodes from this node to the
315333 leaves of the tree, or None if there isn't one such value (the
316334 tree is color incorrectly).
317335 """
318- if self is None :
336+ if self is None or self . left is None or self . right is None :
319337 # If we're already at a leaf, there is no path
320338 return 1
321339 left = RedBlackTree .black_height (self .left )
@@ -332,21 +350,21 @@ def black_height(self) -> int:
332350
333351 # Here are functions which are general to all binary search trees
334352
335- def __contains__ (self , label ) -> bool :
353+ def __contains__ (self , label : int ) -> bool :
336354 """Search through the tree for label, returning True iff it is
337355 found somewhere in the tree.
338356 Guaranteed to run in O(log(n)) time.
339357 """
340358 return self .search (label ) is not None
341359
342- def search (self , label : int ) -> RedBlackTree :
360+ def search (self , label : int ) -> RedBlackTree | None :
343361 """Search through the tree for label, returning its node if
344362 it's found, and None otherwise.
345363 This method is guaranteed to run in O(log(n)) time.
346364 """
347365 if self .label == label :
348366 return self
349- elif label > self .label :
367+ elif self . label is not None and label > self .label :
350368 if self .right is None :
351369 return None
352370 else :
@@ -357,12 +375,12 @@ def search(self, label: int) -> RedBlackTree:
357375 else :
358376 return self .left .search (label )
359377
360- def floor (self , label : int ) -> int :
378+ def floor (self , label : int ) -> int | None :
361379 """Returns the largest element in this tree which is at most label.
362380 This method is guaranteed to run in O(log(n)) time."""
363381 if self .label == label :
364382 return self .label
365- elif self .label > label :
383+ elif self .label is not None and self . label > label :
366384 if self .left :
367385 return self .left .floor (label )
368386 else :
@@ -374,13 +392,13 @@ def floor(self, label: int) -> int:
374392 return attempt
375393 return self .label
376394
377- def ceil (self , label : int ) -> int :
395+ def ceil (self , label : int ) -> int | None :
378396 """Returns the smallest element in this tree which is at least label.
379397 This method is guaranteed to run in O(log(n)) time.
380398 """
381399 if self .label == label :
382400 return self .label
383- elif self .label < label :
401+ elif self .label is not None and self . label < label :
384402 if self .right :
385403 return self .right .ceil (label )
386404 else :
@@ -392,7 +410,7 @@ def ceil(self, label: int) -> int:
392410 return attempt
393411 return self .label
394412
395- def get_max (self ) -> int :
413+ def get_max (self ) -> int | None :
396414 """Returns the largest element in this tree.
397415 This method is guaranteed to run in O(log(n)) time.
398416 """
@@ -402,7 +420,7 @@ def get_max(self) -> int:
402420 else :
403421 return self .label
404422
405- def get_min (self ) -> int :
423+ def get_min (self ) -> int | None :
406424 """Returns the smallest element in this tree.
407425 This method is guaranteed to run in O(log(n)) time.
408426 """
@@ -413,15 +431,15 @@ def get_min(self) -> int:
413431 return self .label
414432
415433 @property
416- def grandparent (self ) -> RedBlackTree :
434+ def grandparent (self ) -> RedBlackTree | None :
417435 """Get the current node's grandparent, or None if it doesn't exist."""
418436 if self .parent is None :
419437 return None
420438 else :
421439 return self .parent .parent
422440
423441 @property
424- def sibling (self ) -> RedBlackTree :
442+ def sibling (self ) -> RedBlackTree | None :
425443 """Get the current node's sibling, or None if it doesn't exist."""
426444 if self .parent is None :
427445 return None
@@ -432,11 +450,15 @@ def sibling(self) -> RedBlackTree:
432450
433451 def is_left (self ) -> bool :
434452 """Returns true iff this node is the left child of its parent."""
435- return self .parent and self .parent .left is self
453+ if self .parent is None :
454+ return False
455+ return self .parent .left is self .parent .left is self
436456
437457 def is_right (self ) -> bool :
438458 """Returns true iff this node is the right child of its parent."""
439- return self .parent and self .parent .right is self
459+ if self .parent is None :
460+ return False
461+ return self .parent .right is self
440462
441463 def __bool__ (self ) -> bool :
442464 return True
@@ -452,21 +474,21 @@ def __len__(self) -> int:
452474 ln += len (self .right )
453475 return ln
454476
455- def preorder_traverse (self ) -> Iterator [int ]:
477+ def preorder_traverse (self ) -> Iterator [int | None ]:
456478 yield self .label
457479 if self .left :
458480 yield from self .left .preorder_traverse ()
459481 if self .right :
460482 yield from self .right .preorder_traverse ()
461483
462- def inorder_traverse (self ) -> Iterator [int ]:
484+ def inorder_traverse (self ) -> Iterator [int | None ]:
463485 if self .left :
464486 yield from self .left .inorder_traverse ()
465487 yield self .label
466488 if self .right :
467489 yield from self .right .inorder_traverse ()
468490
469- def postorder_traverse (self ) -> Iterator [int ]:
491+ def postorder_traverse (self ) -> Iterator [int | None ]:
470492 if self .left :
471493 yield from self .left .postorder_traverse ()
472494 if self .right :
@@ -488,15 +510,17 @@ def __repr__(self) -> str:
488510 indent = 1 ,
489511 )
490512
491- def __eq__ (self , other ) -> bool :
513+ def __eq__ (self , other : object ) -> bool :
492514 """Test if two trees are equal."""
515+ if not isinstance (other , RedBlackTree ):
516+ return NotImplemented
493517 if self .label == other .label :
494518 return self .left == other .left and self .right == other .right
495519 else :
496520 return False
497521
498522
499- def color (node ) -> int :
523+ def color (node : RedBlackTree | None ) -> int :
500524 """Returns the color of a node, allowing for None leaves."""
501525 if node is None :
502526 return 0
@@ -699,19 +723,12 @@ def main() -> None:
699723 >>> pytests()
700724 """
701725 print_results ("Rotating right and left" , test_rotations ())
702-
703726 print_results ("Inserting" , test_insert ())
704-
705727 print_results ("Searching" , test_insert_and_search ())
706-
707728 print_results ("Deleting" , test_insert_delete ())
708-
709729 print_results ("Floor and ceil" , test_floor_ceil ())
710-
711730 print_results ("Tree traversal" , test_tree_traversal ())
712-
713731 print_results ("Tree traversal" , test_tree_chaining ())
714-
715732 print ("Testing tree balancing..." )
716733 print ("This should only be a few seconds." )
717734 test_insertion_speed ()
0 commit comments