In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise, squeeze_patches

#from model_global_dfn import GDFN
from model_baseline import Unet
from model_mwcnn import MWCNN
from model_mwkpn import MWKPN
from model_kpn import KPN, LossFunc, LossBasic

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

In [None]:
'''Préparation des données'''
ims, ims_noise = read_data('imagenet')

N_ims, h, w, color = ims.shape
ims = ims[:N_ims].astype(np.float32)
ims_noise = ims_noise[:N_ims].astype(np.float32)

In [None]:
# train test split
test_size = 0.1

train_X, train_Y = ims_noise, ims
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=test_size, random_state=42)

#train_X = train_X[:,np.newaxis,...]
#test_X = test_X[:,np.newaxis,...]

print('Training X: ', train_X.shape, train_X.dtype, train_X.max(), train_X.min())
print('Training Y: ', train_Y.shape, train_Y.dtype, train_Y.max(), train_Y.min())
print('Testing X: ', test_X.shape, test_X.dtype, test_X.max(), test_X.min())
print('Testing Y: ', test_Y.shape, test_Y.dtype, test_Y.max(), test_Y.min())

In [None]:
# Use tf.data API to shuffle and batch data.
batch_size = 16

train_dataset = tf.data.Dataset.from_tensor_slices((train_X,train_Y))
train_dataset = train_dataset.repeat().shuffle(5000).batch(batch_size).prefetch(1)

test_dataset = tf.data.Dataset.from_tensor_slices((test_X,test_Y))
test_dataset = test_dataset.batch(batch_size).prefetch(1)

In [None]:
#model = KPN(color=False, burst_length=1, blind_est=True, sep_conv=False, kernel_size=[3,5,7],
#            channel_att=False, spatial_att=True, core_bias=True, use_bias=True)
#model = GDFN(color=False, num_filters=5, channel_att=False, spatial_att=True)


load_model = True
if load_model:
    model.load_weights(filepath="model_weights/global_dfn.ckpt")

# Analyse des filtres dynamiques globaux - ImageNet

In [None]:
color = 1 if color == False else 3

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [None]:
for test_X, test_Y in test_dataset.take(1):
    #pred_Y,core = model(test_X, test_X)
    pred_Y, core = model(test_X)

print(core.shape)

In [None]:
plt.figure(figsize = (15,5*num_filters))
for i in range(num_filters):
    cur_core = core[:,:,:,i*color**2:(i+1)*color**2]
    cur_core = tf.reduce_mean(cur_core, axis=0, keepdims=False)
    
    plt.subplot(num_filters,1,i+1)
    plt.imshow(cur_core.numpy().squeeze(), cmap='gray')
    plt.axis('off')
    
#plt.savefig('./eval/gdfn_'+current_time+'.png')
plt.show()