#  All Nodes Distance K in Binary Tree

We are given a binary tree (with root node root), a target node, and an integer value K.

Return a list of the values of all nodes that have a distance K from the target node.  The answer can be returned in any order.

 

**Example 1:**

![alt text](assets/all_node_distance_k_in_binary.png)
```
Input: root = [3,5,1,6,2,0,8,null,null,7,4], target = 5, K = 2

Output: [7,4,1]
```
Explanation: 
The nodes that are a distance 2 from the target node (with value 5)
have values 7, 4, and 1.



Note that the inputs "root" and "target" are actually TreeNodes.
The descriptions of the inputs above are just serializations of these objects.
 

Note:

The given tree is non-empty.
Each node in the tree has unique values 0 <= node.val <= 500.
The target node is a node in the tree.
0 <= K <= 1000.

## Communication

We could approach this problem by running two dfs methods, one dfs to use to find the target node, and another dfs to traverse the graph to find the K distant nodes. To begin with, we maintian a path with the target node by returning the target as 1, and all other failed paths as -1. With this initial distinction, we're able to separate the valid paths and invalid paths. Next, since we terminate the first dfs to seek the target node from continuing to traverse down the graph, we use the secondary dfs to traverse down the graph to find the K distant nodes. When the initial seeker dfs is returning it's traversal stack, we also want to call the second dfs traversal to check for any valid traversals. In addition, for every valid path for the left, we must also traverse down the right path using the second dfs. When we reach K distant in either dps functions, we want to add the node to the result list that contains all the K distant nodes. This approach takes time complexity $O(n)$ because we must traverse every path and node. This approach takes space complexity $O(n)$ because we use dfs and a recursive stack costs linear space complexity.

In [10]:
## Coding
# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

class Solution(object):
    def distanceK(self, root, target, K):
        ans = []
        def dfs(node):
            if not node:
                return -1
            elif node == target:
                subtree_traverse(node, 0)
                return 1
            else:
                L, R = dfs(node.left), dfs(node.right)
                if L != -1:
                    if L == K: ans.append(node.val)
                    subtree_traverse(node.right, L + 1)
                    return L + 1
                if R != -1:
                    if R == K: ans.append(node.val)
                    subtree_traverse(node.left, R + 1)
                    return R + 1
                else:
                    return -1
        def subtree_traverse(node, dist):
            if not node:
                return
            elif dist == K:
                ans.append(node.val)
            else:
                subtree_traverse(node.left, dist + 1)
                subtree_traverse(node.right, dist + 1)
        dfs(root)
        return ans
    def createTree(self, nums, target):
        head = current = tnode = None
        nums.insert(0, None)
        queue = []
        for i in range(1, len(nums)):
            if head is None:
                head = current = TreeNode(nums[i])
            else:
                current = queue.pop(0)
                if current is None:
                    continue
                if current.val == target:
                    tnode = current
            if i*2 < len(nums):
                if nums[i*2] == None:
                    current.left = None
                else:
                    current.left = TreeNode(nums[i*2])
                queue.append(current.left)
            if i*2 + 1 < len(nums):
                if nums[i*2+1] == None:
                    current.right = None
                else:
                    current.right = TreeNode(nums[i*2+1])
                queue.append(current.right)
        return head, tnode
            
    def unit_tests(self):
        test_cases = [
            [[3,5,1,6,2,0,8,None,None,7,4], 5, 2, [7, 4, 1]]
        ]
        for index, tc in enumerate(test_cases):
            head, target = self.createTree(tc[0], tc[1])
            output = self.distanceK(head, target, tc[2])
            assert set(output) == set(tc[3]), 'test#{0} failed'.format(index)
            print('test#{0} passed'.format(index))
Solution().unit_tests()

test#0 passed


## Reference
- [Leetcode](https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree/)