# Visitor

For this pattern, we take an example of numerical addition where there is a base expression and 2 child classes.

In [2]:
class Expression:
    pass

class DoubleExpression(Expression):
    def __init__(self, value):
        self.value = value
        

class AdditionExpression(Expression):
    def __int__(self, left, right):
        self.left = left
        self.right = right
        

## Intrusive visitor

This is the most naive way to extend the existing objects' functionalalities. But it is obvious that this is not practical in real applications since this involves a lot repetitive codes (in this case, adding a print function to each concrete objects hindering scalability problem) and prone to errors.

In [None]:
# Scenario 1

# For this example, our extended function has the same operation only 
# it invoke different function depending on the child object type.

class Expression:
    @staticmethod
    def print(self):
        pass
        
    def print(self):
        return self.__str__()

class DoubleExpression(Expression):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return str(self.value)

class AdditionExpression(Expression):
    def __init__(self, left:Expression, right:Expression):
        self.left = left
        self.right = right

    def __str__(self):
        return f"({self.left.print()} + {self.right.print()})"

In [17]:
# usage
add = AdditionExpression(DoubleExpression(2),
                         AdditionExpression(
                         DoubleExpression(3), DoubleExpression(4)))

add.print()

'(2 + (3 + 4))'

In [16]:
# Scenario 2

# Sometimes, the operations are more complex depending on the child type, 
# so the extended function won't share the same operations.

class Expression:
    @staticmethod
    def print(self):
        pass

class DoubleExpression(Expression):
    def __init__(self, value):
        self.value = value
 
    def print(self):
        return str(self.value)

class AdditionExpression(Expression):
    def __init__(self, left:Expression, right:Expression):
        self.left = left
        self.right = right

    def print(self):
        return f"({self.left.print()} + {self.right.print()})"


In [18]:
# usage won't change
add = AdditionExpression(DoubleExpression(2),
                         AdditionExpression(
                         DoubleExpression(3), DoubleExpression(4)))

add.print()

'(2 + (3 + 4))'

## Reflective printer

To overcome the previous scalability problem, we can create a printer object dedicated to printing job. This obviously respect SRP and it saved code and allows to errors.

In [10]:
# Scenario 1

class ExpressionPrinter:
    def __init__(self):
        self.string = ""
    def print(self, expr):
        self.string = expr.print()
    def __str__(self):
        return self.string

In [11]:
add = AdditionExpression(DoubleExpression(2),
                         AdditionExpression(
                         DoubleExpression(3), DoubleExpression(4)))

ep = ExpressionPrinter()
ep.print(add)
print(ep)

(2 + (3 + 4))


In [60]:
# Scenario 2

class ExpressionPrinter:
    def __init__(self):
        self.string = ""
    def print(self, expr):
        if isinstance(expr, DoubleExpression):
            self.string = expr.print()
        elif isinstance(expr, AdditionExpression):
            self.string = f"({expr.left.print()} + {expr.right.print()})"
    def __str__(self):
        return self.string

In [61]:
add = AdditionExpression(DoubleExpression(2),
                         AdditionExpression(
                         DoubleExpression(3), DoubleExpression(4)))

ep = ExpressionPrinter()
ep.print(add)
print(ep)

(2 + (3 + 4))


## Dispatch

It is a problem of figuring out which function to call—specifically, how many pieces of information are required in order to make the call. 

TODO

In [1]:
# method 1, this won't work for other typing language such as c++ but works for python

from abc import abstractmethod

class Stuff:
    @abstractmethod
    def call(self):
        pass

def func(obj):
    print(obj)

class Foo(Stuff):
    def __str__(self):
        return "this is Foo"

class Bar(Stuff):
    def __str__(self):
        return "this is Bar"

In [3]:
stuff1 = Foo()
stuff2 = Bar()
func(stuff1)
func(stuff2)

this is Foo
this is Bar


In [40]:
# method 2

from abc import abstractmethod

class Stuff:
    @abstractmethod
    def call(self):
        pass

def func(obj):
    print(obj)

class Foo(Stuff):
    def call(self):
        func(self)
    def __str__(self):
        return "this is Foo"

class Bar(Stuff):
    def call(self):
        func(self)
    def __str__(self):
        return "this is Bar"

In [41]:
foo = Foo()
foo.call()

bar = Bar()
bar.call()

this is Foo
this is Bar


## Classic visitor

In [7]:
from abc import abstractmethod
class Expression:
    @abstractmethod
    def accept(self, visitor):
        visitor.visit(self)

class DoubleExpression(Expression):
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return str(self.value)
        
class AdditionExpression(Expression):
    def __init__(self, left:Expression, right:Expression):
        self.left = left
        self.right = right
    def __str__(self):
        return f"({self.left} + {self.right})"


In [8]:
# Scenario 1

from io import StringIO

class ExpressionVisitor:
    @abstractmethod
    def visit(self, epxr):
        pass

class ExpressionPrinter(ExpressionVisitor):
    def __init__(self):
        self.oss = StringIO()
        
    def visit(self, expr):
        self.oss.write(str(expr))
        
    def __str__(self):
        return self.oss.getvalue()

In [9]:
add = AdditionExpression(DoubleExpression(2),
                         AdditionExpression(
                         DoubleExpression(3), DoubleExpression(4)))

printer = ExpressionPrinter()
printer.visit(add)
print(printer)

(2 + (3 + 4))


In [10]:
# Scenario 2

# This uses double dispatch.

from io import StringIO

class ExpressionVisitor:
    @abstractmethod
    def visit(self, epxr):
        pass

class ExpressionPrinter(ExpressionVisitor):
    def __init__(self):
        self.oss = StringIO()
        
    def visit(self, expr):
        if isinstance(expr, DoubleExpression):
            self.oss.write(str(expr))
        elif isinstance(expr, AdditionExpression):
            self.oss.write("(")
            expr.left.accept(self)
            self.oss.write(" + ")
            expr.right.accept(self)
            self.oss.write(")")
    def __str__(self):
        return self.oss.getvalue()

In [11]:
add = AdditionExpression(DoubleExpression(2),
                         AdditionExpression(
                         DoubleExpression(3), DoubleExpression(4)))

printer = ExpressionPrinter()
add.accept(printer)
print(printer)

(2 + (3 + 4))


## add another visitor

In [33]:
class ExpressoinEvaluator(ExpressionVisitor):
    def visit(self, expr):
        if isinstance(expr, DoubleExpression):
            self.result = expr.value
        elif isinstance(expr, AdditionExpression):
            expr.left.accept(self)
            tmp = self.result
            expr.right.accept(self)
            self.result += tmp
    def __str__(self):
        return f"{self.result}"
    

In [36]:
printer = ExpressionPrinter()
evaluator = ExpressoinEvaluator()
printer.visit(add)
evaluator.visit(add)
print(f"{printer} = {evaluator}")

(2 + (3 + 4)) = 9


## Another Example

In [101]:
class Node(object):
    def __init__(self, value):
        self.children = []
        self.value = value
    def add(self, node: Node):
        self.children.append(node)
    def get_value(self):
        return self.value
    def accept(self, visitor):
        visitor.visit(self)

class NodeRunner:
    def visit(self, node):
        if len(node.children) > 0:
            for child in node.children:
                child.accept(self)
        print(f"Node value: {node.get_value()}")

class NodeSum:
    s = 0
    def visit(self, node):
        if len(node.children) > 0:
            for child in node.children:
                child.accept(self)    
        NodeSum.s += node.get_value()

In [102]:
root = Node(1)
root.add(Node(2))
root.add(Node(3))
root.children[0].add(Node(4))

runner = NodeRunner()
checker = NodeChecker()
summer = NodeSum()
root.accept(runner)
root.accept(summer)
print(summer.s)

Node value: 4
Node value: 2
Node value: 3
Node value: 1
10
