In [2]:
!python3 -m pip install git+https://github.com/deepmind/dm-haiku
!python3 -m pip install optax

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-hs1u0f9t
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-hs1u0f9t
Defaulting to user installation because normal site-packages is not writeable


In [37]:
!git clone https://github.com/chao1224/BadGlobalMinima

Cloning into 'BadGlobalMinima'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 29 (delta 1), reused 4 (delta 0), pack-reused 20[K
Unpacking objects: 100% (29/29), done.


In [4]:
!pip install tensorflow-datasets

Defaulting to user installation because normal site-packages is not writeable
Collecting tensorflow-datasets
  Downloading tensorflow_datasets-4.3.0-py3-none-any.whl (3.9 MB)
[K     |████████████████████████████████| 3.9 MB 22.7 MB/s eta 0:00:01
Collecting promise
  Downloading promise-2.3.tar.gz (19 kB)
Collecting dill
  Downloading dill-0.3.3-py2.py3-none-any.whl (81 kB)
[K     |████████████████████████████████| 81 kB 1.6 MB/s  eta 0:00:01
Collecting importlib-resources
  Downloading importlib_resources-5.1.4-py3-none-any.whl (26 kB)
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.0.0-py3-none-any.whl (48 kB)
[K     |████████████████████████████████| 48 kB 923 kB/s  eta 0:00:01
Building wheels for collected packages: promise
  Building wheel for promise (setup.py) ... [?25ldone
[?25h  Created wheel for promise: filename=promise-2.3-py3-none-any.whl size=21494 sha256=838c679b33bf5fbf2f59acf114f5cee0d4407894b2bb59b493d94463fab1161e
  Stored in directory: /home/b

In [1]:
import haiku as hk
import optax
import jax
import jax.numpy as jnp
import tree
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import NamedTuple
from resnet import ResNet18

AUGMENTATION = True
ADVERSARIAL = True

ModuleNotFoundError: No module named 'tensorflow'

In [6]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [7]:
key1 = jax.random.PRNGKey(0)

if (ADVERSARIAL):
    y_train = jax.random.permutation(key1, y_train)



In [27]:
# We need ds_info for the next cell
(train_ds, test_ds), ds_info = tfds.load('cifar10', 
                                         split=['train', 'test'], 
                                         shuffle_files=True, 
                                         with_info=True)

train_ds = tf.data.Dataset.from_tensor_slices({"image": x_train, "label": y_train})


# fig = tfds.show_examples(train_ds, ds_info)

In [28]:
def preprocess(example):
    image, label = example['image'], example['label']

    # Data augmentation
    image = tf.image.resize_with_crop_or_pad(image, 170, 170) # Adiciona 10 pixels
    image = tf.image.random_crop(image, size=[160, 160, 3]) # Corta de volta para 160
    image = tf.image.random_flip_left_right(image) 

    image = tf.cast(image, tf.float32)
    return {'image': image, 'label': label}

In [29]:
if (AUGMENTATION):
    train_ds = train_ds.map(
        preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_ds = train_ds.cache()
train_ds = train_ds.shuffle(ds_info.splits['train'].num_examples)
train_ds = train_ds.batch(128)
ds_numpy = tfds.as_numpy(train_ds)

In [34]:
def _forward(batch, is_training):
    """Forward application of the resnet."""
    images = batch['image']
    net = ResNet18(10,
                    bn_config={'decay_rate': 0.9})
    return net(images, is_training=is_training)

# Transform our forwards function into a pair of pure functions.
forward = hk.transform_with_state(_forward)

In [35]:
def make_optimizer():
    """SGD with momentum and a fixed lr."""
    return optax.chain(
      optax.trace(decay=0.9, nesterov=False), #momentum
      optax.scale(-1e-3))

In [36]:
def l2_loss(params):
    return 0.5 * sum(jnp.sum(jnp.square(p)) for p in params)

In [37]:
class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    opt_state: optax.OptState

In [53]:
def loss_fn(params, state, batch):
    """Computes a regularized loss for the given batch."""
    logits, state = forward.apply(params, state, None, batch, is_training=True)
    labels = jax.nn.one_hot(batch['label'], 10)
    logits = logits.reshape(len(labels), 1, 10)  # match labels shape
    loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
    
    l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
                 if 'batchnorm' not in mod_name]
    loss = loss + 1e-4 * l2_loss(l2_params)
    return loss, (loss, state)

In [49]:
@jax.jit
def train_step(train_state, batch):
    """Applies an update to parameters and returns new state."""
    params, state, opt_state = train_state
    grads, (loss, new_state) = (
        jax.grad(loss_fn, has_aux=True)(params, state, batch))
    
    # Compute and apply updates via our optimizer.
    updates, new_opt_state = make_optimizer().update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    
    train_state = TrainState(new_params, new_state, new_opt_state)
    return train_state, loss

In [50]:
def initial_state(rng, batch):
    """Computes the initial network state."""
    params, state = forward.init(rng, batch, is_training=True)
    opt_state = make_optimizer().init(params)
    return TrainState(params, state, opt_state)

In [51]:
epochs = 5
rng = jax.random.PRNGKey(0)
batch = next(iter(ds_numpy))
train_state = initial_state(rng, batch)

In [54]:
for _ in range(epochs):
    total_losses = []
    for batch in ds_numpy:
        train_state, loss = train_step(train_state, batch)
        total_losses.append(loss)
    print(sum(total_losses)/len(total_losses))

KeyboardInterrupt: 