In [1]:
import sys
sys.path.append(r'..') 
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from model.model import MeNet
from model.Gdata import Data
from model.config import Config
from tqdm import tqdm
import matplotlib.pyplot as plt 
import numpy as np
import json

## Trained a new model

In [2]:
def load_json_file(fpath):
    with open(fpath, 'r') as f:
        data = json.load(f)
    return data

In [None]:

file_path = '' # the dataset path you saved
data = load_json_file(str(file_path+'/config.json'))

print('Modulation:',data['modulate_aber'])
print('Initial Mode Ranges:',data['zernike_amplitude_ranges'])

In [None]:
net_architecture = 'MeNet'  # 'singleEncoder' 
c = Config(zernike_amplitude_ranges=data['zernike_amplitude_ranges'],
           modulate_aber=data['modulate_aber'],net_architecture = net_architecture,
           isMultiStream = 3,isRealTime = False,psf_na_detection=data['psf_na_detection'], 
           psf_units=tuple(data['psf_units']), psf_n=data['psf_n'], 
           psf_lam_detection=data['psf_lam_detection'],dataFile = data['dataFile'], 
           isRegular = data['isRegular'], regularValue = data['regularValue'])
#vars(c)

In [None]:
model = MeNet(config=c, name='test', basedir='Models/')

In [None]:
print(("MultiStream" if model.config.isMultiStream else "SingleStream"))

print("Net_architecture:",model.config.net_architecture)
print("Batch_size:",model.config.train_batch_size)
print("train_steps_per_epoch:",model.config.train_steps_per_epoch)
print("The number of predicted aberrations:",model.config.n_channel_out)
print("Range of predicted aberrations: ",model.config.zernike_amplitude_ranges)
print("Induced bias :",model.config.bias_aber)
print("Net_architecture:",model.config.net_architecture)
print("Net_Learning_rate:",model.config.train_learning_rate)
print("RealTime:",model.config.isRealTime)

In [11]:
model.config.train_learning_rate = 6e-3
model.config.train_steps_per_epoch = 1000
model.config.train_batch_size = 64
model.config.train_n_val = 32


In [None]:
model.train(epochs=100)

In [None]:
test_data = np.load(model.config.dataFile+'data_test.npz')
X1_test = test_data['X1_test']
X2_test = test_data['X2_test']
X3_test = test_data['X3_test']
Y_test = test_data['Y_test']

In [None]:
X = {'X1': X1_test, 'X2': X2_test, 'X3': X3_test}
gt = Y_test

In [None]:
zerns_noll = []
for k in model.config.zernike_amplitude_ranges.keys():
    zerns_noll.append(int(k))
pre = model.keras_model.predict(X)

In [None]:
plt.figure(figsize=(24, 10),facecolor='w',edgecolor='w')
mode_num = 0
# rangeLimit = model.config.zernike_amplitude_ranges['5']
# lower_limit,upper_limit = rangeLimit[0]-0.05,rangeLimit[1]+0.05
lower_limit,upper_limit = -0.5-0.05,0.5+0.05
for j in range(gt.shape[1]):
    plt.subplot(2,4,j+1)
    #ax = plt.subplot(2,2,j+1).add_axes((0.14, .12, .8, .77))
    ax = plt.gca()
    for i in range(gt.shape[0]):
        # ax.plot(gt[i][mode_num],tmp[i][mode_num],"o",color=colors[i], markersize=4)
        ax.plot(gt[i][mode_num],pre[i][mode_num],"o",color='black', markersize=4)
    
    # plt.ylim(lower_limit,upper_limit)
    # plt.xlim(lower_limit,upper_limit)
    # plt.grid(False)

    plt.xlabel(f'Experimentally introduced amplitude for $\mathbf{{Z_{{{zerns_noll[mode_num]}}}}}$ / $\mathbf{{\mu m}}$' , 
               size=12, fontweight='bold', labelpad=4)
    plt.ylabel(f'Predicted amplitude $\mathbf{{a_{{{zerns_noll[mode_num]}}}}}$ / $\mathbf{{\mu m}}$', 
               size=12, fontweight='bold', labelpad=-5)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_linewidth(2)  # x轴加粗
    ax.spines['left'].set_linewidth(2)    # y轴加粗

    plt.yticks(size=14,color="black", fontweight='bold')
    plt.xticks(size=14,color="black", fontweight='bold')
    # plt.ylim(lower_limit,upper_limit)
    # plt.xlim(lower_limit,upper_limit)

    # sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    # sm.set_array([])  # 为空数组
    # plt.colorbar(sm, ax=ax, label='Color Scale')

    XX = np.linspace(lower_limit, upper_limit, num=50)
    YY = 1 * XX
    ax.plot(XX, YY, ls="--", color = "k")

    mode_num = mode_num+1
    
