In [None]:
%load_ext autoreload
%autoreload 2

from generative import dataset, utils, models
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint
from flax.training import orbax_utils
from tqdm import tqdm
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct 


In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
  loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
  metrics: Metrics

model = models.GeneratorLinear()
dummy_input = jnp.zeros((1, 9))

# Initialize the parameters, optimizer, etc.
def create_train_state(module, rng, learning_rate, dummy_input):
  rngs = {'params': rng}
  params = module.init(rngs, dummy_input)['params'] # initialize parameters by passing a template image
  tx = optax.adam(learning_rate)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx,
      metrics=Metrics.empty())

@jax.jit
def train_step(state, images, labels):
    # Run the model forward one step. Using params.
    def loss_fn(params):
        model_output = state.apply_fn({'params': params}, labels)
        loss = optax.l2_loss(model_output, images).mean()
        return loss
    # Compute gradient and apply an optimization step.
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)

    # Log metrics.
    metric_updates = state.metrics.single_from_model_output(loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)

    return state

# Evaluation.
def pred_step(state, rng, batch_size):
	char_set = dataset.mnist_char_set()
	_, labels = dataset.mnist_colored(char_set, rng, batch_size=batch_size)
	model_output = state.apply_fn({'params': state.params}, labels)

	# utils.plot_image_row(images[:16])
	return model_output


In [None]:
# TRAINING LOOP

global_rng = jax.random.PRNGKey(0)
sub_rng, global_rng = jax.random.split(global_rng)
state = create_train_state(model, sub_rng, 3e-4, dummy_input=dummy_input)

char_set = dataset.mnist_char_set()

for step in tqdm(range(10*1000)):
    sub_rng, global_rng = jax.random.split(global_rng)
    images, labels = dataset.mnist_colored(char_set, sub_rng, batch_size=64)

    # Run a training step.
    state = train_step(state, images, labels)

    if step % 250 == 0:
        for metric, value in state.metrics.compute().items(): # compute metrics
            tqdm.write(f'[{step//50}] {metric}: {value:.4f}')
        state = state.replace(metrics=state.metrics.empty())

    if step % 500 == 0:
        sub_rng, global_rng = jax.random.split(global_rng)
        model_output = pred_step(state, sub_rng, batch_size=32)
        utils.plot_image_row(model_output[:8])

ckpt = {'model': state}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('saves/oracle/', ckpt, save_args=save_args)

In [None]:
# Charts for Generated Data

char_set = dataset.mnist_char_set()
frequency_table = jnp.zeros((5,6))
error_table = jnp.zeros((5,6))
examples_table = jnp.zeros((5,6,28,28,3))
for i in tqdm(range(100)):
    # make batch of 128 images. then categorize them. put the counts into frequency.
    sub_rng, global_rng = jax.random.split(global_rng)
    images = pred_step(state, sub_rng, batch_size=128)
    chars, colors, error = dataset.categorize(char_set, images)
    frequency_table = frequency_table.at[colors, chars].add(1)
    error_table = error_table.at[colors, chars].add(error)
    examples_table = examples_table.at[colors, chars].set(images)
utils.plot_frequency_table(frequency_table)
utils.plot_error_table(error_table, frequency_table)
utils.plot_examples_table(examples_table)