In [54]:
import os
import datetime
import tensorflow as tf
import warnings
from tensorflow.keras import backend as K, preprocessing, \
                             models, layers, optimizers, \
                             utils, callbacks, initializers, \
                             activations, regularizers, applications, \
                             constraints, Model
import numpy as np
import numpy.random as rand
import scipy.fftpack as fft
import pandas as pd
import cv2 as cv
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm.notebook import tqdm, trange
from jupyterthemes import jtplot

In [4]:
%matplotlib inline
jtplot.reset()
# plt.switch_backend('Agg')
sns.set()
sns.set_context("notebook")
plt.rcParams["axes.axisbelow"] = True
plt.rcParams["figure.figsize"] = (5, 5)
plt.rcParams["figure.dpi"] = 100

## Input pipeline

In [144]:
traindir = "D:/Kenneth/Documents/VIP/Datasets/dr2imagenet/TRAIN"
testdir = "D:/Kenneth/Documents/VIP/Datasets/dr2imagenet/TEST"
sessiondir = datetime.datetime.now().strftime('%Y%m%d')
os.makedirs(sessiondir, exist_ok=True)

img_w, img_h, img_ch = 64, 64, 1
epochs = int(1e6)
batch_size = 10

In [190]:
def parse_image(filename):
    image = tf.io.read_file(filename)
    image = tf.image.decode_png(image)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [img_h, img_w])
    return image

def prepare_for_training(ds, cache=True):
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
    else:
        ds = ds.cache()
    ds = ds.repeat()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return ds

def tfdata_generator(datagen):
    ds = tf.data.Dataset.from_generator(
        lambda: datagen,
        output_types=(tf.float32),
        output_shapes=tf.TensorShape([None, None, None])
    )
    ds = ds.cache()
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

In [191]:
filenames = os.listdir(testdir)
train_df = pd.DataFrame({
    'filenames': filenames
})
total_train = len(train_df)
total_train

12

In [192]:
train_datagen = preprocessing.image.ImageDataGenerator(rescale=1/255)
train_gen = train_datagen.flow_from_dataframe(
    train_df,
    testdir,
    x_col='filenames',
    y_col=None,
    target_size=(img_h, img_w),
    color_mode='grayscale',
    class_mode=None,
    batch_size=batch_size,
    shuffle=False,
)

Found 12 validated image filenames.


In [193]:
train_ds = tfdata_generator(train_gen)

## Model setup

PCAN weight update rule:

\begin{equation}
    \Delta \mathbf{w}_m = \gamma \left( x_m^\prime \mathbf{x} - x_m^\prime \mathbf{w}_m - \sum_{j=1}^{m-1} x_m^\prime x_j^\prime \mathbf{w}_j \right)
\end{equation}


In [194]:
model = tf.keras.Sequential([
    layers.Flatten(input_shape=(img_h, img_w, img_ch)),
    layers.Dense(
        units=img_h*img_w*img_ch,
        activation='linear'
    ),
    layers.Reshape(target_shape=(img_h, img_w, img_ch))
])

model.summary()

Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_13 (Flatten)         (None, 4096)              0         
_________________________________________________________________
dense_13 (Dense)             (None, 4096)              16781312  
_________________________________________________________________
reshape_13 (Reshape)         (None, 64, 64, 1)         0         
Total params: 16,781,312
Trainable params: 16,781,312
Non-trainable params: 0
_________________________________________________________________


In [195]:
def l2loss(y_true, y_pred):
    return tf.nn.l2_loss(y_pred - y_true)

def ssim(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

In [196]:
class StopOnValue(callbacks.Callback):
    def __init__(self, 
                 monitor='val_loss', 
                 value=0.00001, 
                 mode='min',
                 verbose=0):
        super(callbacks.Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose
        self.mode = mode
        if self.mode == 'min':
            self.compare_op = np.less
        elif self.mode == 'max':
            self.compare_op = np.greater
    
    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn('Early stopping requires %s available!' % self.monitor, 
                          RuntimeWarning)
            
        if self.compare_op(current, self.value):
            if self.verbose > 0:
                print('Epoch %05d: early stopping THR' % epoch)
            self.model.stop_training = True

## Training

In [199]:
optimizer = optimizers.SGD(learning_rate=1e-3)
stopval = StopOnValue(
    monitor='loss',
    value=0.1
)

In [200]:
model.compile(
    optimizer, 
    loss=l2loss, 
    metrics=[ssim]
)

In [201]:
history = model.fit(
    train_gen,
    epochs=5, 
    verbose=1, 
    steps_per_epoch=total_train//batch_size,
)

Train for 1 steps
Epoch 1/5


IndexError: tuple index out of range