In [1]:
from discopy import *
from discopy.function import *

In [2]:
def scalar_mult(scalar):
     return Function('scalar_mult({})'.format(repr(scalar)), Dim(1), Dim(1), lambda x: scalar * x)
    
def scalar_mults(dom, weights):
    result = Id(0)
    for i in range(dom):
        result = result @ scalar_mult(weights[i])
    return result
   
def bias(scalar):
    return Function('bias({})'.format(repr(scalar)), Dim(0), Dim(1), lambda x: np.array([scalar]))

def merge(cod, copies):
    @discofunc(cod * copies, cod, name='merge({}, {})'.format(cod, copies))
    def add(x):
        return np.array([np.sum([x[i + cod * j] for j in range(copies)]) for i in range(cod)])
    return add

def split(dom, copies):
    @discofunc(dom, dom * copies, name='split({}, {})'.format(dom, copies))
    def copy(x):
        return np.concatenate([x for i in range(copies)])
    return copy

@discofunc(1, 1)
def sigmoid(x):
    return 1/(1 + np.exp(-x))

In [3]:
def neuron(dom, cod, weights, beta=0): # weights is a 1d array of length dom, beta is a scalar bias
    return scalar_mults(dom, weights) @ bias(beta) >> merge(1, dom + 1) >> sigmoid >> split(1, cod)

In [13]:
neuron(4, 4, [0, 2.1, 0.3, 0.1], 0.5)

(((((((scalar_mult(0) @ scalar_mult(2.1)) @ scalar_mult(0.3)) @ scalar_mult(0.1)) @ bias(0.5)) >> merge(1, 5)) >> sigmoid) >> split(1, 4))

In [14]:
def layer(dom, cod, weights, biases): # weights is an array of size: cod x dom (note cod = number of neurons)
    neurons = Id(0)                   # biases is a 1d array of length cod
    for i in range(cod):
        neurons = neurons @ neuron(dom, 1, weights[i], biases[i])
    return split(dom, cod) >> neurons

In [5]:
from jax import jit, grad

print(jit(neuron(4, 4, [0, 2.1, 0.3, 0.1]))(np.array([0., 0.3, 1.2, 3.2])))
print(grad(lambda x: neuron(4, 1, [0., 2.1, 0.3, 0.1])(x)[0])(np.array([0., 0.3, 1.2, 3.2])))

[0.7875132 0.7875132 0.7875132 0.7875132]
[0.         0.35140604 0.05020086 0.01673362]


In [12]:
disconnected_layer = lambda x: layer(3, 1, [[0., 0., 0.]], [0., 0.])(x)[0]
assert np.all(grad(disconnected_layer)(np.array([2., 3.4, 1.])) == np.array([0., 0., 0.]))