# NN vs. CNN - Fashion MNIST usando JAX

Referências:

NN - https://coderzcolumn.com/tutorials/artificial-intelligence/guide-to-create-simple-neural-networks-using-jax#2

CNN - https://coderzcolumn.com/tutorials/artificial-intelligence/jax-guide-to-create-convolutional-neural-networks

## Fazendo os imports

Import das bibliotecas do JAX para criação das redes

In [None]:
import jax

print("JAX Version : {}".format(jax.__version__))

from jax.example_libraries import stax, optimizers
from jax import numpy as jnp
from jax import grad, value_and_grad

JAX Version : 0.4.26


Import de bibliotecas para criar tabelas referentes aos processos de treinamento e teste

In [None]:
import gspread
import pytz
import time
from datetime import datetime
from google.colab import auth

In [None]:
#autenticação
auth.authenticate_user()

# Autorização
from google.auth import default
creds, _ = default()
gc = gspread.authorize(creds)
fuso_horario = pytz.timezone('America/Sao_Paulo')

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página1")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Erro no treinamento
n_coluna_H = 9 # Número de iterações
n_coluna_I = 10 # Batch size
n_coluna_J = 11 # Quantidade de amostras de teste
n_coluna_K = 12 # Quantidade de amostra total
n_coluna_L = 13 # Taxa de aprendizado
n_coluna_M = 14 # Erro no teste
n_coluna_N = 15 # Tempo no teste

## Import do Data-set

Usaremos o Data-set Fashion MNIST. Este dataset é composto por imagens 28x28 com escalas de cinza contendo 10 elementos de roupa. Este dataset é dividido em treino (amostras de 60k) e teste (amostra de 10k). Convertemos este dataset em JAX arrays usando modelo já construíndo no JAX. Depois, alteramos para imagens com formato (28,28,1), já que a última entrada caracteriza a classe. A rede convolucional requer canais. Como, porém, as imagens estão em escala de cinza, não temos os 3 canais RGB como imagens usualmente possuem. Para solucionar isto, adicionamos uma camada extra para a convolução. Então, normalizamos o dataset por 255 para facilitar a convergência do algoritmo de otimização.

In [None]:
from tensorflow import keras
from sklearn.model_selection import train_test_split

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  jnp.unique(Y_train)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
[1m29515/29515[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
[1m26421880/26421880[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
[1m5148/5148[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
[1m4422102/4422102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

## NN

### NN 1 - Rede neural com camada 784, 392, 10

In [None]:
neural_net_init, neural_net_apply = stax.serial(  stax.Flatten,
                                                  stax.Dense(784),
                                                  stax.Relu,
                                                  stax.Dense(392),
                                                  stax.Relu,
                                                  stax.Dense(len(classes)),
                                                  stax.Softmax
                                                )

In [None]:
rng = jax.random.PRNGKey(123)

weights = neural_net_init(rng, (18,28,28,1))

weights = weights[1]
for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (784, 784), Biases : (784,)
Weights : (784, 392), Biases : (392,)
Weights : (392, 10), Biases : (10,)


In [None]:
preds = neural_net_apply(weights, X_train[:5])

preds

Array([[0.07811955, 0.09067789, 0.09877805, 0.06826587, 0.17354672,
        0.07074962, 0.1470148 , 0.11571854, 0.06007824, 0.0970507 ],
       [0.12036594, 0.08408968, 0.10847562, 0.10073808, 0.14934102,
        0.05289947, 0.11336958, 0.0693934 , 0.1203151 , 0.08101206],
       [0.10484902, 0.08826307, 0.11507356, 0.09657142, 0.12428769,
        0.07600646, 0.10950141, 0.09494527, 0.09451821, 0.09598389],
       [0.1028449 , 0.08145509, 0.1158585 , 0.10533378, 0.1443685 ,
        0.07796986, 0.10524142, 0.08857726, 0.09122137, 0.08712925],
       [0.11193434, 0.06122527, 0.12651367, 0.12462392, 0.15348454,
        0.05267394, 0.11282269, 0.08244407, 0.08652247, 0.08775512]],      dtype=float32)

In [None]:
def CrossEntropyLoss(weights, input_data, actual):
    preds = neural_net_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import grad, value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora


        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss


        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        #Update sheet only once per epoch
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((784, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((784, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((784, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_H, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(learning_rate))
        aba.update_cell(proxima_linha_vazia,n_coluna_I, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLoss)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_N, float(t1-t0))

    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 50
batch_size=256

weights = neural_net_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 232.788
CrossEntropyLoss : 147.791
CrossEntropyLoss : 131.620
CrossEntropyLoss : 122.850
CrossEntropyLoss : 116.936
CrossEntropyLoss : 112.479
CrossEntropyLoss : 108.925
CrossEntropyLoss : 105.993
CrossEntropyLoss : 103.481
CrossEntropyLoss : 101.249
CrossEntropyLoss : 99.268
CrossEntropyLoss : 97.459
CrossEntropyLoss : 95.813
CrossEntropyLoss : 94.282
CrossEntropyLoss : 92.857
CrossEntropyLoss : 91.521
CrossEntropyLoss : 90.262
CrossEntropyLoss : 89.067
CrossEntropyLoss : 87.928
CrossEntropyLoss : 86.843
CrossEntropyLoss : 85.800
CrossEntropyLoss : 84.804
CrossEntropyLoss : 83.835
CrossEntropyLoss : 82.907
CrossEntropyLoss : 82.007
CrossEntropyLoss : 81.133
CrossEntropyLoss : 80.286
CrossEntropyLoss : 79.463
CrossEntropyLoss : 78.666
CrossEntropyLoss : 77.888
CrossEntropyLoss : 77.129
CrossEntropyLoss : 76.384
CrossEntropyLoss : 75.659
CrossEntropyLoss : 74.946
CrossEntropyLoss : 74.252
CrossEntropyLoss : 73.571
CrossEntropyLoss : 72.901
CrossEntropyLoss : 72.238
Cr

In [None]:
def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(neural_net_apply(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.911
Test  Accuracy : 0.880


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.80      0.86      0.83      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.77      0.83      0.80      1000
         3.0       0.84      0.92      0.88      1000
         4.0       0.81      0.79      0.80      1000
         5.0       0.97      0.95      0.96      1000
         6.0       0.77      0.61      0.68      1000
         7.0       0.93      0.95      0.94      1000
         8.0       0.96      0.96      0.96      1000
         9.0       0.96      0.95      0.96      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000



### NN 2 - Rede neural com camada 784, 784 e 10

In [None]:
neural_net2_init, neural_net2_apply = stax.serial(  stax.Flatten,
                                                  stax.Dense(784),
                                                  stax.Relu,
                                                  stax.Dense(784),
                                                  stax.Relu,
                                                  stax.Dense(len(classes)),
                                                  stax.Softmax
                                                )

In [None]:
rng = jax.random.PRNGKey(123)

weights = neural_net2_init(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple
#print(weights[2][0].shape)
for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (784, 784), Biases : (784,)
Weights : (784, 784), Biases : (784,)
Weights : (784, 10), Biases : (10,)


In [None]:
preds = neural_net2_apply(weights, X_train[:5])

preds

Array([[0.07811955, 0.09067789, 0.09877805, 0.06826587, 0.17354672,
        0.07074962, 0.1470148 , 0.11571854, 0.06007824, 0.0970507 ],
       [0.12036594, 0.08408968, 0.10847562, 0.10073808, 0.14934102,
        0.05289947, 0.11336958, 0.0693934 , 0.1203151 , 0.08101206],
       [0.10484902, 0.08826307, 0.11507356, 0.09657142, 0.12428769,
        0.07600646, 0.10950141, 0.09494527, 0.09451821, 0.09598389],
       [0.1028449 , 0.08145509, 0.1158585 , 0.10533378, 0.1443685 ,
        0.07796986, 0.10524142, 0.08857726, 0.09122137, 0.08712925],
       [0.11193434, 0.06122527, 0.12651367, 0.12462392, 0.15348454,
        0.05267394, 0.11282269, 0.08244407, 0.08652247, 0.08775512]],      dtype=float32)

In [None]:
def CrossEntropyLoss2(weights, input_data, actual):
    preds = neural_net2_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import grad, value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora


        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss2)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss


        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        #Update sheet only once per epoch
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((784, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((784, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((784, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_H, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(learning_rate))
        aba.update_cell(proxima_linha_vazia,n_coluna_I, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLoss2)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_N, float(t1-t0))

    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 50
batch_size=256

weights = neural_net2_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 232.788
CrossEntropyLoss : 147.791
CrossEntropyLoss : 131.620
CrossEntropyLoss : 122.850
CrossEntropyLoss : 116.936
CrossEntropyLoss : 112.479
CrossEntropyLoss : 108.925
CrossEntropyLoss : 105.993
CrossEntropyLoss : 103.481
CrossEntropyLoss : 101.249
CrossEntropyLoss : 99.268
CrossEntropyLoss : 97.459
CrossEntropyLoss : 95.813
CrossEntropyLoss : 94.282
CrossEntropyLoss : 92.857
CrossEntropyLoss : 91.521
CrossEntropyLoss : 90.262
CrossEntropyLoss : 89.067
CrossEntropyLoss : 87.928
CrossEntropyLoss : 86.843
CrossEntropyLoss : 85.800
CrossEntropyLoss : 84.804
CrossEntropyLoss : 83.835
CrossEntropyLoss : 82.907
CrossEntropyLoss : 82.007
CrossEntropyLoss : 81.133
CrossEntropyLoss : 80.286
CrossEntropyLoss : 79.463
CrossEntropyLoss : 78.666
CrossEntropyLoss : 77.888
CrossEntropyLoss : 77.129
CrossEntropyLoss : 76.384
CrossEntropyLoss : 75.659
CrossEntropyLoss : 74.946
CrossEntropyLoss : 74.252
CrossEntropyLoss : 73.571
CrossEntropyLoss : 72.901
CrossEntropyLoss : 72.238
Cr

In [None]:
def MakePredictions2(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(neural_net2_apply(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions2(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions2(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.911
Test  Accuracy : 0.880


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.80      0.86      0.83      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.77      0.83      0.80      1000
         3.0       0.84      0.92      0.88      1000
         4.0       0.81      0.79      0.80      1000
         5.0       0.97      0.95      0.96      1000
         6.0       0.77      0.61      0.68      1000
         7.0       0.93      0.95      0.94      1000
         8.0       0.96      0.96      0.96      1000
         9.0       0.96      0.95      0.96      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000



### NN 3 - Rede neural com camada 12544 e 10

In [None]:
neural_net3_init, neural_net3_apply = stax.serial(  stax.Flatten,
                                                  stax.Dense(12544),
                                                  stax.Relu,
                                                  stax.Dense(len(classes)),
                                                  stax.Softmax
                                                )

In [None]:
rng = jax.random.PRNGKey(123)

weights = neural_net3_init(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (784, 12544), Biases : (12544,)
Weights : (12544, 10), Biases : (10,)


In [None]:
preds = neural_net3_apply(weights, X_train[:5])

preds

Array([[0.11567732, 0.08312271, 0.09316629, 0.13119778, 0.0849104 ,
        0.09422207, 0.09260702, 0.10758439, 0.08504906, 0.11246294],
       [0.12771216, 0.10170111, 0.09442902, 0.12095149, 0.07102303,
        0.10548384, 0.09093792, 0.09384827, 0.08586559, 0.10804753],
       [0.11517212, 0.10051443, 0.08442578, 0.10548244, 0.08962353,
        0.10522561, 0.09384826, 0.10181855, 0.09511691, 0.10877232],
       [0.1249938 , 0.10030809, 0.08713156, 0.10648113, 0.07983516,
        0.10415314, 0.0952759 , 0.09500042, 0.08710919, 0.11971161],
       [0.13157985, 0.11046299, 0.08336561, 0.12212336, 0.07734967,
        0.10936862, 0.07990503, 0.09259556, 0.0831421 , 0.11010727]],      dtype=float32)

In [None]:
def CrossEntropyLoss3(weights, input_data, actual):
    preds = neural_net3_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import grad, value_and_grad

def TrainModelInBatches3(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora


        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss3)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss


        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        #Update sheet only once per epoch
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((784, 12544)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((12544,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((12544, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str())
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str())
        aba.update_cell(proxima_linha_vazia, n_coluna_G, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_H, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(learning_rate))
        aba.update_cell(proxima_linha_vazia,n_coluna_I, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLoss3)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_N, float(t1-t0))

    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 50
batch_size=256

weights = neural_net3_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches3(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 222.671
CrossEntropyLoss : 149.587
CrossEntropyLoss : 134.210
CrossEntropyLoss : 125.965
CrossEntropyLoss : 120.518
CrossEntropyLoss : 116.502
CrossEntropyLoss : 113.332
CrossEntropyLoss : 110.710
CrossEntropyLoss : 108.467
CrossEntropyLoss : 106.505
CrossEntropyLoss : 104.756
CrossEntropyLoss : 103.174
CrossEntropyLoss : 101.729
CrossEntropyLoss : 100.399
CrossEntropyLoss : 99.163
CrossEntropyLoss : 98.009
CrossEntropyLoss : 96.924
CrossEntropyLoss : 95.899
CrossEntropyLoss : 94.929
CrossEntropyLoss : 94.007
CrossEntropyLoss : 93.126
CrossEntropyLoss : 92.284
CrossEntropyLoss : 91.476
CrossEntropyLoss : 90.699
CrossEntropyLoss : 89.952
CrossEntropyLoss : 89.230
CrossEntropyLoss : 88.531
CrossEntropyLoss : 87.854
CrossEntropyLoss : 87.198
CrossEntropyLoss : 86.561
CrossEntropyLoss : 85.941
CrossEntropyLoss : 85.336
CrossEntropyLoss : 84.745
CrossEntropyLoss : 84.170
CrossEntropyLoss : 83.607
CrossEntropyLoss : 83.056
CrossEntropyLoss : 82.517
CrossEntropyLoss : 81.98

In [None]:
def MakePredictions3(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(neural_net3_apply(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions3(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions3(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.897
Test  Accuracy : 0.872


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.80      0.87      0.83      1000
         1.0       0.98      0.96      0.97      1000
         2.0       0.75      0.82      0.79      1000
         3.0       0.83      0.90      0.87      1000
         4.0       0.80      0.79      0.79      1000
         5.0       0.96      0.95      0.95      1000
         6.0       0.77      0.57      0.66      1000
         7.0       0.93      0.93      0.93      1000
         8.0       0.95      0.96      0.96      1000
         9.0       0.94      0.96      0.95      1000

    accuracy                           0.87     10000
   macro avg       0.87      0.87      0.87     10000
weighted avg       0.87      0.87      0.87     10000



### NN 4 - Rede neural com camada 784 e 10

In [None]:
neural_net4_init, neural_net4_apply = stax.serial(  stax.Flatten,
                                                  stax.Dense(784),
                                                  stax.Relu,
                                                  stax.Dense(len(classes)),
                                                  stax.Softmax
                                                )

In [None]:
rng = jax.random.PRNGKey(123)

weights = neural_net4_init(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (784, 784), Biases : (784,)
Weights : (784, 10), Biases : (10,)


In [None]:
preds = neural_net4_apply(weights, X_train[:5])

preds

Array([[0.09612467, 0.10410833, 0.18388611, 0.12166529, 0.05448725,
        0.08495695, 0.06539526, 0.07924715, 0.06066569, 0.14946328],
       [0.1347176 , 0.14377116, 0.13551466, 0.08447529, 0.04370638,
        0.10342203, 0.02061974, 0.09367398, 0.08857248, 0.15152664],
       [0.1026377 , 0.1213987 , 0.12713546, 0.07150632, 0.07939307,
        0.09092852, 0.08071359, 0.11131992, 0.08349042, 0.13147631],
       [0.10809993, 0.15784532, 0.16631818, 0.05506259, 0.06262441,
        0.0834934 , 0.06314867, 0.11416489, 0.06958278, 0.11965981],
       [0.08680553, 0.1079725 , 0.19477277, 0.05983886, 0.05173325,
        0.10810154, 0.04337629, 0.1107651 , 0.06790921, 0.16872497]],      dtype=float32)

In [None]:
def CrossEntropyLoss4(weights, input_data, actual):
    preds = neural_net4_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import grad, value_and_grad

def TrainModelInBatches4(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora


        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss4)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss


        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        #Update sheet only once per epoch
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((784, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((784, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str())
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str())
        aba.update_cell(proxima_linha_vazia, n_coluna_G, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_H, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(learning_rate))
        aba.update_cell(proxima_linha_vazia,n_coluna_I, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLoss4)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_N, float(t1-t0))

    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = neural_net4_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches4(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 238.141
CrossEntropyLoss : 155.597
CrossEntropyLoss : 138.593
CrossEntropyLoss : 129.665
CrossEntropyLoss : 123.839
CrossEntropyLoss : 119.607
CrossEntropyLoss : 116.311
CrossEntropyLoss : 113.612
CrossEntropyLoss : 111.314
CrossEntropyLoss : 109.304
CrossEntropyLoss : 107.531
CrossEntropyLoss : 105.940
CrossEntropyLoss : 104.491
CrossEntropyLoss : 103.154
CrossEntropyLoss : 101.911
CrossEntropyLoss : 100.749
CrossEntropyLoss : 99.658
CrossEntropyLoss : 98.627
CrossEntropyLoss : 97.648
CrossEntropyLoss : 96.723
CrossEntropyLoss : 95.839
CrossEntropyLoss : 94.994
CrossEntropyLoss : 94.184
CrossEntropyLoss : 93.410
CrossEntropyLoss : 92.667


In [None]:
def MakePredictions4(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(neural_net4_apply(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions4(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions4(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 0, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.876
Test  Accuracy : 0.857


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.79      0.85      0.82      1000
         1.0       0.98      0.96      0.97      1000
         2.0       0.73      0.81      0.77      1000
         3.0       0.82      0.89      0.86      1000
         4.0       0.77      0.76      0.76      1000
         5.0       0.95      0.93      0.94      1000
         6.0       0.73      0.54      0.62      1000
         7.0       0.92      0.92      0.92      1000
         8.0       0.94      0.95      0.94      1000
         9.0       0.94      0.96      0.95      1000

    accuracy                           0.86     10000
   macro avg       0.86      0.86      0.85     10000
weighted avg       0.86      0.86      0.85     10000



### (EXTRA) NN 5 - Rede neural com camada 12544, 12544, 12544 e 10
O que aconteceria se tivemos uma super NN?


In [None]:
neural_net5_init, neural_net5_apply = stax.serial(  stax.Flatten,
                                                  stax.Dense(12544),
                                                  stax.Relu,
                                                  stax.Dense(12544),
                                                  stax.Relu,
                                                  stax.Dense(12544),
                                                  stax.Relu,
                                                  stax.Dense(len(classes)),
                                                  stax.Softmax
                                                )

In [None]:
rng = jax.random.PRNGKey(123)

weights = neural_net5_init(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (784, 12544), Biases : (12544,)
Weights : (12544, 12544), Biases : (12544,)
Weights : (12544, 12544), Biases : (12544,)
Weights : (12544, 10), Biases : (10,)


In [None]:
preds = neural_net5_apply(weights, X_train[:5])

preds

Array([[0.09783493, 0.09493916, 0.08850034, 0.11536992, 0.10089111,
        0.09155177, 0.11366444, 0.10272201, 0.09576342, 0.09876294],
       [0.09001359, 0.09726787, 0.0983184 , 0.09964491, 0.11205891,
        0.09315657, 0.11798366, 0.0958628 , 0.09672335, 0.09896995],
       [0.09709807, 0.09883974, 0.09811977, 0.10343327, 0.10358185,
        0.09406121, 0.10775121, 0.10162062, 0.09505739, 0.10043693],
       [0.09554135, 0.10430618, 0.10017609, 0.10546475, 0.10143921,
        0.09168094, 0.1080877 , 0.09815247, 0.09263518, 0.10251615],
       [0.08937951, 0.10172193, 0.09391555, 0.10992814, 0.10593214,
        0.09310725, 0.11358219, 0.10120934, 0.09769433, 0.09352971]],      dtype=float32)

In [None]:
def CrossEntropyLoss5(weights, input_data, actual):
    preds = neural_net5_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import grad, value_and_grad

def TrainModelInBatches5(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora


        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss5)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss


        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        #Update sheet only once per epoch
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((784, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((784, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str())
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str())
        aba.update_cell(proxima_linha_vazia, n_coluna_G, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_H, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(learning_rate))
        aba.update_cell(proxima_linha_vazia,n_coluna_I, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLoss5)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_N, float(t1-t0))

    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = neural_net5_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches5(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 219.660
CrossEntropyLoss : 139.220


Resposta: deu falha de memória RAM.

## CNN

### CNN 1 - Conv+RELu 32, Conv+RELu 16 e Flatten

In [None]:
conv_init, conv_apply = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página2")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bibias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Dimensão do peso 4
n_coluna_H = 9 # Dimensão do bias 4
n_coluna_I = 10 # Erro no treinamento
n_coluna_J = 11 # Número de iterações
n_coluna_K = 12 # Batch size
n_coluna_L = 13 # Quantidade de amostras de teste
n_coluna_M = 14 # Quantidade de amostra total
n_coluna_N = 15 # Taxa de aprendizado
n_coluna_O = 16 # Erro no teste
n_coluna_P = 17 # Tempo no teste

In [None]:
rng = jax.random.PRNGKey(123)

weights = conv_init(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (12544, 10), Biases : (10,)


In [None]:
preds = conv_apply(weights, X_train[:5])

preds

Array([[0.1010583 , 0.08450612, 0.08876862, 0.10357413, 0.09424469,
        0.06837884, 0.13585268, 0.10559722, 0.10213768, 0.11588172],
       [0.10652137, 0.08236858, 0.11018915, 0.11165013, 0.08258918,
        0.07724519, 0.14238866, 0.09115906, 0.0884075 , 0.10748117],
       [0.09566505, 0.09545239, 0.10094573, 0.10249694, 0.09882495,
        0.08883353, 0.11344208, 0.09690957, 0.10275006, 0.10467971],
       [0.10200435, 0.08598676, 0.10633808, 0.10407417, 0.09481844,
        0.0829622 , 0.12135128, 0.093064  , 0.09924014, 0.11016052],
       [0.09390713, 0.08359886, 0.10012674, 0.11463808, 0.09753538,
        0.07206484, 0.12989205, 0.08960547, 0.10824842, 0.11038304]],      dtype=float32)

In [None]:
def CrossEntropyLossConv(weights, input_data, actual):
    preds = conv_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLossConv)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((3, 3, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((1, 1, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((3, 3, 32, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((1, 1, 1, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((3, 3, 16, 8)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((1, 1, 1, 8)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, str((6272, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_H, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_I, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_M, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_N, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLossConv)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_O, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_P, float(t1-t0))


    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 50
batch_size=256

weights = conv_init(rng, (18,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 237.197
CrossEntropyLoss : 137.667
CrossEntropyLoss : 119.940
CrossEntropyLoss : 110.607
CrossEntropyLoss : 104.348
CrossEntropyLoss : 99.563
CrossEntropyLoss : 95.596
CrossEntropyLoss : 92.126
CrossEntropyLoss : 89.062
CrossEntropyLoss : 86.342
CrossEntropyLoss : 83.902
CrossEntropyLoss : 81.713
CrossEntropyLoss : 79.717
CrossEntropyLoss : 77.903
CrossEntropyLoss : 76.245
CrossEntropyLoss : 74.700
CrossEntropyLoss : 73.277
CrossEntropyLoss : 71.944
CrossEntropyLoss : 70.720
CrossEntropyLoss : 69.546
CrossEntropyLoss : 68.426
CrossEntropyLoss : 67.369
CrossEntropyLoss : 66.337
CrossEntropyLoss : 65.365
CrossEntropyLoss : 64.449
CrossEntropyLoss : 63.565
CrossEntropyLoss : 62.700
CrossEntropyLoss : 61.857
CrossEntropyLoss : 61.050
CrossEntropyLoss : 60.265
CrossEntropyLoss : 59.496
CrossEntropyLoss : 58.760
CrossEntropyLoss : 58.045
CrossEntropyLoss : 57.351
CrossEntropyLoss : 56.662
CrossEntropyLoss : 56.007
CrossEntropyLoss : 55.362
CrossEntropyLoss : 54.738
CrossEn

In [None]:
def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 0], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.934
Test  Accuracy : 0.898


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.83      0.85      0.84      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.83      0.86      0.85      1000
         3.0       0.87      0.93      0.90      1000
         4.0       0.88      0.77      0.82      1000
         5.0       0.98      0.98      0.98      1000
         6.0       0.71      0.73      0.72      1000
         7.0       0.95      0.96      0.96      1000
         8.0       0.98      0.97      0.98      1000
         9.0       0.97      0.96      0.96      1000

    accuracy                           0.90     10000
   macro avg       0.90      0.90      0.90     10000
weighted avg       0.90      0.90      0.90     10000



### CNN 2 - Conv+RELu 32, Conv+RELu 16, Conv 8, Flatten

In [None]:
conv_init2, conv_apply2 = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.Conv(8, (3,3), padding="SAME"),
    stax.Relu,

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página2")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bibias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Dimensão do peso 4
n_coluna_H = 9 # Dimensão do bias 4
n_coluna_I = 10 # Erro no treinamento
n_coluna_J = 11 # Número de iterações
n_coluna_K = 12 # Batch size
n_coluna_L = 13 # Quantidade de amostras de teste
n_coluna_M = 14 # Quantidade de amostra total
n_coluna_N = 15 # Taxa de aprendizado
n_coluna_O = 16 # Erro no teste
n_coluna_P = 17 # Tempo no teste

In [None]:
rng = jax.random.PRNGKey(123)

weights = conv_init2(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (3, 3, 16, 8), Biases : (1, 1, 1, 8)
Weights : (6272, 10), Biases : (10,)


In [None]:
preds = conv_apply2(weights, X_train[:5])

preds

Array([[0.11038224, 0.10035543, 0.09388019, 0.09824558, 0.09850648,
        0.10333223, 0.10351744, 0.0947985 , 0.09528428, 0.10169768],
       [0.11012788, 0.09846697, 0.09884238, 0.10509334, 0.09760616,
        0.09469733, 0.10269596, 0.09468425, 0.09779315, 0.0999926 ],
       [0.10290439, 0.10052056, 0.09763322, 0.10209966, 0.09708749,
        0.10292315, 0.1029613 , 0.09659406, 0.09772114, 0.09955501],
       [0.10490854, 0.09879624, 0.09545968, 0.10218364, 0.09938176,
        0.10210045, 0.10189226, 0.09747745, 0.09721155, 0.10058835],
       [0.11069517, 0.10137364, 0.08963691, 0.10273396, 0.09639259,
        0.10485574, 0.10414395, 0.09456155, 0.09492914, 0.10067735]],      dtype=float32)

In [None]:
def CrossEntropyLossConv2(weights, input_data, actual):
    preds = conv_apply2(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLossConv2)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((3, 3, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((1, 1, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((3, 3, 32, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((1, 1, 1, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((3, 3, 16, 8)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((1, 1, 1, 8)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, str((6272, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_H, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_I, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_M, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_N, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLossConv2)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_O, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_P, float(t1-t0))


    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init2(rng, (18,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 237.197
CrossEntropyLoss : 137.667
CrossEntropyLoss : 119.940
CrossEntropyLoss : 110.607
CrossEntropyLoss : 104.348
CrossEntropyLoss : 99.563
CrossEntropyLoss : 95.596
CrossEntropyLoss : 92.126
CrossEntropyLoss : 89.062
CrossEntropyLoss : 86.342
CrossEntropyLoss : 83.902
CrossEntropyLoss : 81.713
CrossEntropyLoss : 79.717
CrossEntropyLoss : 77.903
CrossEntropyLoss : 76.245
CrossEntropyLoss : 74.700
CrossEntropyLoss : 73.277
CrossEntropyLoss : 71.944
CrossEntropyLoss : 70.720
CrossEntropyLoss : 69.546
CrossEntropyLoss : 68.426
CrossEntropyLoss : 67.369
CrossEntropyLoss : 66.337
CrossEntropyLoss : 65.365
CrossEntropyLoss : 64.449


In [None]:
def MakePredictions2(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply2(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions2(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions2(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 0], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.910
Test  Accuracy : 0.887


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.82      0.86      0.84      1000
         1.0       0.99      0.96      0.98      1000
         2.0       0.77      0.88      0.82      1000
         3.0       0.86      0.93      0.89      1000
         4.0       0.87      0.72      0.79      1000
         5.0       0.98      0.97      0.97      1000
         6.0       0.73      0.66      0.69      1000
         7.0       0.94      0.97      0.95      1000
         8.0       0.98      0.97      0.97      1000
         9.0       0.97      0.96      0.96      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000



### CNN 3 - Conv32 + RELu, AvgPool(3,3), Conv16+RELu, AvgPool(3,3), Flatten e Flatten

In [None]:
conv_init3, conv_apply3 = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.AvgPool((3,3), padding="SAME"),
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.AvgPool((3,3), padding="SAME"),

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página2")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bibias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Dimensão do peso 4
n_coluna_H = 9 # Dimensão do bias 4
n_coluna_I = 10 # Erro no treinamento
n_coluna_J = 11 # Número de iterações
n_coluna_K = 12 # Batch size
n_coluna_L = 13 # Quantidade de amostras de teste
n_coluna_M = 14 # Quantidade de amostra total
n_coluna_N = 15 # Taxa de aprendizado
n_coluna_O = 16 # Erro no teste
n_coluna_P = 17 # Tempo no teste

In [None]:
rng = jax.random.PRNGKey(123)

weights = conv_init3(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (12544, 10), Biases : (10,)


In [None]:
preds = conv_apply3(weights, X_train[:5])

preds

Array([[0.09840897, 0.11047231, 0.11040694, 0.10387599, 0.0970936 ,
        0.09818683, 0.09109973, 0.09893171, 0.10123079, 0.09029309],
       [0.09430714, 0.10524321, 0.09991619, 0.10411722, 0.09135789,
        0.09486455, 0.10155654, 0.09576621, 0.10513603, 0.10773499],
       [0.09452868, 0.10610519, 0.09852562, 0.10138861, 0.09767579,
        0.10012037, 0.10115271, 0.10077183, 0.10145459, 0.09827655],
       [0.09371365, 0.10674328, 0.09712839, 0.10316705, 0.09963594,
        0.09866244, 0.10053799, 0.10001117, 0.09940703, 0.100993  ],
       [0.09558828, 0.11103137, 0.09993856, 0.10479355, 0.09826652,
        0.09980534, 0.10059945, 0.0988933 , 0.09651057, 0.09457301]],      dtype=float32)

In [None]:
def CrossEntropyLossConv3(weights, input_data, actual):
    preds = conv_apply3(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLossConv3)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((3, 3, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((1, 1, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((3, 3, 32, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((1, 1, 1, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, str((12544, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_H, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_I, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_M, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_N, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLossConv3)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_O, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_P, float(t1-t0))


    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 50
batch_size=256

weights = conv_init3(rng, (18,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 248.126
CrossEntropyLoss : 166.875
CrossEntropyLoss : 147.230
CrossEntropyLoss : 134.101
CrossEntropyLoss : 124.303
CrossEntropyLoss : 116.747
CrossEntropyLoss : 111.112
CrossEntropyLoss : 106.854
CrossEntropyLoss : 103.467
CrossEntropyLoss : 100.647
CrossEntropyLoss : 98.252
CrossEntropyLoss : 96.194
CrossEntropyLoss : 94.386
CrossEntropyLoss : 92.776
CrossEntropyLoss : 91.345
CrossEntropyLoss : 90.026
CrossEntropyLoss : 88.820
CrossEntropyLoss : 87.711
CrossEntropyLoss : 86.662
CrossEntropyLoss : 85.675
CrossEntropyLoss : 84.758
CrossEntropyLoss : 83.877
CrossEntropyLoss : 83.038
CrossEntropyLoss : 82.242
CrossEntropyLoss : 81.490
CrossEntropyLoss : 80.761
CrossEntropyLoss : 80.062
CrossEntropyLoss : 79.376
CrossEntropyLoss : 78.715
CrossEntropyLoss : 78.074
CrossEntropyLoss : 77.455
CrossEntropyLoss : 76.858
CrossEntropyLoss : 76.276
CrossEntropyLoss : 75.706
CrossEntropyLoss : 75.160
CrossEntropyLoss : 74.625
CrossEntropyLoss : 74.104
CrossEntropyLoss : 73.595
Cr

In [None]:
def MakePredictions3(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply3(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions3(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions3(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.910
Test  Accuracy : 0.896


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.83      0.85      0.84      1000
         1.0       0.99      0.96      0.98      1000
         2.0       0.85      0.83      0.84      1000
         3.0       0.86      0.92      0.89      1000
         4.0       0.84      0.84      0.84      1000
         5.0       0.98      0.97      0.98      1000
         6.0       0.73      0.69      0.71      1000
         7.0       0.94      0.96      0.95      1000
         8.0       0.97      0.97      0.97      1000
         9.0       0.96      0.96      0.96      1000

    accuracy                           0.90     10000
   macro avg       0.90      0.90      0.90     10000
weighted avg       0.90      0.90      0.90     10000



### CNN 4 -Conv32+RELu, MaxPool, Conv16+RELu, MaxPool, Flatten

In [None]:
conv_init4, conv_apply4 = stax.serial(
stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(3,3), strides=(2,2), padding="SAME"),
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(3,3), strides=(2,2), padding="SAME"),

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página2")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bibias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Dimensão do peso 4
n_coluna_H = 9 # Dimensão do bias 4
n_coluna_I = 10 # Erro no treinamento
n_coluna_J = 11 # Número de iterações
n_coluna_K = 12 # Batch size
n_coluna_L = 13 # Quantidade de amostras de teste
n_coluna_M = 14 # Quantidade de amostra total
n_coluna_N = 15 # Taxa de aprendizado
n_coluna_O = 16 # Erro no teste
n_coluna_P = 17 # Tempo no teste

In [None]:
rng = jax.random.PRNGKey(123)

weights = conv_init4(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (784, 10), Biases : (10,)


In [None]:
preds = conv_apply4(weights, X_train[:5])

preds

Array([[0.11346854, 0.11503609, 0.12250461, 0.08987208, 0.08465108,
        0.06523983, 0.10706779, 0.09952015, 0.09868933, 0.10395041],
       [0.12661624, 0.12114603, 0.10464215, 0.09760939, 0.07653171,
        0.0618769 , 0.1185155 , 0.0857824 , 0.09442716, 0.11285254],
       [0.11161763, 0.10252924, 0.10303943, 0.10422694, 0.0890063 ,
        0.08569656, 0.09722918, 0.10129583, 0.10760707, 0.09775182],
       [0.11328106, 0.1094126 , 0.10847753, 0.10078198, 0.07931047,
        0.08214597, 0.10063422, 0.09821174, 0.10790451, 0.09983996],
       [0.11225218, 0.1120536 , 0.10489712, 0.11213328, 0.08322569,
        0.06219784, 0.09779106, 0.09799976, 0.12137157, 0.09607792]],      dtype=float32)

In [None]:
def CrossEntropyLossConv4(weights, input_data, actual):
    preds = conv_apply4(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLossConv4)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((3, 3, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((1, 1, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((3, 3, 32, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((1, 1, 1, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, str((12544, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_H, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_I, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_J, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_K, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_L, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_M, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_N, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLossConv4)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_O, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_P, float(t1-t0))


    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init4(rng, (18,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 333.543
CrossEntropyLoss : 192.880
CrossEntropyLoss : 162.156
CrossEntropyLoss : 144.185
CrossEntropyLoss : 132.848
CrossEntropyLoss : 124.941
CrossEntropyLoss : 118.918
CrossEntropyLoss : 114.187
CrossEntropyLoss : 110.250
CrossEntropyLoss : 106.891
CrossEntropyLoss : 103.960
CrossEntropyLoss : 101.355
CrossEntropyLoss : 99.083
CrossEntropyLoss : 97.043
CrossEntropyLoss : 95.194
CrossEntropyLoss : 93.542
CrossEntropyLoss : 92.038
CrossEntropyLoss : 90.652
CrossEntropyLoss : 89.373
CrossEntropyLoss : 88.196
CrossEntropyLoss : 87.094
CrossEntropyLoss : 86.051
CrossEntropyLoss : 85.080
CrossEntropyLoss : 84.161
CrossEntropyLoss : 83.287


In [None]:
def MakePredictions4(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply4(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions4(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions4(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.891
Test  Accuracy : 0.880


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.81      0.85      0.83      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.80      0.83      0.82      1000
         3.0       0.84      0.91      0.87      1000
         4.0       0.81      0.77      0.79      1000
         5.0       0.97      0.96      0.97      1000
         6.0       0.72      0.64      0.68      1000
         7.0       0.93      0.94      0.94      1000
         8.0       0.96      0.97      0.97      1000
         9.0       0.95      0.95      0.95      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000



### CNN 5 - Conv+RELu, AvgPool, Conv+RELu, AvgPool, 784, 392 e 10

In [None]:
conv_init5, conv_apply5 = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.AvgPool((3,3), padding="SAME"),
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.AvgPool((3,3), padding="SAME"),

    stax.Flatten,
    stax.Dense(784),
    stax.Relu,
    stax.Dense(392),
    stax.Relu,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página2")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Dimensão do peso 4
n_coluna_H = 9 # Dimensão do bias 4
n_coluna_I = 10 # Dimensão do peso 5
n_coluna_J = 11 # Dimensão do bias 5
n_coluna_K = 12 # Dimensão do peso 6
n_coluna_L = 13 # Dimensão do bias 6
n_coluna_M = 14 # Erro no treinamento
n_coluna_N = 15 # Número de iterações
n_coluna_O = 16 # Batch size
n_coluna_P = 17 # Quantidade de amostras de teste
n_coluna_Q = 18 # Quantidade de amostra total
n_coluna_R = 19 # Taxa de aprendizado
n_coluna_S = 20 # Erro no teste
n_coluna_T = 21 # Tempo no teste

In [None]:
rng = jax.random.PRNGKey(123)

weights = conv_init5(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (12544, 784), Biases : (784,)
Weights : (784, 392), Biases : (392,)
Weights : (392, 10), Biases : (10,)


In [None]:
preds = conv_apply5(weights, X_train[:5])

preds

Array([[0.08712355, 0.1106542 , 0.11023826, 0.10060045, 0.10000064,
        0.0971093 , 0.10290173, 0.09644327, 0.09649688, 0.09843171],
       [0.08857621, 0.11876537, 0.1066256 , 0.10168226, 0.09600412,
        0.10515672, 0.09139074, 0.09814203, 0.0953457 , 0.09831124],
       [0.09616838, 0.10643027, 0.10425208, 0.10144627, 0.09744291,
        0.10131925, 0.09687588, 0.09902634, 0.0976822 , 0.09935643],
       [0.0911025 , 0.10938841, 0.10549676, 0.1026509 , 0.09709535,
        0.10177306, 0.09533088, 0.10234383, 0.09711587, 0.09770246],
       [0.08662689, 0.11696312, 0.10433838, 0.10052797, 0.09732302,
        0.10304537, 0.09425217, 0.10079762, 0.09916727, 0.09695816]],      dtype=float32)

In [None]:
def CrossEntropyLossConv5(weights, input_data, actual):
    preds = conv_apply5(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLossConv5)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((3, 3, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((1, 1, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((3, 3, 32, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((1, 1, 1, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, str((12544, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_H, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_I, str((784, 392)))
        aba.update_cell(proxima_linha_vazia, n_coluna_J, str((392,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_K, str((392, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_L, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_N, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_O, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_P, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_Q, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_R, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLossConv5)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_S, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_T, float(t1-t0))


    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init5(rng, (18,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 228.992
CrossEntropyLoss : 147.848
CrossEntropyLoss : 130.127
CrossEntropyLoss : 119.748
CrossEntropyLoss : 112.216
CrossEntropyLoss : 106.669
CrossEntropyLoss : 102.270
CrossEntropyLoss : 98.411
CrossEntropyLoss : 95.277
CrossEntropyLoss : 92.520
CrossEntropyLoss : 90.084
CrossEntropyLoss : 87.906
CrossEntropyLoss : 85.822
CrossEntropyLoss : 84.092
CrossEntropyLoss : 82.436
CrossEntropyLoss : 80.748
CrossEntropyLoss : 79.194
CrossEntropyLoss : 77.764
CrossEntropyLoss : 76.440
CrossEntropyLoss : 75.028
CrossEntropyLoss : 73.847
CrossEntropyLoss : 72.577
CrossEntropyLoss : 71.449
CrossEntropyLoss : 70.359
CrossEntropyLoss : 69.285


In [None]:
def MakePredictions5(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply5(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions5(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions5(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 3], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.896
Test  Accuracy : 0.875


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.80      0.86      0.83      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.71      0.85      0.77      1000
         3.0       0.85      0.92      0.88      1000
         4.0       0.78      0.81      0.79      1000
         5.0       0.98      0.95      0.96      1000
         6.0       0.83      0.52      0.64      1000
         7.0       0.92      0.97      0.94      1000
         8.0       0.96      0.98      0.97      1000
         9.0       0.96      0.94      0.95      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.87     10000
weighted avg       0.88      0.88      0.87     10000



### CNN 6 - Conv+RELu, MaxPool, Conv+RELu, MaxPool, 784, 392 e 10

In [None]:
conv_init6, conv_apply6 = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(3,3), strides=(2,2), padding="SAME"),
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,
    stax.MaxPool(window_shape=(3,3), strides=(2,2), padding="SAME"),

    stax.Flatten,
    stax.Dense(784),
    stax.Relu,
    stax.Dense(392),
    stax.Relu,
    stax.Dense(len(classes)),
    stax.Softmax
)

In [None]:
# Abrir a planilha
planilha = gc.open("EP 1 - NN vs CNN - Fashion Mnist usando JAX")
aba = planilha.worksheet("Página2")

n_coluna_DATA = 1 # O número da coluna contendo a data e a hora
n_coluna_A = 2 # Dimensão do peso 1
n_coluna_B = 3 # Dimensão do bibias 1
n_coluna_C = 4 # Dimensão do peso 2
n_coluna_D = 5 # Dimensão do bias 2
n_coluna_E = 6 # Dimensão do peso 3
n_coluna_F = 7 # Dimensão do bias 3
n_coluna_G = 8 # Dimensão do peso 4
n_coluna_H = 9 # Dimensão do bias 4
n_coluna_I = 10 # Dimensão do peso 5
n_coluna_J = 11 # Dimensão do bias 5
n_coluna_K = 12 # Dimensão do peso 6
n_coluna_L = 13 # Dimensão do bias 6
n_coluna_M = 14 # Erro no treinamento
n_coluna_N = 15 # Número de iterações
n_coluna_O = 16 # Batch size
n_coluna_P = 17 # Quantidade de amostras de teste
n_coluna_Q = 18 # Quantidade de amostra total
n_coluna_R = 19 # Taxa de aprendizado
n_coluna_S = 20 # Erro no teste
n_coluna_T = 21 # Tempo no teste

In [None]:
rng = jax.random.PRNGKey(123)

weights = conv_init6(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))

Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (784, 784), Biases : (784,)
Weights : (784, 392), Biases : (392,)
Weights : (392, 10), Biases : (10,)


In [None]:
preds = conv_apply6(weights, X_train[:5])

preds

Array([[0.09757035, 0.10241958, 0.0900976 , 0.11813903, 0.10072365,
        0.082774  , 0.08404551, 0.10989817, 0.11372444, 0.10060761],
       [0.10525773, 0.10380331, 0.08879522, 0.11272974, 0.09696601,
        0.08687066, 0.08535547, 0.10144036, 0.1212389 , 0.0975426 ],
       [0.10229913, 0.10589906, 0.0968143 , 0.10394232, 0.09907755,
        0.09173166, 0.09412654, 0.09975762, 0.10959445, 0.09675736],
       [0.10250545, 0.10784738, 0.09309515, 0.10935327, 0.09636433,
        0.09162021, 0.09051629, 0.10029337, 0.111427  , 0.09697749],
       [0.09967467, 0.10568917, 0.09336875, 0.11485872, 0.09853793,
        0.08356208, 0.08581802, 0.0999238 , 0.12121885, 0.09734792]],      dtype=float32)

In [None]:
def CrossEntropyLossConv6(weights, input_data, actual):
    preds = conv_apply6(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

In [None]:
from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        t0=time.time()
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        agora = datetime.now(fuso_horario) # Atualiza o horário
        proxima_linha_vazia = len(aba.get_all_values()) + 1
        aba.update_cell(proxima_linha_vazia, # Aqui vai a linha que vai ser adicionada
                  n_coluna_DATA,  # Aqui vai o número da coluna
                  agora.strftime("%d/%m/%Y %H:%M:%S")) # Add a coluna de informações a data e hora

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLossConv6)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        t1=time.time()
        aba.update_cell(proxima_linha_vazia, n_coluna_A, str((3, 3, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_B, str((1, 1, 1, 32)))
        aba.update_cell(proxima_linha_vazia, n_coluna_C, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_D, str((3, 3, 32, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_E, str((1, 1, 1, 16)))
        aba.update_cell(proxima_linha_vazia, n_coluna_F, str((3,3)))
        aba.update_cell(proxima_linha_vazia, n_coluna_G, str((12544, 784)))
        aba.update_cell(proxima_linha_vazia, n_coluna_H, str((784,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_I, str((784, 392)))
        aba.update_cell(proxima_linha_vazia, n_coluna_J, str((392,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_K, str((392, 10)))
        aba.update_cell(proxima_linha_vazia, n_coluna_L, str((10,)))
        aba.update_cell(proxima_linha_vazia, n_coluna_M, float(jnp.array(losses).mean()))
        aba.update_cell(proxima_linha_vazia,n_coluna_N, str(i))
        aba.update_cell(proxima_linha_vazia,n_coluna_O, str(batch_size))
        aba.update_cell(proxima_linha_vazia,n_coluna_P, str(X_train.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_Q, str(X_test.shape))
        aba.update_cell(proxima_linha_vazia,n_coluna_R, str(learning_rate))
        TestError, TestGrad = value_and_grad(CrossEntropyLossConv6)(opt_get_weights(opt_state), X_test,Y_test)
        aba.update_cell(proxima_linha_vazia, n_coluna_S, float(jnp.array(TestError).mean()))
        aba.update_cell(proxima_linha_vazia, n_coluna_T, float(t1-t0))



    return opt_state

In [None]:
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init6(rng, (18,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)

CrossEntropyLoss : 279.547
CrossEntropyLoss : 166.982
CrossEntropyLoss : 141.139
CrossEntropyLoss : 126.828
CrossEntropyLoss : 117.492
CrossEntropyLoss : 110.436
CrossEntropyLoss : 104.791
CrossEntropyLoss : 100.109
CrossEntropyLoss : 96.048
CrossEntropyLoss : 92.522
CrossEntropyLoss : 89.380
CrossEntropyLoss : 86.645
CrossEntropyLoss : 84.113
CrossEntropyLoss : 81.884
CrossEntropyLoss : 79.858
CrossEntropyLoss : 77.857
CrossEntropyLoss : 76.222
CrossEntropyLoss : 74.567
CrossEntropyLoss : 73.086
CrossEntropyLoss : 71.738
CrossEntropyLoss : 70.375
CrossEntropyLoss : 69.084
CrossEntropyLoss : 67.882
CrossEntropyLoss : 66.781
CrossEntropyLoss : 65.687


In [None]:
def MakePredictions6(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply6(weights, X_batch))

    return preds

In [None]:
test_preds = MakePredictions6(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions6(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]

(Array([9, 2, 1, 1, 6], dtype=int32), Array([9, 0, 0, 3, 1], dtype=int32))

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))

Train Accuracy : 0.910
Test  Accuracy : 0.893


In [None]:
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))

Test Classification Report 
              precision    recall  f1-score   support

         0.0       0.79      0.91      0.84      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.80      0.86      0.83      1000
         3.0       0.87      0.93      0.90      1000
         4.0       0.83      0.81      0.82      1000
         5.0       0.98      0.97      0.97      1000
         6.0       0.81      0.59      0.68      1000
         7.0       0.94      0.96      0.95      1000
         8.0       0.97      0.98      0.98      1000
         9.0       0.96      0.95      0.96      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000

