In [None]:
from google.colab import drive
drive.mount('/content/drive')

import keras
from keras.callbacks import Callback
from keras.layers import Input, Conv2D, Conv2DTranspose, Add, Activation
from keras.models import Model
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import math
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pathlib import Path
import torch

print('GPU name: ', tf.config.experimental.list_physical_devices('GPU'))

base = Path.cwd() / 'drive' / 'MyDrive' / 'cvue23' / 'data'
base.exists()


In [None]:

# for debugging
def plot_image_grid(images_array, grid_width=10, grid_height=10):

    if images_array.shape[0] != grid_width * grid_height:
        raise ValueError("The number of images does not match the grid size.")

    fig, axes = plt.subplots(grid_height, grid_width, figsize=(grid_width, grid_height))

    for i, ax in enumerate(axes.flatten()):
        ax.imshow(images_array[i], cmap='gray', interpolation='none')
        ax.axis('off')

    plt.tight_layout()
    plt.show()

def load_data(path):
  # load the mini dataset
  loaded = np.load(path)
  x, y, labels = loaded['x'], loaded['y'], loaded['labels']

  # normalize dataset
  assert y.max() == 255
  assert x.max() == 255

  y = y / 255
  x = x / 255
  return x, y, labels

def encoder(x, num_features, num_layers, residual_every=2):
    x = Conv2D(num_features, kernel_size=3, strides=2, padding='same', activation='relu')(x)

    # Save the output of conv layers at even indices
    residuals = []

    # Encoder
    for i in range(num_layers - 1):
        x = Conv2D(num_features, kernel_size=3, padding='same', activation='relu')(x)
        if (i + 1) % residual_every == 0:
            residuals.append(x)

    return x, residuals

def decoder(x, num_features, num_layers, residuals, residual_every=2):

    # Decoder
    for i in range(num_layers - 1):
        x = Conv2DTranspose(num_features, kernel_size=3, padding='same')(x)

        if (i + 1 + num_layers) % residual_every == 0 and residuals:
            res = residuals.pop()
            x = Add()([x, res])

        x = Activation('relu')(x)

    if residuals: raise ValueError('There are unused residual connections')

    # create 1-channel output
    x = Conv2DTranspose(1, kernel_size=3, strides=2, padding='same')(x)

    return x

def REDNet(num_layers, num_features, channel_size):
    '''Model definition with keras functional layers api'''

    inputs = Input(shape=(None, None, channel_size))

    x, residuals = encoder(inputs, num_features, num_layers)

    x = decoder(x, num_features, num_layers, residuals)

    # Add input residual, needed to do 1x1 conv to adapt channels
    residual = Conv2DTranspose(1, kernel_size=1, padding='same')(inputs)
    outputs = Add()([x, residual])
    outputs = Activation('relu')(outputs)

    # Create model
    model = Model(inputs=inputs, outputs=outputs, name=f'REDNet{num_layers*2}')
    return model

class PredictionCallback(Callback):
    def __init__(self, interval, x_val, y_val):
        super(PredictionCallback, self).__init__()
        self.interval = interval
        self.x_val = x_val
        self.y_val = y_val

    def on_epoch_end(self, epoch, logs={}):
        if epoch % self.interval == 0:
            preds = self.model.predict(self.x_val).squeeze()
            plot_image_grid(np.concatenate([self.y_val, preds]), preds.shape[0], 2)


def load(path: Path, normalize=True):
    # load the mini dataset
    loaded = np.load(path)
    x, y, labels = loaded['x'], loaded['y'], loaded['labels']

    if normalize:
        y = y / 255
        x = x / 255

    return x, y, labels

def sort_by_number(path):
        return int(str(path.stem).split('_')[-1])


class AOSPyDataset(keras.utils.PyDataset):

    def __init__(self, files, batch_size, **kwargs):
        super().__init__(**kwargs)
        self.files = files
        self.batch_size = batch_size

    def __len__(self):
        # Return number of batches.
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        low = idx * self.batch_size
        high = min(low + self.batch_size, len(self.x))

        batch = self.files[low:high]
        x, y = [],[]
        for f in batch:
            x,y,_ = load(f)
            x.append(x)
            y.append(y)

        return np.stack(x), np.stack(y)


In [None]:
paths = [f for f in (base / 'crop_1').iterdir() if f.is_file()]
paths = list(sorted(paths, key=sort_by_number))
aos_dataset = AOSPyDataset(paths)

In [None]:
x_val, y_val, val_labels = load_data(base / 'part_1_original_test_100.npz')

In [None]:
# compile the model
model = REDNet(
    num_layers=5,
    num_features=64,
    channel_size=x_val.shape[-1]
)

opt = keras.optimizers.Adam(
    learning_rate=0.0001
)

loss = keras.losses.MeanSquaredError(
    reduction="sum_over_batch_size",
    name="mse"
)

prediction_callback = PredictionCallback(interval=10, x_val=x_val[:30], y_val=y_val[:30])
#prediction_callback = PredictionCallback(interval=10, x_val=x_train[:30], y_val=y_train[:30])


model.compile(loss=loss,optimizer=opt)


In [None]:
# train on the dataset

batch_size=16
history = model.fit(
    x=DataLoader(aos_dataset, batch_size, shuffle=True),
    epochs=1,
    #validation_split=0.1,
    validation_data=(x_val, y_val),
    callbacks=[prediction_callback]
)
# Save the weights
model.save_weights(base.parent / 'models' / 'new')

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('MSE Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()