In [25]:
import renn

import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp

from jax.experimental import stax, optimizers

import numpy as np

from tqdm import tqdm

from functools import partial

In [33]:
vocab_file = '/Users/nirum/code/reverse-engineering-neural-networks/vocab/ag_news.vocab'

# Load data
train_dset = renn.data.ag_news('train', vocab_file, sequence_length=50)
test_dset = renn.data.ag_news('test', vocab_file, sequence_length=50)

# Load vocab
with open(vocab_file, 'r') as f:
    vocab = f.readlines()
vocab_size = len(vocab)

example = next(iter(train_dset))

In [51]:
def SequenceSum():
    def init_fun(_, input_shape):
        return (input_shape[0], input_shape[2]), ()
    def apply_fun(_, inputs, **kwargs):
        return jnp.sum(inputs, axis=1)
    return init_fun, apply_fun

In [121]:
emb_size = 32
num_classes = 4
input_shape = (-1, 50)
l2_pen = 1e-3

# Linear model
init_fun, apply_fun = stax.serial(
    renn.embedding(vocab_size, emb_size),
    SequenceSum(),
    stax.Dense(num_classes),
    stax.LogSoftmax,
)

# Initialize
key = jax.random.PRNGKey(0)
output_shape, initial_params = init_fun(key, input_shape)

# Hack to set the embedding for 0 to 0
emb = params[0]
new_emb = np.array(emb)
new_emb[0] = np.zeros(emb_size)
initial_params = [jnp.array(new_emb), *params[1:]]

# Loss
def xent(params, batch):
    logits = apply_fun(params, batch['inputs'])
    targets = renn.one_hot(batch['labels'], num_classes)
    data_loss = -jnp.mean(jnp.sum(targets * logits, axis=1))
    reg_loss = l2_pen * renn.norm(params)
    return data_loss + reg_loss

f_df = jax.value_and_grad(xent)

# Accuracy
@jax.jit
def accuracy(params, batch):
    logits = apply_fun(params, batch['inputs'])
    predictions = jnp.argmax(logits, axis=1)
    return jnp.mean(predictions == batch['labels'])

In [134]:
# Train
learning_rate = optimizers.exponential_decay(2e-3, 1000, 0.8)
init_opt, update_opt, get_params = optimizers.adam(learning_rate)

state = init_opt(initial_params)
losses = []

@jax.jit
def step(k, opt_state, batch):
    params = get_params(opt_state)
    loss, gradients = f_df(params, batch)
    new_state = update_opt(k, gradients, opt_state)
    return new_state, loss

def test_acc(params):
    return jnp.array([accuracy(params, batch) for batch in tfds.as_numpy(test_dset)])

In [135]:
for epoch in range(3):
    print('=====================================')
    print(f'== Epoch #{epoch}')
    p = get_params(state)
    acc = np.mean(test_acc(p))
    print(f'== Test accuracy: {100. * acc:0.2f}%')
    print('=====================================')
    
    for batch in tfds.as_numpy(train_dset):
        k = len(losses)
        state, loss = step(k, state, batch)
        losses.append(loss)

        if k % 100 == 0:
            p = get_params(state)
            print(f'[step {k}]\tLoss: {np.mean(losses[k-100:k]):0.4f}', flush=True)

print('=====================================')
print(f'== Epoch #{epoch}')
p = get_params(state)
acc = np.mean(test_acc(p))
print(f'== Test accuracy: {100. * acc:0.2f}%')
print('=====================================')

== Epoch #0
== Test accuracy: 23.79%
[step 0]	Loss: nan
[step 100]	Loss: 0.7176
[step 200]	Loss: 0.3863
[step 300]	Loss: 0.3486
[step 400]	Loss: 0.3514
[step 500]	Loss: 0.3505
[step 600]	Loss: 0.3293
[step 700]	Loss: 0.3424
[step 800]	Loss: 0.3518
[step 900]	Loss: 0.3418
[step 1000]	Loss: 0.3438
[step 1100]	Loss: 0.3183
[step 1200]	Loss: 0.3243
[step 1300]	Loss: 0.3280
[step 1400]	Loss: 0.3340
[step 1500]	Loss: 0.3169
[step 1600]	Loss: 0.3031
[step 1700]	Loss: 0.3306
== Epoch #1
== Test accuracy: 90.74%
[step 1800]	Loss: 0.2796
[step 1900]	Loss: 0.2495
[step 2000]	Loss: 0.2194
[step 2100]	Loss: 0.2161
[step 2200]	Loss: 0.2212
[step 2300]	Loss: 0.2160
[step 2400]	Loss: 0.2262
[step 2500]	Loss: 0.2259
[step 2600]	Loss: 0.2244
[step 2700]	Loss: 0.2315
[step 2800]	Loss: 0.2324
[step 2900]	Loss: 0.2202
[step 3000]	Loss: 0.2260
[step 3100]	Loss: 0.2385
[step 3200]	Loss: 0.2200
[step 3300]	Loss: 0.2230
[step 3400]	Loss: 0.2283
== Epoch #2
== Test accuracy: 90.31%
[step 3500]	Loss: 0.2346
[ste