# Test model on test images

In [6]:

%load_ext autoreload
%autoreload 1

%aimport dti_util


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import os
import yaml
import numpy as np
from dti_util import default, stack_training, unstack_training, clean_ax
import autokeras as ak #Necessary to import otherwise load_model does not work
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import glob
import matplotlib as mpl
from tqdm.notebook import tqdm
import matplotlib.gridspec as gridspec


In [8]:
# Change default matplotlib
mpl.rcParams['savefig.dpi']=150
mpl.rcParams['savefig.bbox']='tight'
mpl.rcParams['savefig.pad_inches']=0.1
mpl.rcParams['axes.labelsize']=mpl.rcParams['axes.titlesize']
plt.rcParams['savefig.facecolor']='white'


In [9]:
# Name of the experiment
name = 'exp-sr-sat-hr-smallpatch'

# Name of the model
model_name = "long"

# Root of all directories used
rootdir = '/mnt/sfe-ns9602k/Julien/data'

# save figure
savefig = True
figdir = 'figs'


In [10]:
# Directory of the experiment outputs
expdir = os.path.join(rootdir, name)

# Directory of the model
model_dir = os.path.join(expdir, model_name)

# Load experiment parameters
with open(os.path.join(expdir,'data_params.yml' )) as file:
    exp_dict = yaml.load(file, Loader=yaml.FullLoader)
exp_dict = {**default, **exp_dict}
    
# Print experiments parameters
print('---- EXPERIMENT SETTING ----')
for key, value in exp_dict.items():
    print(key, ' : ', value)

print('\n---- MODEL SETTING ----')
with open(os.path.join(model_dir,'model_params.yml' )) as file:
    dmod = yaml.load(file, Loader=yaml.FullLoader)   
for key, value in dmod.items():
    print(key, ' : ', value)
    
print('\n---- SAT SETTING ----')
with open(os.path.join(expdir,'sat_params.yml' )) as file:
    dsat = yaml.load(file, Loader=yaml.FullLoader)   
for key, value in dsat.items():
    print(key, ' : ', value)

---- EXPERIMENT SETTING ----
smooth_output  :  4
strides_test  :  2
smooth_drift  :  30
smooth_sic  :  6
smooth_sit  :  30
scale  :  False
epsi  :  None
targetname  :  h
targetfullname  :  SIT
colnames  :  ('e2_0', 'c', 'h', 'e1_0')
datadir  :  /mnt/sfe-ns9602k/Julien/data/anton/shom5km_defor_4cnn
dsize  :  7
end_train  :  400
itest  :  1
name  :  exp-sr-sat-hr-smallpatch
othernames  :  ('c',)
rootdir  :  /mnt/sfe-ns9602k/Julien/data
start_train  :  10
strides  :  20
subd  :  2
th_dam  :  0.0
th_sic  :  0.2
th_sit  :  0.0
traindir  :  /mnt/sfe-ns9602k/Julien/data/exp-sr-sat-hr-smallpatch/train

---- MODEL SETTING ----
epochs  :  300
fname_score_saliency  :  saliency_score
fname_score_shuffle  :  shuffle_score
log_dir  :  /mnt/sfe-ns9602k/.tools/deep-learn-1603280065-tensorboard/autokeras/exp-sr-sat-hr-smallpatch-long
max_trials  :  50
ntrain  :  None
patience  :  15
split_seed  :  1
target_th  :  0.95
test_size  :  0.15
type  :  reg

---- SAT SETTING ----
datadir  :  /mnt/sfe-ns9602k/J

In [11]:
sat_path = dsat['path']
fig_path = os.path.join(sat_path,figdir)

if not os.path.isdir(fig_path):
    os.makedirs(fig_path)

In [12]:
e1lim = (-5e-7,5e-7)
e2lim = (0.,1e-6)
lims = {
    'log_deformation_0': (-4, 1),
    'log_deformation_1': (-4, 1),
    'h':(exp_dict['th_sit'],3),
    'c':(0,1),
    'd':(exp_dict['th_dam'], 1.),
    'e1_0': e1lim,
    'e1_1': e1lim,
    'e2_0': e2lim,
    'e2_1': e2lim,
}

# Corresponsdance
suff = '_sar' if dsat['sar'] else ''
mod2sat = dict(
    c='sic',
    h='sit',
    e1_0='divergence'+suff,
    e2_0='shear'+suff)
print(mod2sat)

lims_sat = {mod2sat[k]:v for k,v in lims.items() if k in mod2sat}
lims_sat

{'c': 'sic', 'h': 'sit', 'e1_0': 'divergence', 'e2_0': 'shear'}


{'sit': (0.0, 3),
 'sic': (0, 1),
 'divergence': (-5e-07, 5e-07),
 'shear': (0.0, 1e-06)}

In [13]:
# Load the model
model_dir =  os.path.join(expdir, model_name)
model = load_model(os.path.join(model_dir,model_name), compile=False)
scale = exp_dict['scale']

if scale is True:
    epsi = exp_dict['epsi']
    th_dam = exp_dict['th_dam']

    print('Scaling the output')
    from dti_util import code_dam, decode_dam
    # Normalization function
    norm = lambda x : code_dam(x,epsi=epsi, vmin=th_dam)

    # Denormalization function
    denorm = lambda x : decode_dam(x,epsi=epsi, vmin=th_dam)
else:
    print('No scaling')

    norm = lambda x : x
    denorm = lambda x : x

No scaling


In [14]:
model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 7, 7, 4)]         0         
_________________________________________________________________
cast_to_float32 (CastToFloat (None, 7, 7, 4)           0         
_________________________________________________________________
normalization (Normalization (None, 7, 7, 4)           9         
_________________________________________________________________
conv2d (Conv2D)              (None, 5, 5, 32)          1184      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 5, 5, 32)          9248      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 2, 2, 32)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 2, 2, 32)         

In [15]:
fname = dsat['tname'].format(date='*')
lfiles=list(map(os.path.basename,sorted(glob.glob(os.path.join(sat_path,fname)))))
n = len(lfiles)
print(f'{n} files found')

31 files found


In [16]:
indx = slice(n)
lfiles

['sat_input_20210101.npz',
 'sat_input_20210201.npz',
 'sat_input_20210301.npz',
 'sat_input_20210401.npz',
 'sat_input_20210501.npz',
 'sat_input_20210601.npz',
 'sat_input_20210701.npz',
 'sat_input_20210801.npz',
 'sat_input_20210901.npz',
 'sat_input_20211001.npz',
 'sat_input_20211101.npz',
 'sat_input_20211201.npz',
 'sat_input_20211301.npz',
 'sat_input_20211401.npz',
 'sat_input_20211501.npz',
 'sat_input_20211601.npz',
 'sat_input_20211701.npz',
 'sat_input_20211801.npz',
 'sat_input_20211901.npz',
 'sat_input_20212001.npz',
 'sat_input_20212101.npz',
 'sat_input_20212201.npz',
 'sat_input_20212301.npz',
 'sat_input_20212401.npz',
 'sat_input_20212501.npz',
 'sat_input_20212601.npz',
 'sat_input_20212701.npz',
 'sat_input_20212801.npz',
 'sat_input_20212901.npz',
 'sat_input_20213001.npz',
 'sat_input_20213101.npz']

In [17]:
# Find the date index
idate = dsat['tname'].index('{')
dateind = slice(idate,idate+8) # 8 == YYYYDDMM

In [18]:
dsize = exp_dict['dsize']
strides_test = dsat['strides_test']
subd = dsat['subd']
ypredict = None
import psutil
for file in tqdm(lfiles[indx]):
    date = file[dateind]
    fname = dsat['tname'].format(date=date)
    print(fname)
    with np.load(os.path.join(dsat['path'],fname)) as data:
        Xtest = data['Xtest']
        mask_test = data['mask_test']
        ny = data['ny']
        nx = data['nx']
    if Xtest.size>0:
        ypredict_tmp = denorm(model.predict(Xtest))
        X2, y2_pred = unstack_training(Xtest, ypredict_tmp, mask_test, ny=ny, nx=nx, subd=subd, strides=strides_test, squeezey=False)
        y2_pred = y2_pred.squeeze()
        fname = dsat['outname'].format(date=date)
        np.save(os.path.join(dsat['path'],fname),y2_pred)
        del X2, y2_pred
    del  Xtest, mask_test
    print(psutil.Process().memory_info().rss / (1024 * 1024))

HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))

sat_input_20210101.npz
2602.53125
sat_input_20210201.npz
2744.40234375
sat_input_20210301.npz
2751.84375
sat_input_20210401.npz
2861.48828125
sat_input_20210501.npz
2871.359375
sat_input_20210601.npz
2982.6015625
sat_input_20210701.npz
3092.875
sat_input_20210801.npz
3094.578125
sat_input_20210901.npz
3202.46875
sat_input_20211001.npz
3202.89453125
sat_input_20211101.npz
3309.8125
sat_input_20211201.npz
3417.3984375
sat_input_20211301.npz
3419.265625
sat_input_20211401.npz
3528.6328125
sat_input_20211501.npz
3528.60546875
sat_input_20211601.npz
3645.37109375
sat_input_20211701.npz
3754.4609375
sat_input_20211801.npz
3755.37890625
sat_input_20211901.npz
3866.41015625
sat_input_20212001.npz
3978.6640625
sat_input_20212101.npz
3979.56640625
sat_input_20212201.npz
2690.5859375
sat_input_20212301.npz
2802.28125
sat_input_20212401.npz
2913.9296875
sat_input_20212501.npz
2914.25
sat_input_20212601.npz
3026.5234375
sat_input_20212701.npz
3139.4375
sat_input_20212801.npz
3139.859375
sat_input_2

In [19]:

colnames = [mod2sat[c] for c in exp_dict['colnames']]
nc = len(colnames)

for file in tqdm(lfiles[indx]):
    date = file[dateind]
    fname = dsat['timname'].format(date=date)
    Xim = np.load(os.path.join(dsat['path'],fname))
    fname = dsat['outname'].format(date=date)
    if os.path.isfile(os.path.join(dsat['path'],fname)):
        y2_pred = np.load(os.path.join(dsat['path'],fname))
    else:
        continue
    fig = plt.figure(figsize=(5*nc,17))

    gs = gridspec.GridSpec(nrows=2, ncols=4, height_ratios=[2, 1])

    for ic, c in enumerate(colnames):
        ax = fig.add_subplot(gs[1,ic])
        vmin, vmax = lims_sat[c]
        co=ax.imshow(Xim[0,...,ic], cmap='jet', vmin=vmin, vmax=vmax)
        fig.colorbar(co, ax=ax, orientation='horizontal')
        clean_ax([ax])
        ax.set_xlim((200,700))
        ax.set_ylim((200,700))
        ax.set_title(c)
    ax = fig.add_subplot(gs[0,:2])
    vmin, vmax = lims_sat['sit']
    co=ax.imshow(y2_pred, cmap='jet', vmin=vmin, vmax=vmax)
    ax.set_xlim((200,700))
    ax.set_ylim((200,700))
    ax.set_title('NN SIT')
    clean_ax([ax])
    fig.colorbar(co, ax=ax, orientation='horizontal')

    ax = fig.add_subplot(gs[0,2:4])
    vmin, vmax = -.8,.8
    ih = colnames.index('sit')
    co=ax.imshow(y2_pred-Xim[0,...,ih], cmap='bwr', vmin=vmin, vmax=vmax)
    ax.set_xlim((200,700))
    ax.set_ylim((200,700))
    ax.set_title('HR minus LR')
    clean_ax([ax])
    
    
    fig.colorbar(co, ax=ax, orientation='horizontal',extend='both')
    fig.suptitle(date)
    if savefig:
        figname = f'SIT-NN-{date}.png'
        fig.savefig(os.path.join(fig_path,figname))
        plt.close(fig)

HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))


