In [1]:
import numpy

In [2]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [3]:
from functools import partial

In [4]:
import random
import string

In [5]:
def rand_string():
    return ''.join([random.choice(string.ascii_letters + string.digits) for n in range(32)])

In [6]:
# @jax.jit
def split_and_sample(key, shape):
    key, subkey = jrng.split(key)
    val = jrng.normal(subkey, shape=shape)
    return key, val

In [110]:
class Layer:
    def __init__(self, name=None):
        if name is None:
            self.name = F'Layer+{rand_string()}'
        else:
            self.name = name
    
    def __call__(self, p, x):
        return self.forward(p, x)
        
    def params(self):
        return None
    
    def init_params(self, rng):
        return rng, self.params()
    
    def forward(self, p, x):
        return x

In [111]:
class Linear(Layer):
    def __init__(self, d_in, d_out, name=None):
        super(Linear, self).__init__(name)
        
        self.weight = jnp.zeros((d_in, d_out))
        self.bias = jnp.zeros((d_out))
        
        if name is None:
            self.name = F'Linear+{rand_string()}'
        
    def params(self):
        return dict([('weight', self.weight), ('bias', self.bias)])
    
    def init_params(self, rng):
        rng, self.weight = split_and_sample(rng, self.weight.shape)
        return rng, self.params()
    
    def forward(self, p, x):
        return jnp.dot(x, p['weight']) + p['bias']

In [135]:
class Tanh(Layer):
    def __init__(self, name=None):
        super(Tanh, self).__init__(name)
        
        if name is None:
            self.name = F'Tanh+{rand_string()}'
            
    def forward(self, p, x):
        return jnp.tanh(x)

In [136]:
class Softmax(Layer):
    def __init__(self, name=None):
        super(Softmax, self).__init__(name)
        
        if name is None:
            self.name = F'Softmax+{rand_string()}'
            
    def forward(self, p, x):
        x_exp = jnp.exp(x)
        return x_exp / jnp.sum(x_exp)

In [185]:
class Model:
    def __init__(self, rng, layers, loss=None, name=None):        
        if name is None:
            name = F'Model+{rand_string()}'
            
        self.layers = layers
        self.loss = loss
            
        self.params = dict()
        for ll in self.layers:
            rng, pp = ll.init_params(rng)
            if pp is not None:
                self.params[ll.name] = pp
        self.params_values, self.params_tree = jax.tree_flatten(self.params)

    @partial(jax.jit, static_argnums=(0,))
    def forward_(self, p, x):
        h = x
        for ll in self.layers:
            h = ll(None if ll.name not in p else p[ll.name], h)
        return h    
    
    @partial(jax.jit, static_argnums=(0,))
    def loss_(self, p, x, y):
        def dummy(mymodel, params, x, y):
            return mymodel.loss(x, y)
        return jax.vmap(dummy, in_axes=(None,None,0,0))(self, self.params, self.forward_(p, x), y).mean()
    
    def forward(self, x, single=False):
        if single:
            return self.forward_(self.params, x)
        
        def dummy(mymodel, params, x):
            return mymodel.forward_(params, x)
        return jax.vmap(dummy, in_axes=(None, None, 0))(self, self.params, x)
    
    def loss_grad(self, x, y):
        return self.loss_(self.params, x, y), jax.grad(self.loss_)(self.params, x, y)

In [186]:
class SGD:
    def __init__(self, model, lr=0.01):
        self.lr = lr
        self.model = model
        
    def step(self, grad):
        for ll in self.model.layers:
            if ll.name not in self.model.params:
                continue
            pp = self.model.params[ll.name]
            gg = grad[ll.name]
            for kk in pp.keys():
                pp[kk] = pp[kk] - self.lr * gg[kk]

In [187]:
@jax.jit
def cross_entropy(p, y):
    return -jnp.take(jnp.log(p), y)

In [188]:
rng = jrng.PRNGKey(1234)

In [182]:
mymodel = Model(rng, [Linear(10,10), Tanh(), Linear(10,10), Softmax()], loss=cross_entropy)

In [189]:
optim = SGD(mymodel, lr=0.1)

In [190]:
target_labels = numpy.floor(numpy.random.rand(256)).astype('int')
inputs = numpy.random.rand(256,10)

for ii in range(1000):
    loss, grad = mymodel.loss_grad(inputs, target_labels)
    
    if numpy.mod(ii, 100) == 0:
        print(loss)
    optim.step(grad)

5.7819085


NameError: name 'lr' is not defined