In [None]:
from __future__ import absolute_import
import os
import os.path
from shutil import copyfile
import sys
sys.path.append(os.path.abspath('../../'))
# ----------------- import keras tools ----------------------
from keras.models import Model
from keras.layers import Input, Conv2D, Add, Reshape, Lambda, Concatenate, ZeroPadding2D
from keras import backend as K
from keras.utils import plot_model

from mnn.layers import CNNK1D, CNNR1D, CNNI1D, WaveLetC1D, InvWaveLetC1D
from mnn.layers import CNNK2D
from mnn.callback import SaveBestModel
# ---------------- import python packages --------------------
import argparse
import h5py
import numpy as np
import math
from matplotlib import pyplot as plt
import matplotlib.colors
from myplot import myPolarPlot

In [None]:
parser = argparse.ArgumentParser(description='Scattering -- 2D')
parser.add_argument('--epoch', type=int, default=40, metavar='N',
                    help='# epochs for training in the each round (default: %(default)s)')
parser.add_argument('--input-prefix', type=str, default='scafullV1N4', metavar='N',
                    help='prefix of input data filename (default: %(default)s)')
parser.add_argument('--alpha', type=int, default=40, metavar='N',
                    help='number of channels for the depth for training (default: %(default)s)')
parser.add_argument('--n-cnn', type=int, default=6, metavar='N',
                    help='number CNN layers (default: %(default)s)')
parser.add_argument('--n-cnn3', type=int, default=5, metavar='N',
                    help='number CNN layers (default: %(default)s)')
parser.add_argument('--noise', type=float, default=0, metavar='noise',
                    help='noise on the measure data (default: %(default)s)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate for the first round (default: %(default)s)')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='batch size (default: %(default)s)')
parser.add_argument('--verbose', type=int, default=2, metavar='N',
                    help='verbose (default: %(default)s)')
parser.add_argument('--output-suffix', type=str, default=None, metavar='N',
                    help='suffix output filename(default: )')
parser.add_argument('--percent', type=float, default=4./5., metavar='precent',
                    help='percentage of number of total data(default: %(default)s)')
parser.add_argument('--initialvalue', type=str, default=None, metavar='filename',
                    help='filename storing the weights of the model (default: '')')
parser.add_argument('--w-comp', type=int, default=1, metavar='N',
                    help='window size of the compress(default: %(default)s)')
parser.add_argument('--data-path', type=str, default='data/', metavar='string',
                    help='data path (default: )')
parser.add_argument('--log-path', type=str, default='logs/', metavar='string',
                    help='log path (default: )')
args = parser.parse_args(args=[])

In [None]:
args.output_suffix = 'ttT2'
args.input_prefix = 'scafullV2N4'
args.alpha = 40
N_epoch = args.epoch
alpha = args.alpha
N_cnn = args.n_cnn
N_cnn3 = args.n_cnn3
lr = args.lr
percent = args.percent
batch_size = args.batch_size
noise = args.noise
noise_rate = noise / 100.
input_prefix = args.input_prefix
output_suffix = args.output_suffix
data_path = args.data_path + '/'
log_path = args.log_path + '/'
print(f'N_epoch = {N_epoch}\t alpha = {alpha}\t (N_cnn, N_cnn3) = ({N_cnn}, {N_cnn3})\t batch size = {batch_size}')
print(f'lr = {lr:.2e}\t percent = {percent}\t noise = {noise}')
print(f'input_prefix = {input_prefix}\t output suffix = {output_suffix}')

In [None]:
if not os.path.exists(log_path):
    os.mkdir(log_path)

outputfilename  = log_path + 'S2d' + input_prefix[7:] \
    + 'Nc' + str(N_cnn) + 'Al' + str(alpha)
if abs(int(noise) - noise) < 1.e-6:
    outputfilename += "Ns" + str(int(noise))
else:
    outputfilename += "Ns" + str(noise)
outputfilename += output_suffix or str(os.getpid())
modelfilename   = outputfilename + '.h5'
outputfilename += '.txt'
log_os          = open(outputfilename, "w+")

def output(obj):
    print(obj)
    log_os.write(str(obj)+'\n')

def outputnewline():
    log_os.write('\n')
    log_os.flush()

output(f'output filename is {outputfilename}')

In [None]:
filenameIpt = data_path + input_prefix + '.h5'
print('Reading data...')
fin = h5py.File(filenameIpt, 'r')
InputArray = fin['measure'][:]
OutputArray = fin['coe'][:]
Nsamples, Ns, Nd = InputArray.shape
assert OutputArray.shape[0] == Nsamples
Nsamples, Nt, Nr = OutputArray.shape
Nd *= 2
tmp = InputArray
tmp2 = np.concatenate([tmp[:, Ns//2:Ns, :], tmp[:,0:Ns//2, :]], axis=1)
InputArray = np.concatenate([tmp, tmp2], axis=2)
InputArray = InputArray[:, :, Nd//4:3*Nd//4]
print('Reading data finished')
Nsamples, Ns, Nd = InputArray.shape
print(f'Input shape is {InputArray.shape}')
print(f'Output shape is {OutputArray.shape}')

In [None]:
x = np.linspace(0.0, 1.0, Nt+1)[None, :]
y = np.linspace(0, 1, Nr+1)
y = y[:, None]
x, y = np.broadcast_arrays(x, y)
th = 2.0*np.pi*x
rr = y

In [None]:
k = 4
n_col = 3
n_row = 1
plt.figure(figsize=(12, 4))
plt.subplot(n_row, n_col, 1, polar='True')
plt.gca().pcolormesh(th, rr, np.transpose(OutputArray[k, :, :], (1,0)))
plt.thetagrids([])
plt.rgrids([])

plt.subplot(n_row, n_col, 2)
plt.imshow(OutputArray[k, :, :])
plt.subplot(n_row, n_col, 3)
plt.imshow(InputArray[k, :, :])
plt.colorbar()
plt.tight_layout()

In [None]:
output('alpha                   = %d\t' % alpha)
outputnewline()
output('Input data filename     = %s' % filenameIpt)
output("(Ns, Nd)                = (%d, %d)" % (Ns, Nd))
output("(Nt, Nr)                = (%d, %d)" % (Nt, Nr))
output("Nsamples                = %d" % Nsamples)
outputnewline()

In [None]:
n_train = int(Nsamples * percent)
n_test  = min(max(n_train, 5000), Nsamples - n_train)
BATCH_SIZE = batch_size
n_valid = 512

In [None]:
factor = 1
InputArray *= factor
output(f'factor on the input data is {factor}')
mean_out = 0
max_out = np.amax(OutputArray)
min_out = np.amin(OutputArray)
pixel_max = max_out - min_out
OutputArray /= 0.5 * pixel_max
output(f'max / min of the output data are ({max_out:0.2f}, {min_out:0.2f})')
max_out = np.amax(OutputArray)
min_out = np.amin(OutputArray)
pixel_max = max_out - min_out
output(f'max / min of the output data are ({max_out:0.2f}, {min_out:0.2f})')

In [None]:
n_input  = (Ns, Nd)
n_output = (Nt, Nr)
output("[n_input, n_output] = [(%d,%d),  (%d,%d)]" % (n_input + n_output))
output("[n_train, n_test, n_valid]   = [%d, %d, %d]" % (n_train, n_test, n_valid))
output("batch size = %d" % BATCH_SIZE)
output("noise rate = %.2e" % noise_rate)

In [None]:
X_train = InputArray[0:n_train, :, :]
Y_train = OutputArray[0:n_train, :, :]
X_test  = InputArray[n_train:(n_train+n_test), :, :]
Y_test  = OutputArray[n_train:(n_train+n_test), :, :]

# ---------- add noise on the input data ----------------------
noiseTrain = np.random.randn(n_train, Ns, Nd) * noise_rate
X_train = X_train * (1 + noiseTrain)
noiseTest = np.random.randn(n_test, Ns, Nd) * noise_rate
X_test = X_test * (1 + noiseTest)

In [None]:
weight_pixel = np.arange(1, 2*Nr+1, 2)

def PSNR(img1, img2, pixel_max=1.0):
    dimg = (img1 - img2) / pixel_max
    mse = np.maximum(np.mean(dimg**2), 1.e-10)
    return -10 * math.log10(mse)

def PSNRs(imgs1, imgs2, pixel_max=1.0):
    dimgs = (imgs1 - imgs2) / pixel_max
    mse = np.maximum(np.mean(dimgs**2, axis=(1,2)), 1.e-10)
    return -10 * np.mean(np.log10(mse))

def test_data(model, X, Y):
    Yhat = model.predict(X, n_valid)
#     errs = np.linalg.norm((Yhat - Y) * weight_pixel, axis=(1, 2)) / np.linalg.norm((Y+mean_out) * weight_pixel, axis=(1, 2))
#     return errs
    return -PSNRs(Yhat, Y, pixel_max)

def check_result(model):
    return (test_data(model, X_train[0:n_valid, ...], Y_train[0:n_valid, ...]),
            test_data(model, X_test[0:n_valid, ...], Y_test[0:n_valid, ...]))

def test_data_mh(model_mh, X, Y):
    Yhat = model_mh.predict(X, n_valid)
#     errs1 = np.linalg.norm((Yhat[0] - Y) * weight_pixel, axis=(1, 2)) / np.linalg.norm((Y+mean_out)*weight_pixel, axis=(1, 2))
#     errs2 = np.linalg.norm((Yhat[1] - Y) * weight_pixel, axis=(1, 2)) / np.linalg.norm((Y+mean_out)*weight_pixel, axis=(1, 2))
#     return (errs1, errs2)
    return (-PSNRs(Yhat[0], Y, pixel_max), -PSNRs(Yhat[1], Y, pixel_max))

def check_result_mh(model_mh):
    return test_data_mh(model_mh, X_test[0:n_valid, ...], Y_test[0:n_valid, ...])

def splitScaling1D(X, alpha):
    return Lambda(lambda x: x[:, :, alpha:2*alpha])(X)


def splitWavelet1D(X, alpha):
    return Lambda(lambda x: x[:, :, 0:alpha])(X)

def Padding_x(x, s):
    return K.concatenate([x[:, x.shape[1]-s:x.shape[1], ...], x, x[:, 0:s, ...]], axis=1)

def __TriangleAdd(X, Y, alpha):
    return K.concatenate([X[:, :, 0:alpha], X[:, :, alpha:2*alpha] + Y], axis=2)

def TriangleAdd(X, Y, alpha):
    return Lambda(lambda x: __TriangleAdd(x[0], x[1], alpha))([X, Y])

In [None]:
bc = 'period'
w_comp = args.w_comp
w_interp = w_comp
L = math.floor(math.log2(Ns)) - 2  # number of levels
m = Ns // 2**L     # size of the coarse grid
m = 2 * ((m+1)//2) - 1
w = 2 * 3    # support of the wavelet function
n_b = 5      # bandsize of the matrix
output("(L, m) = (%d, %d)" % (L, m))

Ipt = Input(shape=n_input)
Ipt_c = CNNK1D(alpha, w_comp, activation='linear', bc_padding=bc)(Ipt)

bt_list = (L+1) * [None]
b = Ipt_c
for ll in range(1, L+1):
    bt = WaveLetC1D(2*alpha, w, activation='linear', use_bias=False)(b)
    bt_list[ll] = bt
    b = splitScaling1D(bt, alpha)

# (b,t) --> d
# d^L = A^L * b^L
d = b
for k in range(0, N_cnn):
    d = CNNK1D(alpha, m, activation='relu', bc_padding='period')(d)

# d = T^* * (D tb + (0,d))
for ll in range(L, 0, -1):
    d1 = bt_list[ll]
    for k in range(0, N_cnn):
        d1 = CNNK1D(2*alpha, n_b, activation='relu', bc_padding='period')(d1)

#     d11 = splitWavelet1D(d1, alpha)
#     d12 = splitScaling1D(d1, alpha)
#     d12 = Add()([d12, d])
#     d = Concatenate(axis=-1)([d11, d12])
#     d = Lambda(lambda x: TriangleAdd(x[0], x[1], alpha))([d1, d])
    d = TriangleAdd(d1, d, alpha)
    d = InvWaveLetC1D(2*alpha, w//2, Nout=Nt//(2**(ll-1)), activation='linear', use_bias=False)(d)

Img_c = d

Img = CNNK1D(Nr, w_interp, activation='linear', bc_padding=bc)(Img_c)
Img_p = Reshape(n_output+(1,))(Img)
for k in range(0, N_cnn3-1):
    Img_p = Lambda(lambda x: Padding_x(x, 1))(Img_p)
    Img_p = ZeroPadding2D((0, 1))(Img_p)
    Img_p = Conv2D(4, 3, activation='relu')(Img_p)
    # Img_p = CNNK2D(4, 3, activation='relu', bc_padding=bc)(Img_p)

Img_p = Lambda(lambda x: Padding_x(x, 1))(Img_p)
Img_p = ZeroPadding2D((0, 1))(Img_p)
Img_p = Conv2D(1, 3, activation='linear')(Img_p)
# Img_p = CNNK2D(1, 3, activation='linear', bc_padding=bc)(Img_p)
Opt = Reshape(n_output)(Img_p)
Opt = Add()([Img, Opt])

In [None]:
# plot_model(model, to_file='figeit2dInv.png', show_shapes=True)

In [None]:
lr_bs = []
for bs in range(0, 5):
    lr_bs.append([BATCH_SIZE * 2**bs, lr])

for ll in range(1, 5):
    lr_bs.append([BATCH_SIZE * 2**bs, lr * math.sqrt(0.1)**ll])
    
print(lr_bs)

In [None]:
lr_bs2 = []
for bs in range(3, 5):
    lr_bs2.append([BATCH_SIZE * 2**bs, lr])

for ll in range(1, 5):
    lr_bs2.append([BATCH_SIZE * 2**bs, lr * math.sqrt(0.1)**ll])
    
print(lr_bs2)

In [None]:
model_multihead = Model(inputs=Ipt, outputs=[Opt, Img])
model_multihead.compile(loss='mean_squared_error', optimizer='Nadam', loss_weights=[1., 1.])
model_multihead.optimizer.schedule_decay = (0.004)
output('number of params = %d' % model_multihead.count_params())
save_best_model_mh = SaveBestModel(modelfilename, check_result=check_result_mh, period=1,
                                   patience=10, output=output, test_weight=0., verbose=2)

In [None]:
N_epoch = 20
n_epochs_pre = 0
N_e = n_epochs_pre + 2 * N_epoch
for b_s, l_r in lr_bs:
    model_multihead.optimizer.lr = (l_r)
    model_multihead.stop_training = False
    model_multihead.fit(X_train, [Y_train, Y_train], batch_size=b_s, epochs=N_e,
                        initial_epoch=n_epochs_pre, verbose=2, callbacks=[save_best_model_mh])
    n_epochs_pre = N_e
    N_e += N_epoch
    model_multihead.load_weights(modelfilename, by_name=False)  # re-load the best model
    save_best_model_mh.best_epoch_update = n_epochs_pre
    Yhat_tmp = model_multihead.predict(X_test[0:100, ...], 100)
    Yhat0 = Yhat_tmp[0]
    Yhat1 = Yhat_tmp[1]
#     print(f'PSNR for Opt: {PSNRs(Y_test[0:100, ...], Yhat0, pixel_max):.3g}')
#     print(f'PSNR for Img: {PSNRs(Y_test[0:100, ...], Yhat1, pixel_max):.3g}')
#     Yhat0 = np.maximum(np.minimum(Yhat0, max_out), min_out)
#     Yhat1 = np.maximum(np.minimum(Yhat1, max_out), min_out)
#     print(f'PSNR for Opt after post-precessing: {PSNRs(Y_test[0:100, ...], Yhat0, pixel_max):.3g}')
#     print(f'PSNR for Img after post-precessing: {PSNRs(Y_test[0:100, ...], Yhat1, pixel_max):.3g}')
    for idx in (0, 1):
        datas = [Y_test[idx, ...], Yhat0[idx, ...], Yhat1[idx, ...]]
        for k in range(len(datas)):
            datas[k] = np.transpose(datas[k], [1, 0])

        myPolarPlot(th, rr, datas)

In [None]:
if n_epochs_pre > 0:
    model_multihead.load_weights(modelfilename, by_name=False)  # re-load the best model

In [None]:
# model: final model
model = Model(inputs=Ipt, outputs=Opt)
model.compile(loss='mean_squared_error', optimizer='Nadam')
model.optimizer.schedule_decay = (0.004)
output('number of params = %d' % model.count_params())

In [None]:
N_epoch = 40
save_best_model = SaveBestModel(modelfilename, check_result=check_result, period=1,
                                    patience=10, output=output, test_weight=1., verbose=2)
try:
    n_epochs_pre
except:
    n_epochs_pre = 0

N_e = n_epochs_pre + 2 * N_epoch
for b_s, l_r in lr_bs2:
    print(b_s, l_r)
    model.optimizer.lr = (l_r)
    model.stop_training = False
    model.fit(X_train, Y_train, batch_size=b_s, epochs=N_e,
              initial_epoch=n_epochs_pre, verbose=2, callbacks=[save_best_model])
    n_epochs_pre = N_e
    N_e += N_epoch
    model.load_weights(modelfilename, by_name=False)
    save_best_model.best_epoch_update = n_epochs_pre
    Yhat = model.predict(X_test[0:100, ...], 100)
#     print(f'PSNR for Opt: {PSNRs(Y_test[0:100, ...], Yhat, pixel_max):.3g}')
#     Yhat = np.maximum(np.minimum(Yhat, max_out), min_out)
#     print(f'PSNR for Opt after post-precessing: {PSNRs(Y_test[0:100, ...], Yhat, pixel_max):.3g}')
    dY = Yhat - Y_test[0:100, ...]
    for idx in (0, 1):
        datas = [Y_test[idx, ...], Yhat[idx, ...], dY[idx, ...]]
        for k in range(len(datas)):
            datas[k] = np.transpose(datas[k], [1, 0])

        myPolarPlot(th, rr, datas)

In [None]:
Yhat = model.predict(X_test[0:100, ...], 100)
print(f'PSNR for Opt: {PSNRs(Y_test[0:100, ...], Yhat, pixel_max):.5g}')
Yhat = np.maximum(np.minimum(Yhat, max_out), min_out)
print(f'PSNR for Opt after post-precessing: {PSNRs(Y_test[0:100, ...], Yhat, pixel_max):.5g}')
dY = np.absolute(Yhat - Y_test[0:100, ...])
file_prefix = 'TTN4V1'
for idx in (0, 1, 2, 4, 8):
#     filename = 'figures/' + file_prefix + 'Sampling' + str(idx)
#     myPolarPlot(th, rr, [np.transpose(Y_test[idx, ...], [1, 0])], figsize=(6,6), n_col=1, filename=filename + 'Y.png')
#     myPolarPlot(th, rr, [np.transpose(Yhat[idx, ...], [1, 0])], figsize=(6,6), n_col=1, filename=filename + 'Ypred.png')
#     myPolarPlot(th, rr, [np.transpose(dY[idx, ...], [1, 0])], figsize=(6,6), n_col=1, filename=filename + 'dY.png')
    datas = [Y_test[idx, ...], Yhat[idx, ...], dY[idx, ...]]
    for k in range(len(datas)):
        datas[k] = np.transpose(datas[k], [1, 0])

    myPolarPlot(th, rr, datas, colorbar='on', figsize=(12, 3))
    print(f'PSNR of sample k is: {PSNR(Y_test[idx, ...], Yhat[idx, ...], pixel_max): .2f}')

In [None]:
test_file = 'rteNus2Nua0SymG9'
n_test_new = 100
filenameTest = data_path + test_file + '.h5'
fin = h5py.File(filenameTest, 'r')
TestInput = fin['measurement_diff_' + data_type][0:n_test_new, :, :]
TestInput_bg = fin['measurement_' + data_type][:]
TestOutput = fin['us'][0:n_test_new, :, :]
TestOutput = np.transpose(TestOutput, [0, 2, 1])
TestInput *= Input_factor
TestInput_bg *= Input_factor
TestOutput -= 1

# ---------- add noise on the input data ----------------------
ns_rate = noise_rate
noiseTest = np.random.randn(n_test_new, Ns, Nd) * ns_rate
TestInput = TestInput * (1 + noiseTest) + TestInput_bg * noiseTest

TestYhat = model.predict(TestInput, min(n_test_new, 100))
TestYhat = np.maximum(np.minimum(TestYhat, 1), 0)
dY = TestYhat - TestOutput
for idx in range(0, 10):
    datas = [TestOutput[idx, ...], TestYhat[idx, ...], dY[idx, ...]]
    for k in range(len(datas)):
        datas[k] = np.transpose(datas[k], [1, 0])

    myPolarPlot(th, r, datas)