In [None]:
# Evaluate the ML model

import tensorflow as tf
from tensorflow import keras
import os
import time

import glob
import numpy as np
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import matplotlib as mpl
from utils import decoder_files_to_tensors, get_best_model_timestamp
from utils import unnormalize_params, assess_decoder
from utils import sample_files
from models import Decoder, mse_loss_encoder, mse_loss_decoder


In [None]:
# Initialize parameters
# data_dir = '/eos/user/k/kiliakis/tomo_data/datasets_decoder_02-12-22'
data_dir = './tomo_data/datasets_decoder_02-12-22'


# Initialize train/ test / validation paths
ML_dir = os.path.join(data_dir, 'ML_data')
TESTING_PATH = os.path.join(ML_dir, 'TESTING')
assert os.path.exists(TESTING_PATH)

# First the training data
file_names = sample_files(TESTING_PATH, 0.01, keep_every=1)
print(len(file_names))
import time
start_t = time.time()
# read input, divide in features/ label, create tensors
x_test, y_test = decoder_files_to_tensors(file_names, normalization='minmax')
total_time = time.time() - start_t
print(f'Elapsed time: {total_time:.3f}, Per file: {total_time/len(file_names):.3f}')

# VALIDATION_PATH = os.path.join(ML_dir, 'VALIDATION')
# assert os.path.exists(VALIDATION_PATH)

# # Then the validation data
# files = glob.glob(VALIDATION_PATH + '/*.pk')

# # Shuffle them
# np.random.shuffle(files)
# # read input, divide in features/ label, create tensors
# x_valid, y_valid = decoder_files_to_tensors(files)



In [None]:
# Model to load
timestamp = '2022_12_13_15-37-27'
# timestamp = get_best_model_timestamp('./trials', model='dec')
print(timestamp)

# Initialize directories
trial_dir = os.path.join('./trials/', timestamp)
weights_dir = os.path.join(trial_dir, 'weights')
plots_dir = os.path.join(trial_dir, 'plots')
assert os.path.exists(weights_dir)
os.makedirs(plots_dir, exist_ok=True)

# load the model
decoder = keras.models.load_model(os.path.join(weights_dir, 'decoder.h5'),
                                  compile=False)
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
decoder.compile(optimizer=optimizer, loss='mse')


In [None]:
# Evaluate the model on the test and validation data
test_loss = decoder.evaluate(x_test, y_test)
print(f'Test loss: {test_loss:.4e}')

# get predictions
test_pred = decoder.predict(x_test, verbose=False)
y_test = np.array(y_test)

# Calculate error per variable
# mses = mean_squared_error(y_test, test_pred, multioutput='raw_values')

# valid_loss = decoder.evaluate(x_valid, y_valid)
# print(f'Valid loss: {valid_loss:.4e}')
# valid_pred = decoder.predict(x_valid, verbose=False)

In [None]:
%matplotlib inline

print(np.max(y_test))
print(np.min(y_test))
mse_image = np.mean((y_test - test_pred) ** 2, axis=0)
mse_image = mse_image.reshape((128, 128))

me_image = np.mean(np.abs(y_test - test_pred), axis=0)
me_image = me_image.reshape((128, 128))

# Create 3x3 grid of figures
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(10, 8))

ax.set_xticks([])
ax.set_yticks([])
# show the image
plt.imshow(me_image, cmap='jet', aspect='auto')
plt.colorbar()
# Set the label
# title = ','.join([f'{num:.1f}' for num in samples_X[i]])
ax.set_title(f'Mean Diff.')

# for i in range(len(axes)):
#     ax = axes[i]


In [None]:
%matplotlib inline
# plot some of the outputs

nrows = 5
# Get nrows * nrows random images
sample = np.random.choice(np.arange(len(y_test)),
                          size=nrows, replace=False)

samples_real = y_test[sample]
samples_pred = test_pred[sample]

# Create 3x3 grid of figures
fig, axes = plt.subplots(ncols=3, nrows=nrows, figsize=(12, 20))
# axes = np.ravel(axes)
for i in range(nrows):
    ax = axes[i][0]
    plt.sca(ax)
    ax.set_xticks([])
    ax.set_yticks([])
    # show the image
    ax.imshow(samples_real[i]+1, cmap='jet', vmin=0, vmax=2)
    # Set the label
    # title = ','.join([f'{num:.1f}' for num in samples_X[i]])
    ax.set_title(f'True')
    plt.tight_layout()

    ax = axes[i][1]
    plt.sca(ax)
    ax.set_xticks([])
    ax.set_yticks([])
    # show the image
    ax.imshow(samples_pred[i]+1, cmap='jet', vmin=0, vmax=2)
    # Set the label
    # title = ','.join([f'{num:.1f}' for num in samples_X[i]])
    ax.set_title(f'Predicted')
    plt.tight_layout()

    ax = axes[i][2]
    plt.sca(ax)
    ax.set_xticks([])
    ax.set_yticks([])
    # show the image
    plt.imshow(np.abs(samples_real[i] -
              samples_pred[i]), cmap='jet', vmin=0, vmax=2,
              aspect='auto')
    plt.colorbar()
    # Set the label
    # title = ','.join([f'{num:.1f}' for num in samples_X[i]])
    ax.set_title(f'Diff')
    plt.tight_layout()
