In [2]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None


def BTthroughInput():
    data = int(input())
    if data == -1:
        return None
    root = Node(data)
    leftTree = BTthroughInput()
    rightTree = BTthroughInput()
    root.left = leftTree
    root.right = rightTree
    return root


def displayBT(root):
    if not root:
        return None
    print(root.data, end=" ")
    if root.left:
        print("L : ", root.left.data, end=" ")

    if root.right:
        print("R : ", root.right.data, end=" ")
    print()
    displayBT(root.left)
    displayBT(root.right)


bt = BTthroughInput()
displayBT(bt)

1 L :  2 R :  3 
2 L :  4 R :  5 
4 
5 
3 L :  6 R :  7 
6 
7 


In [7]:
bt1 = BTthroughInput()
displayBT(bt1)

2 L :  3 R :  4 
3 
4 


In [4]:
## Height of binary tree
def heightBT(root):
    if not root:
        return 0
    return 1 + max(heightBT(root.left), heightBT(root.right))

print(heightBT(bt))

3


In [8]:
## Determine if two trees are identical
def isIdentical(root1, root2):
    if not root1 and not root2:
        return True
    if root1.data != root2.data:
        return False
    leftpair = isIdentical(root1.left, root2.left)
    rightpair = isIdentical(root1.right, root2.right)
    return leftpair and rightpair

print(isIdentical(bt,bt))
print(isIdentical(bt,bt1))


True
False


In [9]:
## swap tree
def swapTree(root):
    if not root:
        return
    root.left, root.right = root.right, root.left
    swapTree(root.left)
    swapTree(root.right)

displayBT(bt1)
swapTree(bt1)
displayBT(bt1)

2 L :  3 R :  4 
3 
4 
2 L :  4 R :  3 
4 
3 


In [14]:
## Symmetric Tree : Check whether its a mirror of itself
##      1
##    /   \
##   2     2
##  / \   / \
## 3   4 4   3
def isMirror(root1, root2):
    if not root1 and not root2:
        return True
    if root1.data == root2.data:
        return isMirror(root1.left, root2.right) and isMirror(root1.right, root2.left)
    return False

bt2 = BTthroughInput()
print(isMirror(bt2, bt2))

displayBT(bt2)

True


In [26]:
## Diameter of tree
def diameterOfBinaryTree(root):
    if root is None:
        return 0
    lh = heightBT(root.left)
    rh = heightBT(root.right)

    ld = diameterOfBinaryTree(root.left)
    rd = diameterOfBinaryTree(root.right)
    return max(lh + rh + 1, max(ld, rd))

diameterOfBinaryTree(bt)

5

In [33]:
## diameter optimised
class Height:
    def __init__(self):
        self.h = 0

def diameterHelper(root, height):
    lh = rh = Height()
    if not root:
        height.h = 0
        return 0

    ld = diameterHelper(root.left, lh)
    rd = diameterHelper(root.right, rh)

    height.h = 1 + max(lh.h, rh.h)

    return max(1+lh.h+ rh.h, max(ld, rd))

def diameterOfBinaryTree(root):
    return diameterHelper(root, Height())

diameterOfBinaryTree(bt)

5

In [37]:
## Check is balanced?
def isBalanced(root):
    if not root:
        return True
    lh = heightBT(root.left)
    rh = heightBT(root.right)
    if lh-rh > 1 or rh-lh > 1:
        return False
    return isBalanced(root.left) and isBalanced(root.right)

isBalanced(bt)

True

In [76]:
## Path Sum root to leaf: give all path whose equal to k
# --------------------------------------------------------
# DESCRIPTION:
# For a given Binary Tree of type integer and a number K, print out all root-to-leaf paths where the sum of all the node data along the path is equal to K.
# --------------------------------------------------------

def rootToLeafPathSum(root, k, path='', currSum=0):
    if root is None :
        return

    if (root.left is None) and (root.right is None) :
        currSum += root.data

        if currSum == k :
            print(str(path + str(root.data) + " ").lstrip())
        return


    rootToLeafPathSum(root.left, k, str(path + str(root.data) + " "), (currSum + root.data))
    rootToLeafPathSum(root.right, k, str(path + str(root.data) + " "), (currSum + root.data))

rootToLeafPathSum(bt2, 7)

1 2 4 
1 2 4 


In [78]:
## print nodes at distance k from node
def getallNodes(root, k):
    if not root:
        return
    if k==0:
        print(root.data)
        return
    getallNodes(root.left, k-1)
    getallNodes(root.right, k-1)


def nodesAtDistanceK(root, node, k):
    if not root:
        return -1

    if root.data == node:
        getallNodes(root,k)
        return 0
    leftDist = nodesAtDistanceK(root.left, node, k)
    if leftDist != -1:
        if leftDist + 1 == k:
            print(root.data)
        else:
            getallNodes(root.right, k-leftDist-2)
        return 1+leftDist

    rightDist = nodesAtDistanceK(root.right, node, k)
    if rightDist != -1:
        if rightDist + 1 == k:
            print(root.data)
        else:
            getallNodes(root.left, k-rightDist-2)
        return 1+rightDist

    return -1


nodesAtDistanceK(bt, 6,1)

3


2

In [74]:
displayBT(bt2)

1 L :  2 R :  2 
2 L :  3 R :  4 
3 
4 
2 L :  4 R :  3 
4 
3 
