# Globals

In [1]:
cost_dict = {
    "add" : 0,
    "mul" : 0,
    "relu" : 0,
    "sigmoid" : 0,
    "mux" : 0,
    "sqrt" : 0,
}

## Functions

In [2]:
def add_cost_dicts(dict1, dict2) :
    ret_dict = cost_dict.copy()
    
    for k, v in dict1.items() :
        ret_dict[k] = dict1[k] + dict2[k]
        
    return ret_dict

# Classes

### Building Blocks

In [3]:
class Node() :
    def __init__(self) :
        self.fwd_costs = cost_dict.copy()
        self.back_costs = cost_dict.copy()
        self.is_fwd_computed = False
        self.is_back_computed = False
        self.components = []
        
    def compute_forward(self) :
        for c in self.components :
            if not c.is_fwd_computed :
                c.compute_forward()
            self.fwd_costs = add_cost_dicts(self.fwd_costs, c.fwd_costs)
            
    def compute_backward(self) :
        for c in self.components: 
            if not c.is_back_computed :
                c.compute_back()
            self.back_costs = add_cost_dicts(self.back_costs, c.back_costs)
        
class Sigmoid(Node) :
    def __init__(self, sz) :
        super(Sigmoid, self).__init__()
        self.sz = sz
        
    def compute_forward(self) :
        super(Sigmoid, self).compute_forward()
        self.fwd_costs["sigmoid"] += self.sz
        
    def compute_backward(self) :
        super(Sigmoid, self).compute_backward()

class Relu(Node) :
    def __init__(self, sz) :
        super(Relu, self).__init__()
        self.sz = sz
        
    def compute_forward(self) :
        super(Relu, self).compute_forward()
        self.fwd_costs["relu"] += self.sz
        
    def compute_backward(self) :
        super(Relu, self).compute_backward()
        
class Add(Node) :
    def __init__(self, sz) :
        super(Add, self).__init__()
        self.sz = sz
        
    def compute_forward(self) :
        super(Add, self).compute_forward()
        self.fwd_costs["add"] += self.sz
        
    def compute_backward(self) :
        super(Add, self).compute_backward()
        
class Mul(Node) :
    def __init__(self, sz) :
        super(Mul, self).__init__()
        self.sz = sz
        
    def compute_forward(self) :
        super(Mul, self).compute_forward()
        self.fwd_costs["mul"] += self.sz
        
    def compute_backward(self) :
        super(Mul, self).compute_backward()
                
class Tanh(Node) :
    def __init__(self, sz) :
        super(Tanh, self).__init__()
        self.sz = sz
        
    def compute_forward(self) :
        super(Tanh, self).compute_forward()
        self.fwd_costs["mul"] += 2*self.sz
        self.fwd_costs["add"] += self.sz
        
    def compute_backward(self) :
        super(Tanh, self).compute_backward()
        
class GemmAdd(Node) :
    def __init__(self, d1, d2) :
        super(GemmAdd, self).__init__()
        self.d1 = d1
        self.d2 = d2
        
    def compute_forward(self) :
        super(GemmAdd, self).compute_forward()
        self.fwd_costs["add"] += self.d1*self.d2
        
    def compute_backward(self) :
        super(GemmAdd, self).compute_backward()
        
class GemmAdd3(Node) :
    def __init__(self, d1, d2, d3) :
        super(GemmAdd3, self).__init__()
        self.d1 = d1
        self.d2 = d2
        self.d3 = d3
        
    def compute_forward(self) :
        super(GemmAdd3, self).compute_forward()
        self.fwd_costs["add"] += self.d1*self.d2*self.d3
        
    def compute_backward(self) :
        super(GemmAdd, self).compute_backward()
    
class MatMul(Node) :
    def __init__(self, s1, s2, s3) :
        super(MatMul, self).__init__()
        self.s1 = s1
        self.s2 = s3
        self.s3 = s3
        
    def compute_forward(self) :
        super(MatMul, self).compute_forward()
        self.fwd_costs["mul"] += self.s1*self.s2*self.s3
        self.fwd_costs["add"] += self.s1*(self.s2-1)*self.s3
        
    def compute_backward(self) :
        super(MatMul, self).compute_backward()
        
class MatMul3(Node) :
    def __init__(self, d1, d2, d3, d4) :
        super(MatMul3, self).__init__()
        self.d1 = d1
        self.d2 = d2
        self.d3 = d3
        self.d4 = d4
        
        self.components += [
            MatMul(d2, d3, d4)
            for _ in range(d1)
        ]
    
    def compute_forward(self) :
        super(MatMul3, self).compute_forward()
    
    def compute_backward(self) :
        super(MatMul3, self).compute_backward()

### Layer Definitions

In [4]:
class FAL(Node) :
    def __init__(self, d1, d2, d3, d4) :
        super(FAL, self).__init__()
        self.d1 = d1
        self.d2 = d2
        self.d3 = d3
        self.d4 = d4
        
        self.components = [
            MatMul3(d1, d3, d2, d2),
            MatMul3(d1, d2, d3, d4),
            GemmAdd3(d1, d2, d4),
            Relu(d1*d2*d4)
        ]
    
    def compute_forward(self) :
        super(FAL, self).compute_forward()
        
    def compute_backward(self) :
        super(FAL, self).compute_backward()
                
class LSTMStep(Node) :
    def __init__(self, t, inputdim, hiddendim, dim1, dim3) :
        super(LSTMStep, self).__init__()
        self.t = t
        self.inputdim = inputdim
        self.hiddendim = hiddendim
        self.dim1 = dim1
        self.dim3 = dim3
        
        self.components = [
            MatMul(dim1, inputdim, dim3),
            MatMul(dim1, hiddendim, dim3),
            Tanh(dim1*hiddendim),
            Add(dim1*dim3),
            GemmAdd(dim1, dim3),
            Sigmoid(dim1*hiddendim),
            Sigmoid(dim1*hiddendim),
            Tanh(dim1*hiddendim),
            Mul(dim1*hiddendim),
            Mul(dim1*hiddendim),
            Add(dim1*hiddendim),
            Sigmoid(dim1*hiddendim),
            Tanh(dim1*hiddendim),
            Mul(dim1*hiddendim)
        ]
        
    def compute_forward(self) :
        super(LSTMStep, self).compute_forward()
        
    def compute_backward(self) :
        super(LSTMStep, self).compute_backward()
                        
class LSTM(Node) :
    def __init__(self, numunits, idim, hdim, d1, d3) :
        super(LSTM, self).__init__()
        self.numunits = numunits
        self.idim = idim
        self.hdim = hdim
        self.d1 = d1
        self.d3 = d3
        
        self.components += [
            LSTMStep(t, idim, hdim, d1, d3)
            for t in range(numunits)
        ]
        
    def compute_forward(self) :
        super(LSTM, self).compute_forward()
        
    def compute_backward(self) :
        super(LSTM, self).compute_backward()
        
# The way this functions is defined assumes only 2 FAL and 1 LSTM layer. 
# Generalize this to any number of FAL and LSTM layers by modifying the constructor signature
class GCNLSTM(Node) :
    def __init__(self, d1, d2, d3, d4, unitstotal, idim, hdim, dim1, dim3) :
        super(GCNLSTM, self).__init__()
        self.d1 = d1
        self.d2 = d2 
        self.d3 = d3
        self.d4 = d4
        self.unitstotal = unitstotal
        self.idim = idim
        self.hdim = hdim
        self.dim1 = dim1
        self.dim3 = dim3
        
        self.components = [
            FAL(d1, d2, d3, d4),
            FAL(d1, d2, d4, d4),
            LSTM(unitstotal, idim, hdim, dim1, dim3),
            MatMul(dim1, hdim, d2),
            GemmAdd(dim1, d2),
            Sigmoid(dim1*d2)
        ]
        
    def compute_forward(self) :
        super(GCNLSTM, self).compute_forward()
        
    def compute_backward(self) :
        super(GCNLSTM, self).compute_backward()

## Get Cost

In [30]:
# Batch size 23
pass1 = GCNLSTM(23, 270, 3, 4, 4, 270, 4, 23, 16)
pass1.compute_forward()

# Batch size 10
pass2 = GCNLSTM(10, 270, 3, 4, 4, 270, 4, 10, 16)
pass2.compute_forward()

forward_cost = add_cost_dicts(pass1.fwd_costs, pass2.fwd_costs)

In [31]:
print(forward_cost)

{'add': 19538046, 'mul': 19603056, 'relu': 71280, 'sigmoid': 10494, 'mux': 0, 'sqrt': 0}
