In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise, squeeze_patches, read_div2k_data

#from model_global_dfn import GDFN
from model_baseline import Unet
from model_mwcnn import MWCNN
from model_mwkpn import MWKPN
from model_kpn import KPN, LossFunc, LossBasic

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

# Analyse de l'influence de l'ondelette - KPN - div2k

In [None]:
'''préparation des données'''
seed = 42
np.random.seed(seed)

(train_X_p, train_Y_p), (test_X_p, test_Y_p) = read_data('div2k')
N_ims= len(train_X_p)

train_X_p, label_train_X_p = squeeze_patches(train_X_p)
train_Y_p, label_train_Y_p = squeeze_patches(train_Y_p)
test_X_p, label_test_X_p = squeeze_patches(test_X_p)
test_Y_p, label_test_Y_p = squeeze_patches(test_Y_p)

train_X_p = train_X_p[:,np.newaxis,...]
train_Y_p = train_Y_p[...,np.newaxis]
test_X_p = test_X_p[:,np.newaxis,...]
test_Y_p = test_Y_p[...,np.newaxis]

print('\nTrain data:')
print('train_X_p:',train_X_p.shape)
print('train_Y_p:',train_Y_p.shape)

print('\nTest data:')
print('test_X_p:',test_X_p.shape)
print('test_Y_p:',test_Y_p.shape)

In [None]:
use_noise_map = False   # if True, concatenate a noise map to the input
#use_noise_est = False   # if True, use a model to estimate noise map, if False, use known info

if not use_noise_map:
    train_X_p = train_X_p[...,0][..., np.newaxis]
    test_X_p = test_X_p[...,0][..., np.newaxis]
    
print('Train data:')
print('train_X_p:',train_X_p.shape)
print('train_Y_p:',train_Y_p.shape)

print('\nTest data:')
print('test_X_p:',test_X_p.shape)
print('test_Y_p:',test_Y_p.shape)

In [None]:
# Use tf.data API to shuffle and batch data.
batch_size = 16

train_dataset = tf.data.Dataset.from_tensor_slices((train_X_p,train_Y_p))
train_dataset = train_dataset.repeat().shuffle(5000).batch(batch_size).prefetch(1)

test_dataset = tf.data.Dataset.from_tensor_slices((test_X_p,test_Y_p))
test_dataset = test_dataset.batch(batch_size).prefetch(1)

In [None]:
model = KPN(color=False, burst_length=1, blind_est=True, sep_conv=False, kernel_size=[3],
            channel_att=False, spatial_att=True, core_bias=True, use_bias=True)
#model = MWKPN(color=False, burst_length=1, blind_est=True, sep_conv=False, kernel_size=[3,5,7],
#             channel_att=False, spatial_att=True, core_bias=True, use_bias=True)

sub_dir = 'transfer_to_div2k'
filename = 'kpn_ks3_satt_bias_combinedsymetricloss'

load_model = True
if load_model:
    model.load_weights(filepath = "model_weights/" + sub_dir + '/' + filename + ".ckpt")

## Analyse the kernels

In [None]:
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
    #print(core[3].shape, core[5].shape, core[7].shape, bias.shape)
    
    core3 = tf.reshape(tf.reduce_mean(tf.squeeze(core[3]), [0,1,2], keepdims=False), [3,3]).numpy()
    #core5 = tf.reshape(tf.reduce_mean(tf.squeeze(core[5]), [0,1,2], keepdims=False), [5,5]).numpy()
    #core7 = tf.reshape(tf.reduce_mean(tf.squeeze(core[7]), [0,1,2], keepdims=False), [7,7]).numpy()
    print(core3.shape)
    #print(core5.shape)
    #print(core7.shape)


plt.figure(figsize = (30,10))

plt.subplot(1, 3, 1)
plt.imshow(core3, cmap='gray')
plt.title('3x3 filter')

# plt.subplot(1, 3, 2)
# plt.imshow(core5, cmap='gray')
# plt.title('5x5 filter')

# plt.subplot(1, 3, 3)
# plt.imshow(core7, cmap='gray')
# plt.title('7x7 filter')
    
#plt.savefig('./eval/' + sub_dir + '/kpn357_symetricloss/kernels.png')
plt.show()

## Analyse the filters applied to the images

In [None]:
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
    #print(core[3].shape, core[5].shape, core[7].shape, bias.shape)
    print(pred_test_Y.shape)
    print(core[3].shape)

In [None]:
plt.figure(figsize = (10, 10))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow(core[3].numpy().mean(axis=0).squeeze()[...,i], cmap='gray')
    plt.title('plan {} mean {}'.format(i+1, core[3].numpy().squeeze()[...,i].mean()))
    plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize = (30,80))
for i in range(16):
    plt.subplot(16, 6, 6*i+1)
    plt.imshow(batch_test_X[i, ...,0].numpy().squeeze(), cmap='gray')
    plt.title('noisy image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+2)
    plt.imshow(batch_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('ground truth')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+3)
    plt.imshow(pred_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('recovered image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+4)
    plt.imshow(tf.reduce_mean(core[3][i], axis=-1).numpy().squeeze(), cmap='gray')
    plt.title('filter 3x3 {:.3f}'.format(tf.reduce_mean(core[3][i]).numpy().squeeze()))
    plt.axis('off')

#     plt.subplot(16, 6, 6*i+5)
#     plt.imshow(tf.reduce_mean(core[5][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 5x5 {:.3f}'.format(tf.reduce_mean(core[5][i]).numpy().squeeze()))
#     plt.axis('off')
    
#     plt.subplot(16, 6, 6*i+6)
#     plt.imshow(tf.reduce_mean(core[7][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 7x7 {:.3f}'.format(tf.reduce_mean(core[7][i]).numpy().squeeze()))
#     plt.axis('off')
    
#plt.savefig('./eval/' + sub_dir + '/kpn357_symetricloss/recovered_images.png')
plt.show()

In [None]:
bins = 0.1*np.arange(11)

plt.figure(figsize = (30,10))

plt.subplot(1, 3, 1)
plt.hist(np.abs(core[3].numpy().flatten().squeeze()), bins=bins) 
plt.title('filter 3x3 {:.3f}'.format(tf.reduce_mean(core[3][i]).numpy().squeeze()))
plt.ylabel('number')
plt.xlabel('mean value')

# plt.subplot(1, 3, 2)
# plt.hist(np.abs(core[5].numpy().flatten().squeeze()), bins=bins) 
# plt.title('filter 5x5 {:.3f}'.format(tf.reduce_mean(core[5][i]).numpy().squeeze()))
# plt.ylabel('number')
# plt.xlabel('value')

# plt.subplot(1, 3, 3)
# plt.hist(np.abs(core[7].numpy().flatten().squeeze()), bins=bins) 
# plt.title('filter 7x7 {:.3f}'.format(tf.reduce_mean(core[7][i]).numpy().squeeze()))
# plt.ylabel('number')
# plt.xlabel('value')
    
# plt.savefig('./eval/' + sub_dir + '/kpn357_symetricloss/filter_vdistribution.png')
plt.show()

In [None]:
plt.figure(figsize = (90,160))
for i in range(16):
    for j in range(9):
        plt.subplot(16, 9, 9*i+j+1)
        plt.imshow(core[3][i,...,j].numpy().squeeze(), cmap='gray')
        plt.title("mean {:.3f}".format(core[3][i,...,j].numpy().squeeze().mean()), fontsize=40)
        #plt.title('noisy image')
        plt.axis('off')
    
#plt.savefig('./eval/' + sub_dir + '/kpn3/filter_3x3.png')
plt.show()

In [None]:
plt.figure(figsize = (30,10))

def normalize(im):
    return (im - im.min())/(im.max()-im.min())
    
plt.subplot(1, 3, 1)
n = 4
plt.imshow(core[3][0,...,n].numpy().squeeze()*batch_test_X[0, ...,0].numpy().squeeze(), cmap='gray')
plt.title('filtered image No.{} std {:.3f}'.format(n+1, normalize((core[3][0,...,n].numpy().squeeze()*batch_test_X[0, ...,0].numpy().squeeze())).std()))
plt.axis('off')

plt.subplot(1, 3, 2)
n = 0
plt.imshow(core[3][0,...,n].numpy().squeeze()*batch_test_X[0, ...,0].numpy().squeeze(), cmap='gray')
plt.title('filtered image No.{} std {:.3f}'.format(n+1, normalize((core[3][0,...,n].numpy().squeeze()*batch_test_X[0, ...,0].numpy().squeeze())).std()))
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(batch_test_X[0, ...,0].numpy().squeeze(), cmap='gray')
plt.title('final filtered image std {:.3f}'.format(normalize((core[3][0,...,n].numpy().squeeze()*batch_test_X[0, ...,0].numpy().squeeze())).std()))
plt.axis('off')
    
plt.show()

# Analyse the kernels

In [None]:
def normalize(ims):
    ims_new = []
    for im in ims:
        im = (im-im.min())/(im.max()-im.min())
        ims_new.append(im)
    return np.array(ims_new)


N_rate = 0.02

core3_selected_all = []
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset):
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
    
    core[3] = tf.reshape(core[3], [-1,9]).numpy()
    core3_ind = np.random.randint(core[3].shape[0], size=int(N_rate*core[3].shape[0]))
    core3_selected = core[3][core3_ind]
    
    core3_selected_all.append(core3_selected)
    
core3_selected_all = np.concatenate(core3_selected_all, axis=0)
#core3_selected_all = normalize(core3_selected_all)
print(core3_selected_all.shape)

## Calculate the variance of kernels 

In [None]:
stds_all = []
for k in core3_selected_all:
    std = np.std(k)
    stds_all.append(std)
stds_all = np.array(stds_all)

In [None]:
def create_bins(vmin, vmax, n):
    vdiff = vmax - vmin
    bins = [vmin]
    for i in range(n):
        bins.append(vmin+(i+1)/n*vdiff)
    return bins

bins = create_bins(stds_all.min(), stds_all.max(), 10)
plt.figure()
plt.hist(stds_all, bins=bins)
plt.xlabel('standard deviation')
plt.ylabel('numbers of kernels')
#plt.savefig('./eval/' + sub_dir + '/kpn3/kernels_std.png')
plt.show()

## Visualize the kernels

In [None]:
plt.figure(figsize = (10,10))
for i in range(16):
    ind = np.random.randint(core3_selected_all.shape[0], size=1)[0]
    plt.subplot(4,4,i+1)
    plt.imshow(core3_selected_all[ind].reshape(3,3), cmap='gray')
    plt.title(ind)
    plt.axis('off')
#plt.savefig('./eval/' + sub_dir + '/kpn3/kernels_exemples.png')
plt.show()

## Cluster the kernels

In [None]:
from sklearn.cluster import KMeans

n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
y_preds = kmeans.fit_predict(core3_selected_all)

In [None]:
def normalize(ims):
    ims_new = []
    for im in ims:
        im = (im-im.min())/(im.max()-im.min())
        ims_new.append(im)
    return np.array(ims_new)

normalize(kmeans.cluster_centers_)

In [None]:
kmeans.cluster_centers_

In [None]:
plt.figure(figsize = (5*n_clusters,5))
for i in range(n_clusters):
    plt.subplot(1,n_clusters,i+1)
    plt.imshow(kmeans.cluster_centers_[i].reshape(3,3), cmap='gray')
    plt.axis('off')
#plt.savefig('./eval/' + sub_dir + '/kpn3/kernels_kmeans3.png')
plt.show()

## PCA/ICA

In [None]:
from sklearn.decomposition import PCA, FastICA
from sklearn.mixture import GaussianMixture as GMM

pca = PCA(n_components = 2)
pca.fit(core3_selected_all)

core3_proj = pca.transform(core3_selected_all)
print(core3_proj.shape)

'''for color labels'''
n_clusters = 10
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
#kmeans = GMM(n_components=n_clusters, random_state=0)
kmeans.fit(core3_selected_all)

plt.figure(figsize = (8,8))
plt.scatter(core3_proj[:,0], core3_proj[:,1], edgecolor='none', c = kmeans.predict(core3_selected_all), alpha = 0.5, cmap=plt.cm.get_cmap('nipy_spectral', 10))
#plt.savefig('./eval/' + sub_dir + '/kpn3/kernels_pca.png')
plt.show()

In [None]:
kmeans = GMM(n_components=n_clusters, random_state=0)
kmeans.fit(core3_selected_all)

kmeans.predict(core3_selected_all)

In [None]:
pca_full = PCA(n_components = 9)
pca_full.fit(core3_selected_all)
cum_var = np.cumsum(pca_full.explained_variance_)

plt.figure()
axis_x = np.arange(1,10)
plt.plot(axis_x, cum_var)
plt.xlabel('reduced dimensionality')
plt.ylabel('cumulative variance')
plt.show()

# SVD

In [None]:
from scipy.linalg import svd

U,s,Vh = svd(core3_selected_all, full_matrices=False)
print(U.shape, s.shape, Vh.shape)
print(s)

K = 1
U_new = U[:, :K]
Vh_new = Vh[:K, :]
s_new = np.diag(s[:K])
print()
print(U_new.shape, s_new.shape, Vh_new.shape)

core3_selected_all_svd = U_new.dot(s_new).dot(Vh_new)
print(core3_selected_all_svd.shape)

In [None]:
plt.figure(figsize = (10,10))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(U[:,i][:,np.newaxis].dot(Vh[i,:][np.newaxis,:]).mean(axis=0).reshape(3,3), cmap='gray')
    plt.axis('off')
plt.show()

# Analyse the clusters

In [None]:
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters, random_state=0)

core3_all = []
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
    
    core3 = tf.reshape(core[3], [-1,9]).numpy()
    core3_all.append(core3)

core3_all = np.concatenate(core3_all, axis=0)
kmeans.fit(core3_all)

In [None]:
batch_test_X_flatten = []
K = 3
frame_pad = tf.pad(batch_test_X, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
for i in range(K):
    for j in range(K):
        batch_test_X_flatten.append(frame_pad[:, :, i:i+height, j:j+width,:])
batch_test_X_flatten = tf.stack(batch_test_X_flatten, axis=-1)       
print(batch_test_X_flatten.shape)

batch_test_X_flatten = batch_test_X_flatten.numpy().reshape(-1, 9)
print(batch_test_X_flatten.shape)

In [None]:
test_all = dict()
for i in range(n_clusters):
    test_all[i] = batch_test_X_flatten[np.where(kmeans.labels_==i)[0]].mean(axis=0).reshape(3,3)


plt.figure(figsize = (10*n_clusters,20))
for i in range(n_clusters):
    plt.subplot(2, n_clusters, i+1)
    plt.imshow(test_all[i].reshape(3,3), cmap='gray')
    plt.title('mean patch for label {}, mean {:.3f}, std {:.3f}'.format(i, test_all[i].mean(), test_all[i].std()), fontsize=25)
    plt.axis('off')

for i in range(n_clusters):
    plt.subplot(2, n_clusters, i+n_clusters+1)
    plt.imshow(kmeans.cluster_centers_[i].reshape(3,3), cmap='gray')
    plt.title('mean kernel for label {}, mean {:.3f}, std {:.3f}'.format(i, kmeans.cluster_centers_[i].mean(), kmeans.cluster_centers_[i].std()), fontsize=25)
    plt.axis('off')
plt.show()

In [None]:
num_exemples = 100
print(np.where(kmeans.labels_==0)[0].shape, np.where(kmeans.labels_==1)[0].shape, np.where(kmeans.labels_==2)[0].shape)

test_all0 = batch_test_X_flatten[np.where(kmeans.labels_==0)[0][:num_exemples]]
test_all1 = batch_test_X_flatten[np.where(kmeans.labels_==1)[0][:num_exemples]]
test_all2 = batch_test_X_flatten[np.where(kmeans.labels_==2)[0][:num_exemples]]

kernel_all0 = core3_all[np.where(kmeans.labels_==0)[0][:num_exemples]]
kernel_all1 = core3_all[np.where(kmeans.labels_==1)[0][:num_exemples]]
kernel_all2 = core3_all[np.where(kmeans.labels_==2)[0][:num_exemples]]

plt.figure(figsize = (20*3,100))
for i in range(10):
    plt.subplot(10, 6, 6*i+1)
    plt.imshow(test_all0[i].reshape(3,3), cmap='gray')
    plt.title('test(label0)', fontsize=25)
    plt.axis('off')
    
    plt.subplot(10, 6, 6*i+2)
    plt.imshow(kernel_all0[i].reshape(3,3), cmap='gray')
    plt.title('kernel(label0)', fontsize=25)
    plt.axis('off')
    
    plt.subplot(10, 6, 6*i+3)
    plt.imshow(test_all1[i].reshape(3,3), cmap='gray')
    plt.title('test(label1)', fontsize=25)
    plt.axis('off')
    
    plt.subplot(10, 6, 6*i+4)
    plt.imshow(kernel_all1[i].reshape(3,3), cmap='gray')
    plt.title('kernel(label1)', fontsize=25)
    plt.axis('off')
    
    plt.subplot(10, 6, 6*i+5)
    plt.imshow(test_all1[i].reshape(3,3), cmap='gray')
    plt.title('test(label2)', fontsize=25)
    plt.axis('off')
    
    plt.subplot(10, 6, 6*i+6)
    plt.imshow(kernel_all2[i].reshape(3,3), cmap='gray')
    plt.title('kernel(label2)', fontsize=25)
    plt.axis('off')
plt.show()

# Cluster the kernels by the direction of their corresponding patches 

In [None]:
from skimage.filters import sobel_h, sobel_v
from skimage.util import pad
from skimage.feature import hog

def vote_dir(im):
    im = pad(im.reshape(3,3), (1,1), 'edge')
    grad_h = sobel_h(im) # tendence up down
    grad_v = sobel_v(im) # tendence left right
    
    grad_h = grad_h[1:-1,1:-1].flatten()
    grad_v = grad_v[1:-1,1:-1].flatten()
    im = im[1:-1,1:-1].flatten()
    
    vote = np.zeros((8))
    for i in range(9):
        if grad_h[i] > 0 and grad_v[i] > 0:
            if grad_h[i] > grad_v[i]:
                vote[0] += 1
            else:
                vote[1] += 1
        elif grad_h[i] < 0 and grad_v[i] > 0:
            if abs(grad_h[i]) < grad_v[i]:
                vote[2] += 1
            else:
                vote[3] += 1
        elif grad_h[i] < 0 and grad_v[i] < 0:
            if grad_h[i] < grad_v[i]:
                vote[4] += 1
            else:
                vote[5] += 1
        else:
            if grad_h[i] < abs(grad_v[i]):
                vote[6] += 1
            else:
                vote[7] += 1
    return vote

def ims_dir(ims):
    dirs = []
    for im in ims:
        vote = vote_dir(im)
        cur_dir = np.argmax(vote)
        dirs.append(cur_dir)
    return np.array(dirs)

In [None]:
dirs = ims_dir(batch_test_X_flatten)

In [None]:
dirs = np.array(dirs)
batch_test_X_flatten_by_dir = dict()
core3_all_by_dir = dict()
for i in range(8):
    batch_test_X_flatten_by_dir[i] = batch_test_X_flatten[np.where(dirs==i)[0]]
    core3_all_by_dir[i] = core3_all[np.where(dirs==i)[0]]
    
plt.figure(figsize=(40,10))
for i in range(8):
    plt.subplot(2,8,i+1)
    plt.imshow(batch_test_X_flatten_by_dir[i].mean(axis=0).reshape(3,3), cmap='gray')
    plt.title('patch dir {}'.format(i))
    plt.axis('off')
    
    plt.subplot(2,8,i+9)
    plt.imshow(core3_all_by_dir[i].mean(axis=0).reshape(3,3), cmap='gray')
    plt.title('kernel dir {}'.format(i))
    plt.axis('off')
plt.show()

In [None]:
core3_all_by_dir[2].mean(axis=0).reshape(3,3)

In [None]:
''' test and visualize '''
p = pad(batch_test_X_flatten[1].reshape(3,3), (1,1), 'edge')
#fd, hog_im = hog(p, orientations = 8, pixels_per_cell = (1,1), cells_per_block = (1,1), visualize = True) 
#Hrr, Hrc, Hcc = hessian_matrix(p, sigma=0.1, mode='constant')
grad_h = sobel_h(p) # tendence up down
grad_v = sobel_v(p) # tendence left right
# Wh = grad_h**2 - grad_v**2
# Wv = 2*grad_h*grad_v
# Wh = Wh[1:-1,1:-1]
# Wv = Wv[1:-1,1:-1]
grad_h = grad_h[1:-1,1:-1]
grad_v = grad_v[1:-1,1:-1]
p = p[1:-1,1:-1]

plt.figure(figsize=(20,5))
plt.subplot(141)
plt.imshow(p, cmap='gray')
plt.title('patch')
plt.subplot(142)
plt.imshow(grad_h, cmap='gray')
plt.title('vertical gradient')
plt.subplot(143)
plt.imshow(grad_v, cmap='gray')
plt.title('horizontal gradient')
# plt.subplot(144)
# plt.imshow(Hcc, cmap='gray')
plt.show()

# Using the clusters to simulate the kernels

In [None]:
#kernel_size = [3,5,7]

def apply_filtering(frames, core, bias, kernel_size):
    img_stack = []
    pred_img = []
    kernel = kernel_size[::-1]
    for index, K in enumerate(kernel):
        if not len(img_stack):
            frame_pad = tf.pad(frames, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
            for i in range(K):
                for j in range(K):
                    img_stack.append(frame_pad[:, :, i:i+height, j:j+width,:])
            img_stack = tf.stack(img_stack, axis=-1)                 # (bs, N, h, w，color, K*K) 
        else:
            # k_diff = (kernel[index - 1]**2 - kernel[index]**2) // 2
            k_diff = (kernel[index-1] - kernel[index]) // 2
            k_chosen = []
            for i in range(k_diff, kernel[index-1]-k_diff):
                k_chosen += [i*kernel[index-1]+j for j in range(k_diff, kernel[index-1]-k_diff)]
            # img_stack = img_stack[..., k_diff:-k_diff]
            img_stack = tf.convert_to_tensor(img_stack.numpy()[..., k_chosen])
        pred_img.append(tf.reduce_sum(tf.math.multiply(core[K], img_stack), axis=-1, keepdims=False))
    pred_img = tf.stack(pred_img, axis=0)                           # (nb_kernels, bs, N, h, w, color)
    pred_img_i = tf.reduce_mean(pred_img, axis=0, keepdims=False)   # (bs, N, h, w, color)

    pred_img_i += bias

    pred_img = tf.reduce_mean(pred_img_i, axis=1, keepdims=False)          # (bs, h, w, color)
    return pred_img, pred_img_i

In [None]:
n_clusters = 3
#kmeans = KMeans(n_clusters=n_clusters, random_state=0)
gmm = GMM(n_components=n_clusters, random_state=0)

for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    '''obtain core by kpn model'''
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
#     print(core[3].shape, core[5].shape, core[7].shape, bias.shape)
    core3_all = tf.reshape(core[3], [-1, 9]).numpy()
    print(core3_all.shape)
    
    '''cluster the kernels by gmm and obtain the labels (or clustered kernels)'''
    #kmeans.fit(core3_all)
    #core3_all_clustered = kmeans.cluster_centers_[kmeans.labels_]  # use kmeans to cluster the kernels
    
    #gmm.fit(core3_all)
    core3_all_clustered = gmm.fit_predict(core3_all)  # use gmm to cluster the kernels
    core3_all_clustered = gmm.means_[core3_all_clustered]
    core3_all_clustered = core3_all_clustered.reshape(batch_size, N, height, width, color, -1)
    core3_all_clustered = dict({3: core3_all_clustered}) # use dict
    print(core3_all_clustered[3].shape)
    
    '''apply filters'''
    pred_test_Y3_clustered, _ = apply_filtering(batch_test_X, core3_all_clustered, bias, kernel_size = [3])
    print(pred_test_Y3_clustered.shape)
    
#     pred_test_Y5, _ = apply_filtering(batch_test_X, core[5], bias, kernel_size = [5])
#     pred_test_Y7, _ = apply_filtering(batch_test_X, core[7], bias, kernel_size = [7])

In [None]:
plt.figure(figsize = (30,80))
for i in range(16):
    plt.subplot(16, 6, 6*i+1)
    plt.imshow(batch_test_X[i, ...,0].numpy().squeeze(), cmap='gray')
    plt.title('noisy image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+2)
    plt.imshow(batch_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('ground truth')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+3)
    plt.imshow(pred_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('recovered image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+4)
    plt.imshow(pred_test_Y3_clustered[i].numpy().squeeze(), cmap='gray')
    plt.title('recovered image by clustered kernels')
    plt.axis('off')

#     plt.subplot(16, 6, 6*i+5)
#     plt.imshow(tf.reduce_mean(core[5][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 5x5 {:.3f}'.format(tf.reduce_mean(core[5][i]).numpy().squeeze()))
#     plt.axis('off')
    
#     plt.subplot(16, 6, 6*i+6)
#     plt.imshow(tf.reduce_mean(core[7][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 7x7 {:.3f}'.format(tf.reduce_mean(core[7][i]).numpy().squeeze()))
#     plt.axis('off')
    
# plt.savefig('./eval/' + sub_dir + '/kpn3/recovered_images_by_30clustered_kernels.png')
plt.show()

In [None]:
import math
'''fetch te patches over each pixel'''
batch_test_X_flatten = []
K = 3
frame_pad = tf.pad(batch_test_X, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
for i in range(K):
    for j in range(K):
        batch_test_X_flatten.append(frame_pad[:, :, i:i+height, j:j+width,:])
batch_test_X_flatten = tf.stack(batch_test_X_flatten, axis=-1).numpy().reshape(-1,9)

'''build the dictionaries'''
batch_test_X_dict = dict()
core3_all_dict = dict()
for i in range(n_clusters):
#     batch_test_X_dict[i] = batch_test_X_flatten[np.where(kmeans.labels_==i)[0]]
#     core3_all_dict[i] = core3_all[np.where(kmeans.labels_==i)[0]]
    batch_test_X_dict[i] = batch_test_X_flatten[np.where(gmm.predict(core3_all)==i)[0]]
    core3_all_dict[i] = core3_all[np.where(gmm.predict(core3_all)==i)[0]]

'''drax the patches and their corresponding kernels'''
plt.figure(figsize=(5*n_clusters,10))
for i in range(n_clusters):
    plt.subplot(2,n_clusters,i+1)
    plt.imshow(batch_test_X_dict[i].mean(axis=0).reshape(3,3), cmap='gray')
    plt.title('patch label {}'.format(i), fontsize=20)
    plt.axis('off')
    
    plt.subplot(2,n_clusters,i+n_clusters+1)
    plt.imshow(core3_all_dict[i].mean(axis=0).reshape(3,3), cmap='gray')
    plt.title('kernel label {}'.format(i), fontsize=20)
    plt.axis('off')
plt.show()

In [None]:
'''save the patches and their corresponding kernels to a txt file using pickle'''
import pickle

'''sum up the kernels'''
batch_test_X_dict_sum = core3_all_dict_sum = dict()
for i in range(n_clusters):
    batch_test_X_dict_sum[i] = batch_test_X_dict[i].mean(axis=0)
    core3_all_dict_sum[i] = core3_all_dict[i].mean(axis=0)

'''remove nan from dictionaries'''
def remove_nan_from_dict(my_dict): 
    nb_removed = 0
    new_dict = my_dict.copy()
    for k,v in my_dict.items():
        for e in v:
            if math.isnan(e):
                nb_removed += 1
                new_dict.pop(k)
                break
    print(nb_removed, 'items removed from the dictionary')
    return new_dict
batch_test_X_dict_sum = remove_nan_from_dict(batch_test_X_dict_sum)
core3_all_dict_sum = remove_nan_from_dict(core3_all_dict_sum)

with open("kernels/patches50.txt","wb") as f: # 'b' means opening file in binary mode
    pickle.dump(batch_test_X_dict_sum, f)

with open("kernels/kernels50.txt","wb") as f:
    pickle.dump(core3_all_dict_sum, f)
        
# with open('kernels/kernels.txt', 'rb') as f:
#     a = pickle.loads(f.read())  

# Use the patches information to decide which kernels to apply

In [None]:
n_clusters = 10
gmm = GMM(n_components=n_clusters, random_state=0)

for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    '''obtain core by kpn model'''
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
#     print(core[3].shape, core[5].shape, core[7].shape, bias.shape)
    core3_all = tf.reshape(core[3], [-1, 9]).numpy()
    print(core3_all.shape)
    
    '''obtain patches over each pixel'''
    batch_test_X_flatten = []
    K = 3
    frame_pad = tf.pad(batch_test_X, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
    for i in range(K):
        for j in range(K):
            batch_test_X_flatten.append(frame_pad[:, :, i:i+height, j:j+width,:])
    batch_test_X_flatten = tf.stack(batch_test_X_flatten, axis=-1)       
    batch_test_X_flatten = batch_test_X_flatten.numpy().reshape(-1, 9)
    print(batch_test_X_flatten.shape)
    
    '''cluster the patches by gmm and obtain the labels'''
    labels = gmm.fit_predict(batch_test_X_flatten)  # use gmm to cluster the kernels
    
    '''use the labels to cluster the kernels'''
    core3_all_clustered_dict = dict()
    for i in range(n_clusters):
        core3_all_clustered_dict[i] = core3_all[labels==i].mean(axis=0) 
    core3_all_clustered = np.array([core3_all_clustered_dict[labels[i]] for i in range(core3_all.shape[0])])
    core3_all_clustered = core3_all_clustered.reshape(batch_size, N, height, width, color, -1)
    core3_all_clustered = dict({3: core3_all_clustered}) # use dict
    print(core3_all_clustered[3].shape)
    
    '''apply filters'''
    pred_test_Y3_clustered, _ = apply_filtering(batch_test_X, core3_all_clustered, bias, kernel_size = [3])
    print(pred_test_Y3_clustered.shape)
    
#     pred_test_Y5, _ = apply_filtering(batch_test_X, core[5], bias, kernel_size = [5])
#     pred_test_Y7, _ = apply_filtering(batch_test_X, core[7], bias, kernel_size = [7])

In [None]:
plt.figure(figsize = (30,80))
for i in range(16):
    plt.subplot(16, 6, 6*i+1)
    plt.imshow(batch_test_X[i, ...,0].numpy().squeeze(), cmap='gray')
    plt.title('noisy image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+2)
    plt.imshow(batch_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('ground truth')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+3)
    plt.imshow(pred_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('recovered image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+4)
    plt.imshow(pred_test_Y3_clustered[i].numpy().squeeze(), cmap='gray')
    plt.title('recovered image by simulated kernels')
    plt.axis('off')

#     plt.subplot(16, 6, 6*i+5)
#     plt.imshow(tf.reduce_mean(core[5][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 5x5 {:.3f}'.format(tf.reduce_mean(core[5][i]).numpy().squeeze()))
#     plt.axis('off')
    
#     plt.subplot(16, 6, 6*i+6)
#     plt.imshow(tf.reduce_mean(core[7][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 7x7 {:.3f}'.format(tf.reduce_mean(core[7][i]).numpy().squeeze()))
#     plt.axis('off')
    
#plt.savefig('./eval/' + sub_dir + '/kpn3/recovered_images_by_10clustered_kernels.png')
plt.show()