# Creating new visitors

The [`FindNodes`](https://sites.ecmwf.int/docs/loki/master/loki.visitors.find.html#loki.visitors.find.FindNodes) visitor looks through a given IR tree and returns a list of matching instances of a specified [`Node`](https://sites.ecmwf.int/docs/loki/master/loki.ir.html#loki.ir.Node) type.

For `Node` types that could appear in a nested structure, for example [`Loop`](https://sites.ecmwf.int/docs/loki/master/loki.ir.html#loki.ir.Loop) or [`Conditional`](https://sites.ecmwf.int/docs/loki/master/loki.ir.html#loki.ir.Conditional), we may be interested in knowing at what depth they appear in a given IR tree.

This notebook will illustrate how this can be achieved by building a new `FindNodesDepth` visitor based on `FindNodes`.

## Dataclass to store return values

The default return value for `FindNodes` is a list of nodes. For `FindNodesDepth`, we would also like to return the depth of the node. We can create a new dataclass (essentially a c-style struct) called `DepthNode` to store both these pieces of information:

In [1]:
from loki import Node
from dataclasses import dataclass

@dataclass
class DepthNode:
    node: Node
    depth: int

## Modifying initialization method

`FindNodes` has two operating modes. The first, and default mode, is to look through a given IR tree and return a list of matching instances of a specified `Node` type. The second, which is enabled by passing `mode='scope'` when creating the visitor, returns the [`InternalNode`](https://sites.ecmwf.int/docs/loki/master/loki.ir.html#loki.ir.InternalNode) i.e. the [`Scope`](https://sites.ecmwf.int/docs/loki/master/loki.scope.html#loki.scope.Scope) in which a specified `Node` appears.

For our new visitor, we are only interested in the default operating mode of `FindNodes`. Therefore let us define a new initialization function for our `FindNodesDepth` class:

In [2]:
from loki import FindNodes

class FindNodesDepth(FindNodes):
    def __init__(self,match,greedy=False):
        super().__init__(match,mode='type',greedy=greedy)

## Modifying the `visit_Node` method

In order to achieve the desired functionality of our new visitor, we will need a new `visit_Node` method. We start from a copy of `FindNodes.visit_Node` and make only a few changes to it:

In [3]:
class FindNodesDepth(FindNodes):
    def __init__(self,match,greedy=False):
        super().__init__(match,mode='type',greedy=greedy)
        
    def visit_Node(self, o,**kwargs):
        ret = kwargs.pop('ret', self.default_retval())
        depth = kwargs.pop('depth',-1)
        if self.rule(self.match, o): 
            depth += 1
            ret.append(DepthNode(o,depth))
            if self.greedy:
                return ret 
        for i in o.children:
            ret = self.visit(i,depth=depth,ret=ret,**kwargs)
        return ret or self.default_retval()

The first change to `visit_Node` is on line 9. If `visit_Node` is called from the base IR tree, then `depth` is initialized to -1. If on the other hand `visit_Node` is called recursively, then line 9 retrieves the current `depth` of `Node` `o`. The next changes are on lines 11-12. If `Node` `o` matches the specified type, then `depth` is incremented by 1 and a `DepthNode` object is appended to the return list. The final change is on line 16, where `depth` is passed as a keyword argument to a recursive call to `visit`. This is required to allow line 9 to retrieve the current value of `depth`.

Having now fully defined our new visitor, we can test it on the following routine containing nested loops:

In [4]:
from loki import Sourcefile
from loki import fgen

source = Sourcefile.from_file('src/loop_fuse.F90')
routine = source['loop_fuse_v1']
print(fgen(routine.body))


DO k=1,n
  DO j=1,n
    DO i=1,n
      var_out(i, j, k) = var_in(i, j, k)
    END DO
    DO i=1,n
      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)
    END DO
  END DO
  
  CALL some_kernel(n, var_out(1, 1, k))
  
  DO j=1,n
    DO i=1,n
      var_out(i, j, k) = var_out(i, j, k) + 1._JPRB
    END DO
    DO i=1,n
      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)
    END DO
  END DO
END DO



`loop_fuse_v1` contains a total of 7 loops, with a maximum nesting depth of 3. Let us see if our new visitor can identify the loops and their depth correctly:

In [5]:
from loki import Loop

loops = FindNodesDepth(Loop).visit(routine.body)

for k,loop in enumerate(loops):
    print(k,loop.node,loop.depth)
    
depth = [0,1,2,2,1,2,2]
assert(depth == [loop.depth for loop in loops])

0 Loop:: k=1:n 0
1 Loop:: j=1:n 1
2 Loop:: i=1:n 2
3 Loop:: i=1:n 2
4 Loop:: j=1:n 1
5 Loop:: i=1:n 2
6 Loop:: i=1:n 2


All the loops and their respective depths are identified correctly. We can do a similar test on nested `if` statements:

In [6]:
from loki import Subroutine
from loki import Conditional

fcode = """ 
subroutine nested_conditionals(i,j,k,h)
    
    logical,intent(in) :: i,j,k,h

    if(i)then
      if(j)then

        if(k)then
          ! do something
        else
          ! do something else
        endif
        
        if(h)then
          ! also test h
        endif

      endif
    endif

end subroutine nested_conditionals
"""

routine = Subroutine.from_source(fcode)

conds = FindNodesDepth(Conditional).visit(routine.body)
for k,cond in enumerate(conds):
    print(k,cond.node.condition,cond.depth)
    
depth = [0,1,2,2]
assert(depth == [cond.depth for cond in conds])

0 i 0
1 j 1
2 k 2
3 h 2
