## Вариант 19
$$ z = bab + a^2b - b^2 $$

## Operation

In [26]:
class Operation():
    
    def __init__(self, input_nodes = []):
        
        self.input_nodes = input_nodes # The list of input nodes
        self.output_nodes = [] # List of nodes consuming this node's output

        for node in input_nodes:
            node.output_nodes.append(self)

        _default_graph.operations.append(self)
  
    def compute(self):
        
        pass

## Example Operations

### Addition

In [2]:
class add(Operation):
    
    def __init__(self, x, y):
         
        super().__init__([x, y])

    def compute(self, x_var, y_var):
         
        self.inputs = [x_var, y_var]
        return x_var + y_var

### Substitution

In [3]:
class substitute(Operation):
    
    def __init__(self, x, y):
         
        super().__init__([x, y])

    def compute(self, x_var, y_var):
         
        self.inputs = [x_var, y_var]
        return x_var - y_var

### Multiplication

In [4]:
class multiply(Operation):
     
    def __init__(self, a, b):
        
        super().__init__([a, b])
    
    def compute(self, a_var, b_var):
         
        self.inputs = [a_var, b_var]
        return a_var * b_var

### Squaring

In [5]:
class squaring(Operation):
     
    def __init__(self, a):
        
        super().__init__([a])
    
    def compute(self, a_var):
         
        self.inputs = [a_var]
        return a_var ** 2

### Matrix Multiplication

In [6]:
class matmul(Operation):
     
    def __init__(self, a, b):
        
        super().__init__([a, b])
    
    def compute(self, a_mat, b_mat):
         
        self.inputs = [a_mat, b_mat]
        return a_mat.dot(b_mat)

## Placeholders

In [7]:
class Placeholder():
    
    def __init__(self):
        
        self.output_nodes = []
        
        _default_graph.placeholders.append(self)

## Variables

In [8]:
class Variable():
    
    def __init__(self, initial_value = None):
        
        self.value = initial_value
        self.output_nodes = []
        
         
        _default_graph.variables.append(self)

## Graph

In [9]:
class Graph():
    
    
    def __init__(self):
        
        self.operations = []
        self.placeholders = []
        self.variables = []
        
    def set_as_default(self):
        
        global _default_graph
        _default_graph = self

## The Graph

$$ z = bab + a^2b - b^2 $$

With a=2 and b=3

$$ z = 18 + 12 - 9 = 21 $$


In [10]:
g = Graph()

In [11]:
g.set_as_default()

In [12]:
a = Variable(2)

In [13]:
b = Variable(3)

In [14]:
bab = multiply(b, multiply(a, b))

In [15]:
a2b = multiply(squaring(a), b)

In [16]:
b2 = squaring(b)

In [17]:
z = substitute(add(bab, a2b), b2)

## Session

In [18]:
import numpy as np

Traversing Operation Nodes

In [19]:
def traverse_postorder(operation):
    
    nodes_postorder = []
    def recurse(node):
        if isinstance(node, Operation):
            for input_node in node.input_nodes:
                recurse(input_node)
        nodes_postorder.append(node)

    recurse(operation)
    return nodes_postorder

In [20]:
class Session:
    
    def run(self, operation, feed_dict = {}):
        
        nodes_postorder = traverse_postorder(operation)
        
        for node in nodes_postorder:

            if type(node) == Placeholder:
                
                node.output = feed_dict[node]
                
            elif type(node) == Variable:
                
                node.output = node.value
                
            else: # Operation
                
                node.inputs = [input_node.output for input_node in node.input_nodes]

                 
                node.output = node.compute(*node.inputs)
                
            # Convert lists to numpy arrays
            if type(node.output) == list:
                node.output = np.array(node.output)
        
        # Return the requested node value
        return operation.output

In [21]:
sess = Session()

In [22]:
result = sess.run(operation=z,feed_dict={})

In [23]:
result

21

In [24]:
18+12-9

21