In [201]:
from typing import List
from collections import defaultdict

n = 10
first = [6, 8, 3, 6, 4, 1, 8, 5, 1]
second = [9, 9, 5, 7, 8, 8, 10, 8, 2]
values = [17, 29, 3, 20, 11, 8, 3, 23, 5, 15]
queries = [1, 8, 9, 6, 4, 3]


nodes = {}

class Node:
    def __init__(self, index, children=None, val=None, parent=None):
        self.children = children or []
        self.val = val 
        self.index = index
        self.parent = parent or None
        
    def __hash__(self):
        return self.index
    
    def __repr__(self):
        return f"Node({self.index})"
    
    def __str__(self):
        return f"Node: {self.index}, val: {self.val}, children: {self.children}, parent: " + repr(self.parent) 
    

# Taken from: 
# https://stackoverflow.com/questions/4114167/checking-if-a-number-is-a-prime-number-in-python
from math import sqrt; from itertools import count, islice
def is_prime(n):
    return n > 1 and all(n%i for i in islice(count(2), int(sqrt(n)-1)))


def get_node(index):
    if nodes.get(index):
        return nodes.get(index)
    n = Node(index)
    nodes[index] = n
    return n
        
def make_adjancencies(first: List, second: List):
    """returns dict of adjanciceis"""
    adjacency = defaultdict(list)
    
    for p1, p2 in zip(first, second):
        adjacency[p1].append(p2)
        adjacency[p2].append(p1)
    
    return adjacency


def make_tree(first: List, second: List) -> Node:
    """Returns root node of the Tree"""
    adjacency = make_adjancencies(first, second)    # O(n)
    
    def remove_parents(index, parent):
        if parent is not None: # only for the root
            adjacency[index].remove(parent)        
        for child in adjacency[index]:
            remove_parents(child, index)
    
    
    # Removes parents from adjacencies
    remove_parents(1, None)                       # O (n)
    
    parent_child_relations = []                     # O (n)
    for parent, children in adjacency.items():
        for child in children:
            parent_child_relations.append((parent, child))
            
    # n1 should point to n2                             # O (n)
    for n1, n2 in parent_child_relations:
        node1, node2 = get_node(n1), get_node(n2)       # O (1)
        node1.children.append(node2)
        node2.parent = node1
        
    return nodes[1]


# O (n)
def assign_values(values):
    for node, value in enumerate(values, 1):
        nodes[node].val = value
        

def primeQuery(n, first, second, values, queries):
    # Write your code here
    make_tree(first, second)     # O (n)
    assign_values(values)         # O (n)
    root = nodes[1]
    prime_counts = calculate_primes(root)        # O (n)
    
    res = []
    for query in queries:
        res.append(prime_counts[query])
        
    return res
    

def get_leafs(nodes):
    return [node for node in nodes.values() if not node.children]
    

def recur(node, add, counts):
    if node is None: 
        return
    
    if is_prime(node.val) and node.index not in counts:
        add += 1
        
    counts[node.index] += add
    recur(node.parent, add, counts)

def calculate_primes(root):
    """calculate primes from the root of the tree"""
    counts = defaultdict(lambda: 0)   # Key: node number, Value: number of nodes below it that are prime
    leafs = get_leafs(nodes)
    
    for leaf in leafs:                 # O (n)
        recur(leaf, 0, counts)         # O (nd) d is depth of the tree
        
    return counts
        
    

primeQuery(n, first, second, values, queries)

[7, 5, 2, 1, 0, 1]

## 