# Trim a Binary Search Tree 

## Problem Statement

Given the root of a binary search tree and 2 numbers min and max, trim the tree such that all the numbers in the new tree are between min and max (inclusive). The resulting tree should still be a valid binary search tree. So, if we get this tree as input:
___

![title](images/bst1.png)
___
and we’re given **min value as 5** and **max value as 13**, then the resulting binary search tree should be: 
___
![title](images/bst_trim.png)
___
We should remove all the nodes whose value is not between min and max. 


In [1]:
class Node:
    def __init__(self, k, left=None, right=None, parent=None):
        self.key = k
        self.left = left
        self.right = right
        self.parent = parent

## Solution - Without Using Parent

In [2]:
def trimBST(tree, minVal, maxVal): 
    
    if not tree: 
        return 
    
    tree.left=trimBST(tree.left, minVal, maxVal) 
    tree.right=trimBST(tree.right, minVal, maxVal) 
    
    if minVal<=tree.key<=maxVal: 
        return tree 
    
    if tree.key<minVal: 
        return tree.right 
    
    if tree.key>maxVal: 
        return tree.left 

## Solution 2 - With Parent

In [None]:
def trimBST(tree, minVal, maxVal):
    if not tree:
        return
    
    if minVal<=tree.key<=maxVal:
        trimBST(tree.left, minVal, maxVal)
        trimBST(tree.right, minVal, maxVal)
    elif tree.key<minVal:
        if tree.right:
            tree.right.parent = tree.parent
            tree.parent.left = tree.right
            trimBST(tree.right, minVal, maxVal)
        else:
            tree.parent.left = None
    else:
        if tree.left:
            tree.left.parent = tree.parent
            tree.parent.right = tree.left
            trimBST(tree.left, minVal, maxVal)
        else:
            tree.parent.right = None

## Test Cases

In [3]:
# Define functions to check BST
def tree_max(node):
    if not node:
        return float("-inf")
    maxleft  = tree_max(node.left)
    maxright = tree_max(node.right)
    return max(node.key, maxleft, maxright)

def tree_min(node):
    if not node:
        return float("inf")
    minleft  = tree_min(node.left)
    minright = tree_min(node.right)
    return min(node.key, minleft, minright)

def bst_check(node):
    if not node:
        return True
    if (tree_max(node.left) <= node.key <= tree_min(node.right) and
        bst_check(node.left) and bst_check(node.right)):
        return True
    else:
        return False

In [4]:
bst = Node(8)
bst.left = Node(3, parent=bst)
bst.left.left = Node(1, parent=bst.left)
bst.left.right = Node(6, parent=bst.left)
bst.left.right.left = Node(4, parent=bst.left.right)
bst.left.right.right = Node(7, parent=bst.left.right)
bst.right = Node(10, parent=bst)
bst.right.right = Node(14, parent=bst.right)
bst.right.right.left = Node(13, parent=bst.right.right)

print(tree_min(bst))
print(tree_max(bst))
bst_check(bst)

1
14


True

In [5]:
trimBST(bst, 5, 13)

print(tree_min(bst))
print(tree_max(bst))
bst_check(bst)

6
13


True