Download colab_utils and import



<a href="https://colab.research.google.com/github/beangoben/intro_dl/blob/master/MNIST_LR_MLP_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [None]:
!wget https://raw.githubusercontent.com/beangoben/workshop_template/master/colab_utils.py -O colab_utils.py
!rm -rf sample_data
import colab_utils
!pip install git+https://github.com/google/flax.git
!pip install umap-learn

# MUY IMPORTANTE: USA UN GPU O TPU (choose runtime)

# Importa modulos


In [None]:
from collections import OrderedDict
import tqdm.auto as tqdm
from more_itertools import chunked
import colab_utils

import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import altair as alt

import umap
import sklearn
# Jax
import jax
from jax import numpy as jnp, random, jit, lax
# Flax
import flax
from flax import nn, optim
# Tensorflow
import tensorflow as tf
import tensorflow_datasets as tfds

_ = jnp.square(2.)

colab_utils.print_module_versions([tf, tfds, jax])
colab_utils.matplotlib_settings()

Primero lodearemos algunos datos

In [None]:
#dataset = 'fashion_mnist'
dataset_name = 'mnist'

train_ds = tfds.as_numpy(tfds.load(dataset_name, split=tfds.Split.TRAIN, batch_size=-1))
x_train, y_train = train_ds['image'], train_ds['label']

test_ds = tfds.as_numpy(tfds.load(dataset_name, split=tfds.Split.TEST, batch_size=-1))
x_test, y_test = test_ds['image'], test_ds['label']

# Necesitamos DL? Primero EDA!

In [None]:
def flatten_image(x):
    return np.reshape(x, (x.shape[0],x.shape[1]*x.shape[2]))

flat_x_train = flatten_image(x_train)
flat_y_train = y_train.ravel()

flat_x_test = flatten_image(x_test)
flat_y_test = y_test.ravel()
print(flat_x_train.shape, y_train.shape)
print(flat_x_test.shape, y_test.shape)

In [None]:
def labels_to_nodecolors(labels):
    """Convierte etqiuetas a colores."""
    cols = sns.color_palette("Set3", 10)
    return [cols[int(l)] for l in labels.ravel()]

def plot_color_legend(dataset_name):
    cols = sns.color_palette("Set3", 10)
    categories = {'fashion_mnist':['T-shirt/top','Trouser',
                                   'Pullover','Dress', 'Coat','Sandal',
               'Shirt','Sneaker','Bag','Ankle boot'],
               'mnist': list(range(10))}[dataset_name]
    sns.palplot(cols)
    plt.xticks(np.arange(10), categories, rotation=45)
    plt.show()
    
plot_color_legend(dataset_name)

In [None]:
import sklearn.pipeline
import sklearn.decomposition
import sklearn.preprocessing

pca_pipe = sklearn.pipeline.Pipeline([('scaler', sklearn.preprocessing.StandardScaler()),
                                  ('dim_reduce', sklearn.decomposition.PCA(2))])

umap_pipe = sklearn.pipeline.Pipeline([('scaler', sklearn.preprocessing.StandardScaler()),
                                  ('dim_reduce', umap.UMAP())])

umap_pipe.fit(flat_x_train)
x_umap = umap_pipe.transform(flat_x_test)

pca_pipe.fit(flat_x_train)
x_pca = pca_pipe.transform(flat_x_test)

print(x_umap.shape, x_pca.shape)

In [None]:
plot_color_legend(dataset_name)
plt.scatter(x_pca[:,0], x_pca[:, 1],
            c=labels_to_nodecolors(flat_y_test),
            s=1, alpha=0.5)
plt.show()

In [None]:
indices = np.random.permutation(len(x_umap))[:5000]

vis_df = pd.DataFrame()
vis_df['UMAP1'] = x_umap[indices, 0]
vis_df['UMAP2'] = x_umap[indices, 1]
vis_df['PC1'] = x_pca[indices, 0]
vis_df['PC2'] = x_pca[indices, 1]
vis_df['label'] = flat_y_test[indices]
vis_df

In [None]:
alt.Chart(vis_df).mark_circle(size=10).encode(
    x='UMAP1:Q',
    y='UMAP2:Q',
    color='label:N',
    tooltip=['label']
).interactive()

## Comparativa de PCA/UMAP

In [None]:
brush = alt.selection(type='interval', resolve='global')

scatter1 = alt.Chart(vis_df).mark_circle(size=4).encode(
    x='PC1:Q',
    y='PC2:Q',
    color = alt.condition(brush, alt.Color('label:N'), alt.value('lightgray')),
    tooltip=['label']
).add_selection(brush)

scatter2 = alt.Chart(vis_df).mark_circle(size=4).encode(
    x='UMAP1:Q',
    y='UMAP2:Q',
    color = alt.condition(brush, alt.Color('label:N'), alt.value('lightgray')),
    tooltip=['label']
).add_selection(brush)

scatter1 | scatter2

# Necesitamos DL? Un modelo lineal

Reloadear datos, ahora con splits

In [None]:
train_ds = tfds.as_numpy(tfds.load(dataset_name, split=tfds.Split.TRAIN, batch_size=-1))
x_train, y_train = train_ds['image'], train_ds['label']

test_ds = tfds.as_numpy(tfds.load(dataset_name, split=tfds.Split.TEST, batch_size=-1))
x_test, y_test = test_ds['image'], test_ds['label']

Queremos :
$$ \hat{y} = W \cdot x + b $$

Talque la diferencia $|y-\hat{y}|$ sea minima (error).

* Cuales son las dimensiones de W y b?
* Cuantos parametros necesitamos "aprender"?
* Que significa aprender?
* Como optimizamos W + b? 


In [None]:
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

N_PIXELS = x_train.shape[1]*x_train.shape[2]
N_LABELS = 10

rng_key = random.PRNGKey(0)
params = [random_layer_params(N_PIXELS, N_LABELS, rng_key)]
params

## Primero el "modelo"

$$ \hat{y} = W \cdot x + b $$

In [None]:
from jax.scipy.special import logsumexp

def predict(params, x):
    w,b = params[0]
    logits = jnp.dot(w, x) + b
    return logits - logsumexp(logits)

In [None]:
def jnp_flat_image(x):
    return jnp.reshape(x, (len(x), x.shape[1]*x.shape[2]))

one_image = x_train[[0]]
print(one_image.shape)
one_flat_image = jnp_flat_image(one_image)
one_flat_image.shape

In [None]:
preds = predict(params, one_flat_image[0])
preds

# VMAP: en accion, vectorizacion automatica

In [None]:
batched_predict = jax.vmap(predict, (None, 0))
batched_preds = batched_predict(params, one_flat_image)
batched_preds

## Como sabemos si estamos mejorando? 

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
y_train = one_hot(train_ds['label'], N_LABELS)
y_test = one_hot(test_ds['label'], N_LABELS)

In [None]:
x_train = jnp_flat_image(train_ds['image'])
x_test = jnp_flat_image(test_ds['image'])

In [None]:
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

accuracy(params, x_train, y_train), loss(params, x_train, y_train)

# GRAD & JIT : gradientes y compilacion

In [None]:
LEARNING_RATE = 0.01
@jit
def update(params, x, y):
  grads = jax.grad(loss)(params, x, y)
  return [(w - LEARNING_RATE * dw, b - LEARNING_RATE * db)
          for (w, b), (dw, db) in zip(params, grads)]

## Empieza el aprendizaje

In [None]:
BATCH_SIZE = 512
NUM_EPOCHS = 50
NUM_LABELS = 10

pbar = tqdm.tqdm(range(NUM_EPOCHS))
for epoch in pbar:
    batch_indices = list(chunked(np.random.permutation(len(x_train)), BATCH_SIZE))[:-1]
    for batch_index in batch_indices:
        x = x_train[batch_index]
        y = y_train[batch_index]
        params = update(params, x, y)

    pbar.set_postfix({'train_acc':accuracy(params, x_train, y_train),
                      'test_acc':accuracy(params, x_test, y_test)})

# Crea un modelo CNN con Flax



In [None]:
train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
train_ds = train_ds.map(lambda x: {'image': tf.cast(x['image'], tf.float32),
                                    'label': tf.cast(x['label'], tf.int32)})
train_ds = train_ds.cache().shuffle(1000).batch(128)
test_ds = tfds.as_numpy(tfds.load(
    'mnist', split=tfds.Split.TEST, batch_size=-1))
test_ds = {'image': test_ds['image'].astype(jnp.float32),
            'label': test_ds['label'].astype(jnp.int32)}

In [None]:
class CNN(flax.nn.Module):
  """Una red convolucional."""
  
  def apply(self, x):
    x = nn.Conv(x, features=32, kernel_size=(3, 3))
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(x, features=64, kernel_size=(3, 3))
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(x, features=256)
    x = nn.relu(x)
    x = nn.Dense(x, features=10)
    x = nn.log_softmax(x)
    return x

In [None]:
rng = random.PRNGKey(0)
rng, init_rng = random.split(rng)
_, initial_params = CNN.init_by_shape(rng, [((1, 28, 28, 1), jnp.float32)])
model = flax.nn.Model(CNN, initial_params)

In [None]:
@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def evaluate(model, eval_ds):
  logits = model(eval_ds['image'] / 255.0)
  return compute_metrics(logits, eval_ds['label'])
  
@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch['image'])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch['label']))
    return loss
  grad = jax.grad(loss_fn)(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  return optimizer

## Optimizador

In [None]:
optimizer = flax.optim.Momentum(learning_rate=0.1, beta=0.9).create(model)

## El training loop

In [None]:
NUM_EPOCHS = 10
pbar = tqdm.tqdm(range(NUM_EPOCHS))
stats = []
for epoch in pbar:
    for batch in tfds.as_numpy(train_ds):
        batch['image'] = batch['image'] / 255.0
        optimizer = train_step(optimizer, batch)
    stats.append(evaluate(optimizer.target, test_ds))
    pbar.set_postfix(stats[0])


# Aprendiendo representaciones

Viendo los penultimate layer embeddings