In [1]:
import kerax
from kerax.layers import Concatenate
from kerax.layers import Conv2D, Dense, Flatten, Input, BatchNormalization, Activation, Add
from kerax.models import Model
from kerax.losses import CategoricalCrossEntropy
import numpy as np
import jax.numpy as jnp
from kerax import activations
kerax.enable_jit_execution(True)

In [2]:
train_x = np.random.random((128, 28,28,1))
train_y = np.random.random((128, 10))

val_x = np.random.random((128, 28,28,1))
val_y = np.random.random((128, 10))

In [3]:
inputs = Input((28, 28, 1))
conv1 = Conv2D(64, 3, activation=activations.ReLU)(inputs)
conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv1)
act1 = Activation('relu')(conv1)
conv3 = Conv2D(128, 3, padding='same')(act1)
act2 = Activation('relu')(conv3)
conv4 = Conv2D(128, 3, padding='same')(act2)
add = Add()([conv2, conv4])
flatten = Flatten()(add)
dense1 = Dense(512, activation='relu')(flatten)
dense2 = Dense(10, activation='softmax')(dense1)


In [11]:
from collections import namedtuple, deque

class Graph:
    def __init__(self, **kwargs):
        self._validate_init(**kwargs)
        self.connected_layers = []
        self.connection = namedtuple('Layer', ['layer1', 'layer2'])
        self.layers = []
        self.connect_layers()

    def flatten(self, x):
        def _flatten(x, result=[]):
            for i in x:
                if isinstance(i, list):
                    return _flatten(i, result)
                else:
                    result.append(i)
            return result
        return _flatten(x, [])

    def _validate_init(self, **kwargs):
        if (
            kwargs.get('inputs', False)
            and isinstance(kwargs.get('inputs', False), (list, tuple)) != 1
        ):
            raise Exception('Use \'input\' argument instead of \'inputs\' if you want to pass a list or a tuple')
        elif (
            kwargs.get('input', False)
            and isinstance(kwargs.get('input', False), (list, tuple)) >= 1
        ):
            raise Exception('Use \'inputs\' argument instead of \'input\' if you want to pass an input layer')

        inputs = kwargs.get('inputs', False) or kwargs.get('input', False)

        if not inputs:
            raise Exception('inputs should be provided')

        if isinstance(inputs, (list, tuple)):
            self.inputs = self.flatten(inputs)
            self.input = None
        else:
            self.input = inputs
            self.inputs = [inputs]

        outputs = kwargs.get('outputs', False) or kwargs.get('output', False)

        if not outputs:
            raise Exception('outputs should be provided')

        if (
            kwargs.get('outputs', False)
            and isinstance(kwargs.get('outputs', False), (list, tuple)) != 1
        ):
            raise Exception('Use \'output\' argument instead of \'outputs\' if you want to pass a list or a tuple')
        elif (
            kwargs.get('output', False)
            and isinstance(kwargs.get('output', False), (list, tuple)) >= 1
        ):
            raise Exception('Use \'outputs\' argument instead of \'output\' if an output layer')

        if isinstance(outputs, (list, tuple)):
            self.outputs = self.flatten(outputs)
            self.output = None
        else:
            self.output = outputs
            self.outputs = [outputs]

    def get_layers(self):
        queue = deque()

        if self.inputs:
            queue += self.flatten(self.inputs)
        else:
            queue.append(self.input)

        if self.outputs:
            queue += self.flatten(self.outputs)
        else:
            queue.append(self.output)

        visited = {i.index for i in queue}
        self.layers += queue

        while queue:
            current_pointer = queue.popleft()
            for i in current_pointer.next:

                if i.index not in visited:
                    self.layers.append(i)
                    queue.append(i)
                    visited.add(i.index)

    def connect_layers(self):
        self.get_layers()
        self.layers = sorted(self.layers, key=lambda x: x.index)

        for layer in self.layers:
            self.connected_layers += [self.connection(layer, i) for i in layer.next]
    
    def check_dependencies(self, layer):
        for p in layer.prev:
            pass

    def have_dependencies(self, layer):
        return len(layer.prev) != 0

    def same_input_len(self, inputs):
        return len(inputs) == len(self.inputs)

    def flow_data(self, *args):
        if not self.same_input_len(args):
            raise Exception(f'Not the same input length expected {len(self.inputs)} found {len(args)}')
        
        for arg, input_layer in zip(args, self.inputs):
            input_layer(arg)

        for layers in self.connected_layers:
            if layers.layer1.output is not None:
                print(layers)
                layers.layer2(layers.layer1)

                        


In [12]:
g = Graph(input=inputs, output=dense2)

In [13]:
g.layers

[<Input Layer>,
 <Convolutional Layer with input shape (None, 28, 28, 1) and output shape (None, 26, 26, 64)>,
 <Convolutional Layer with input shape (None, 26, 26, 64) and output shape (None, 26, 26, 128)>,
 <relu Activation Layer with input shape (None, 26, 26, 64) and output shape (None, 26, 26, 64)>,
 <Convolutional Layer with input shape (None, 26, 26, 64) and output shape (None, 26, 26, 128)>,
 <relu Activation Layer with input shape (None, 26, 26, 128) and output shape (None, 26, 26, 128)>,
 <Convolutional Layer with input shape (None, 26, 26, 128) and output shape (None, 26, 26, 128)>,
 <Add Layer with input shape (None, 26, 26, 128) and output shape (None, 26, 26, 128)>,
 <Flatten Layer with input shape (None, 26, 26, 128) and output shape (None, 86528)>,
 <Dense Layer with input shape (None, 86528) and output shape (None, 512)>,
 <Dense Layer with input shape (None, 512) and output shape (None, 10)>]

In [18]:
inputs(train_x)
flag = True
for l in g.connected_layers:
    if l.layer1.output is not None:
        for i in l.layer2.next:
            if i.output is None:
                flag = False
        if flag:
            l.layer2(l.layer1.output)
        flag = True
        

True
True


In [20]:
for l in g.connected_layers:
    print(l.layer1.name, l.layer2.name)

Input_1 Conv2D_1
Conv2D_1 Conv2D_2
Conv2D_1 Activation_1
Conv2D_2 Add_1
Activation_1 Conv2D_3
Conv2D_3 Activation_2
Activation_2 Conv2D_4
Conv2D_4 Add_1
Add_1 Flatten_1
Flatten_1 Dense_1
Dense_1 Dense_2


In [26]:
inputs(train_x)

train_x2 = np.random.random((128, 26,26,1))
inputs2(train_x2)
for i in g.layers[1:]:
    if isinstance(i, Input):
        continue
    i(i.prev[0].output)


Exception: Not expected shape, input dims should be (None, 28, 28, 1) found (128, 26, 26, 1)

(128, 10)