In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime
from tensorflow.keras.layers import Conv2D

from preprocessing.data.imagenet_labels import imagenet_labels

In [3]:
from models.ResBlock import ResBlock

In [4]:
from modules.TFR_load import TFR_load

In [5]:
# Hyper params

BATCH_SIZE     = 8
NUM_OF_CLASS   = 7
LEARNING_RATE  = 1e-4
COSINE_DECAY   = True
NUM_TRAIN_DATA = 60000
NUM_TEST_DATA  = 10000

TR_STEPS_PER_EPOCH = NUM_TRAIN_DATA//BATCH_SIZE
TE_STEPS_PER_EPOCH = NUM_TEST_DATA//BATCH_SIZE

In [6]:
import tensorflow_datasets as tfds

ds_name = 'mnist'
builder = tfds.builder(ds_name)

tr_ds, te_ds = builder.as_dataset(split = ['train', 'test'], shuffle_files = True)

tr_ds = tr_ds.batch(BATCH_SIZE)
print(len(tr_ds))
te_ds = te_ds.batch(BATCH_SIZE)
print(len(te_ds))

7500
1250


In [7]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.conv = Conv2D(
            16,
            (5, 5),
            strides = (1, 1),
            padding='same',
            activation = "relu")
        
        self.ResBlock = ResBlock(out_channel = 32, pooling = 'max')
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(10, activation = 'softmax')
    def call(self, x):
        x = self.conv(x)
        x = self.ResBlock(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

model = MyModel()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, epsilon=1e-08)
loss = tf.keras.losses.CategoricalCrossentropy()

for example in tr_ds:
    images, labels = example["image"], example["label"]
    with tf.GradientTape() as tape:
        logits = model(images / 255)
        loss_value = loss(logits, tf.one_hot(labels, 10))
    print(loss_value)
    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

tf.Tensor(14.329683, shape=(), dtype=float32)
tf.Tensor(14.966236, shape=(), dtype=float32)
tf.Tensor(14.363052, shape=(), dtype=float32)
tf.Tensor(15.004982, shape=(), dtype=float32)
tf.Tensor(14.663286, shape=(), dtype=float32)
tf.Tensor(14.271727, shape=(), dtype=float32)
tf.Tensor(15.438803, shape=(), dtype=float32)
tf.Tensor(14.569505, shape=(), dtype=float32)
tf.Tensor(15.110025, shape=(), dtype=float32)
tf.Tensor(14.663126, shape=(), dtype=float32)
tf.Tensor(12.894722, shape=(), dtype=float32)
tf.Tensor(11.901726, shape=(), dtype=float32)
tf.Tensor(13.78829, shape=(), dtype=float32)
tf.Tensor(14.178573, shape=(), dtype=float32)
tf.Tensor(14.210624, shape=(), dtype=float32)
tf.Tensor(14.399593, shape=(), dtype=float32)
tf.Tensor(13.51449, shape=(), dtype=float32)
tf.Tensor(14.958836, shape=(), dtype=float32)
tf.Tensor(15.484111, shape=(), dtype=float32)
tf.Tensor(12.283693, shape=(), dtype=float32)
tf.Tensor(14.774632, shape=(), dtype=float32)
tf.Tensor(13.73737, shape=(), dtype=

tf.Tensor(10.026777, shape=(), dtype=float32)
tf.Tensor(12.537981, shape=(), dtype=float32)
tf.Tensor(13.659925, shape=(), dtype=float32)
tf.Tensor(9.1675415, shape=(), dtype=float32)
tf.Tensor(12.823554, shape=(), dtype=float32)
tf.Tensor(12.406855, shape=(), dtype=float32)
tf.Tensor(10.739494, shape=(), dtype=float32)
tf.Tensor(13.856304, shape=(), dtype=float32)
tf.Tensor(10.062316, shape=(), dtype=float32)
tf.Tensor(11.2070055, shape=(), dtype=float32)
tf.Tensor(10.062761, shape=(), dtype=float32)
tf.Tensor(12.388039, shape=(), dtype=float32)
tf.Tensor(9.239778, shape=(), dtype=float32)
tf.Tensor(10.469443, shape=(), dtype=float32)
tf.Tensor(13.43004, shape=(), dtype=float32)
tf.Tensor(12.662985, shape=(), dtype=float32)
tf.Tensor(6.1153803, shape=(), dtype=float32)
tf.Tensor(11.487537, shape=(), dtype=float32)
tf.Tensor(11.319624, shape=(), dtype=float32)
tf.Tensor(9.880111, shape=(), dtype=float32)
tf.Tensor(11.590711, shape=(), dtype=float32)
tf.Tensor(11.799498, shape=(), dtype

tf.Tensor(10.620656, shape=(), dtype=float32)
tf.Tensor(11.976242, shape=(), dtype=float32)
tf.Tensor(11.660032, shape=(), dtype=float32)
tf.Tensor(9.91906, shape=(), dtype=float32)
tf.Tensor(9.216521, shape=(), dtype=float32)
tf.Tensor(3.868093, shape=(), dtype=float32)
tf.Tensor(5.6896744, shape=(), dtype=float32)
tf.Tensor(8.415145, shape=(), dtype=float32)
tf.Tensor(8.094677, shape=(), dtype=float32)
tf.Tensor(8.249207, shape=(), dtype=float32)
tf.Tensor(7.158701, shape=(), dtype=float32)
tf.Tensor(10.019398, shape=(), dtype=float32)
tf.Tensor(8.490942, shape=(), dtype=float32)
tf.Tensor(15.454088, shape=(), dtype=float32)
tf.Tensor(5.8039026, shape=(), dtype=float32)
tf.Tensor(7.535574, shape=(), dtype=float32)
tf.Tensor(5.470776, shape=(), dtype=float32)
tf.Tensor(9.189715, shape=(), dtype=float32)
tf.Tensor(8.534654, shape=(), dtype=float32)
tf.Tensor(9.432415, shape=(), dtype=float32)
tf.Tensor(11.343064, shape=(), dtype=float32)
tf.Tensor(5.376523, shape=(), dtype=float32)
tf.

tf.Tensor(8.898828, shape=(), dtype=float32)
tf.Tensor(9.20113, shape=(), dtype=float32)
tf.Tensor(5.5681496, shape=(), dtype=float32)
tf.Tensor(5.6651354, shape=(), dtype=float32)
tf.Tensor(9.4309, shape=(), dtype=float32)
tf.Tensor(9.418917, shape=(), dtype=float32)
tf.Tensor(7.9664326, shape=(), dtype=float32)
tf.Tensor(9.3547535, shape=(), dtype=float32)
tf.Tensor(4.4989433, shape=(), dtype=float32)
tf.Tensor(11.415926, shape=(), dtype=float32)
tf.Tensor(5.6703835, shape=(), dtype=float32)
tf.Tensor(6.445294, shape=(), dtype=float32)
tf.Tensor(7.3075686, shape=(), dtype=float32)
tf.Tensor(7.2636957, shape=(), dtype=float32)
tf.Tensor(4.7930126, shape=(), dtype=float32)
tf.Tensor(7.3937874, shape=(), dtype=float32)
tf.Tensor(3.6849945, shape=(), dtype=float32)
tf.Tensor(7.3069916, shape=(), dtype=float32)
tf.Tensor(8.460479, shape=(), dtype=float32)
tf.Tensor(10.4261, shape=(), dtype=float32)
tf.Tensor(7.2565618, shape=(), dtype=float32)
tf.Tensor(6.5152626, shape=(), dtype=float32)