In [1]:
!pip install git+https://github.com/deepmind/dm-haiku
!pip install optax

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-hqx0kbqk
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-hqx0kbqk
Collecting jmp>=0.0.2
  Downloading https://files.pythonhosted.org/packages/ff/5c/1482f4a4a502e080af2ca54d7f80a60b5d4735f464c151666d583b78c226/jmp-0.0.2-py3-none-any.whl
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.5.dev0-cp37-none-any.whl size=553003 sha256=057e07bbbf1371edd36ca02e258d81297f1ad6228c43f3649d60853f9c38a871
  Stored in directory: /tmp/pip-ephem-wheel-cache-wy6itgwv/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499
Successfully built dm-haiku
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.5.dev0 jmp-0.0.2
Collecting optax
[?25l  Downloading https://files.pythonhosted.org/packages/ec/

In [37]:
!git clone https://github.com/chao1224/BadGlobalMinima

Cloning into 'BadGlobalMinima'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 29 (delta 1), reused 4 (delta 0), pack-reused 20[K
Unpacking objects: 100% (29/29), done.


In [69]:
import haiku as hk
import optax
import jax
import jax.numpy as jnp
import tree
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import NamedTuple

AUGMENTATION = True
ADVERSARIAL = True

In [70]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [71]:
key1 = jax.random.PRNGKey(0)

if (ADVERSARIAL):
    y_train = jax.random.permutation(key1, y_train)

In [72]:
# (train_ds, test_ds), ds_info = tfds.load('cifar10', 
#                                           split=['train', 'test'], 
#                                           shuffle_files=True, 
#                                           with_info=True)

train_ds = tf.data.Dataset.from_tensor_slices({"image": x_train, "label": y_train})


# fig = tfds.show_examples(train_ds, ds_info)

In [73]:
def preprocess(example):
    image, label = example['image'], example['label']

    # Data augmentation
    image = tf.image.resize_with_crop_or_pad(image, 170, 170) # Adiciona 10 pixels
    image = tf.image.random_crop(image, size=[160, 160, 3]) # Corta de volta para 160
    image = tf.image.random_flip_left_right(image) 

    image = tf.cast(image, tf.float32)
    return {'image': image, 'label': label}

In [74]:
if (AUGMENTATION):
    train_ds = train_ds.map(
        preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_ds = train_ds.cache()
train_ds = train_ds.shuffle(ds_info.splits['train'].num_examples)
train_ds = train_ds.batch(128)
ds_numpy = tfds.as_numpy(train_ds)

Tensor("args_0:0", shape=(32, 32, 3), dtype=uint8) Tensor("args_1:0", shape=(1,), dtype=uint8)


In [None]:
def _forward(batch, is_training):
  """Forward application of the resnet."""
  images = batch['images']
  net = hk.nets.ResNet18(10,
                         resnet_v2=True,
                         bn_config={'decay_rate': 0.9})
  return net(images, is_training=is_training)

# Transform our forwards function into a pair of pure functions.
forward = hk.transform_with_state(_forward)

In [None]:
def make_optimizer():
  """SGD with momentum and a fixed lr."""
  return optax.chain(
      optax.trace(decay=0.9, nesterov=False), #momentum
      optax.scale(-1e-3))

In [None]:
def l2_loss(params):
  return 0.5 * sum(jnp.sum(jnp.square(p)) for p in params)

In [None]:
class TrainState(NamedTuple):
  params: hk.Params
  state: hk.State
  opt_state: optax.OptState

In [None]:
def loss_fn(params, state, batch):
  """Computes a regularized loss for the given batch."""
  logits, state = forward.apply(params, state, None, batch, is_training=True)
  labels = jax.nn.one_hot(batch['labels'], 10)
  loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
  l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
               if 'batchnorm' not in mod_name]
  loss = loss + 1e-4 * l2_loss(l2_params)
  return loss, (loss, state)

In [None]:
@jax.jit
def train_step(train_state, batch):
  """Applies an update to parameters and returns new state."""
  params, state, opt_state = train_state
  grads, (loss, new_state) = (
      jax.grad(loss_fn, has_aux=True)(params, state, batch))
  
  # Compute and apply updates via our optimizer.
  updates, new_opt_state = make_optimizer().update(grads, opt_state)
  new_params = optax.apply_updates(params, updates)

  train_state = TrainState(new_params, new_state, new_opt_state)
  return train_state, loss

In [None]:
def initial_state(rng, batch):
  """Computes the initial network state."""
  params, state = forward.init(rng, batch, is_training=True)
  opt_state = make_optimizer().init(params)
  return TrainState(params, state, opt_state)

In [None]:
epochs = 5
rng = jax.random.PRNGKey(0)
batch = next(iter(ds_numpy))
train_state = initial_state(rng, batch)

In [None]:
for _ in range(epochs):
  total_losses = []
  for batch in ds_numpy:
    train_state, loss = train_step(train_state, batch)
    total_losses.append(loss)
  print(sum(total_losses)/len(total_losses))

1.484259
1.2876029
1.1279528
0.98609024
0.8480268
