In [None]:
from load_modules import *
import my_config
import time
K.clear_session()
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

In [None]:
if my_config.USE_GPU:
    physical_devices = tf.config.list_physical_devices('GPU')
    #set memory amount to half of GPU
    if len(physical_devices) > 0:
        for device in physical_devices:
            print("Device:", device)
    else:
        print("No GPU devices found.")
    # Set GPU memory growth
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                
        except RuntimeError as e:
            print(e)
        # Set environment variable for GPU memory allocation
    os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
    # Test for GPU device name
    name = tf.test.gpu_device_name()
    if name != '/device:GPU:0':
        raise SystemError('GPU device not found')
    print('Found GPU at: {}'.format(name))
    # Print the number of available GPUs
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
else:
    print("No GPU")

In [None]:
encoder = load_model(my_config.ENCODER_PATH)
decoder = load_model(my_config.DECODER_PATH)
print("Loaded models")
print("Encoder summary:" + str(encoder.summary()))
print("Decoder summary:" + str(decoder.summary()))
def encode(img):
    image = np.asarray(img).reshape(-1,3).astype('float32')
    # pred_maps = encoder.predict(image)
    start = time.time()
    pred_maps = None
    with tf.device('/device:GPU:0'):
        pred_maps = encoder.predict_on_batch(image)
    end = time.time()
    elapsed = end - start
    return pred_maps, elapsed
 
def decode(encoded):
    # recovered = decoder.predict(encoded)
    start = time.time()
    recovered = None
    with tf.device('/device:GPU:0'):
        recovered = decoder.predict_on_batch(encoded)
    end = time.time()
    elapsed = end - start
    # recovered = np.clip(recovered, 0, 1)
    return recovered, elapsed

In [None]:
#load neutral image
neutral_path = r"meta46/neutral/FaceColor_MAIN_LOD.PNG"
aged_path = r"meta46/old/FaceColor_MAIN_LOD.PNG"
#display neutral image
neutral_image = Image.open(neutral_path).resize((256,256))

plt.imshow(neutral_image)
plt.show()
#display aged image
aged_image = Image.open(aged_path).resize((256,256))
plt.imshow(aged_image)
plt.show()
print(f"width = {WIDTH}")

#encode neutral image
WIDTH = neutral_image.size[0]
HEIGHT = neutral_image.size[1]
neutral_image = np.asarray(neutral_image)/255.0
neutral_image = neutral_image[:, :, :3]  # keep only R, G, and B channels

neutral_pred_maps, neutral_encode_time = encode(neutral_image)
neutral_Cm = np.asarray(neutral_pred_maps[:,0])
neutral_Ch = np.asarray(neutral_pred_maps[:,1])
neutral_Bm = np.asarray(neutral_pred_maps[:,2])
neutral_Bh = np.asarray(neutral_pred_maps[:,3])
neutral_T = np.asarray(neutral_pred_maps[:,4])

#encode aged image
WIDTH = aged_image.size[0]
HEIGHT = aged_image.size[1]
aged_image = np.asarray(aged_image)/255.0
aged_image = aged_image[:, :, :3]
aged_pred_maps, aged_encode_time = encode(aged_image)
aged_Cm = np.asarray(aged_pred_maps[:,0])
aged_Ch = np.asarray(aged_pred_maps[:,1])
aged_Bm = np.asarray(aged_pred_maps[:,2])
aged_Bh = np.asarray(aged_pred_maps[:,3])
aged_T = np.asarray(aged_pred_maps[:,4])
#show Melanin maps
plt.imshow(neutral_Cm.reshape((WIDTH, HEIGHT)))
plt.show()
plt.imshow(aged_Cm.reshape((WIDTH, HEIGHT)))
plt.show()
#save aged pred maps to csv
header = ["Cm", "Ch", "Bm", "Bh", "T"]
np.savetxt("aged_pred_maps256.csv", aged_pred_maps, delimiter=",", header="Cm, Ch, Bm, Bh, T", comments="")
#save neutral pred maps to csv
header = ["Cm", "Ch", "Bm", "Bh", "T"]
np.savetxt("neutral_pred_maps256.csv", neutral_pred_maps, delimiter=",", header="Cm, Ch, Bm, Bh, T", comments="")


In [None]:
#load aged pred maps from csv
aged_pred_maps = pd.read_csv("aged_pred_maps2048.csv")
neutral_pred_maps = pd.read_csv("neutral_pred_maps2048.csv")
#print headers
print(aged_pred_maps.head())
print(neutral_pred_maps.head())
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

# Define the model
model = Sequential([
    Dense(65, activation='relu', input_shape=(5,)),  # Input layer (5 features: 'Cm', 'Ch', 'Bm', 'Bh', 'T')
    Dense(65, activation='relu'),  # Hidden layer
    Dense(5, activation='linear')  # Output layer (5 predictions: 'Cm', 'Ch', 'Bm', 'Bh', 'T')
])

# Compile the model
model.compile(optimizer='adam', 
              loss='mean_squared_error', 
              metrics=['mae'])

# Print the summary of the model
model.summary()
x = neutral_pred_maps
y = aged_pred_maps
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=42, shuffle=True)
with tf.device('/device:GPU:0'):
    history = model.fit(x_train, y_train, 
                        validation_data=(x_test, y_test), 
                        epochs=65, 
                        batch_size=2048)


#predict aged maps from neutral
with tf.device('/device:GPU:0'):
    aged_pred_maps_prediction = model.predict(neutral_pred_maps)
#decode aged maps
aged_recovered, aged_decode_time = decode(aged_pred_maps_prediction)
#show original image and  aged recovered image
plt.imshow(neutral_image)
plt.title("Original Image")
plt.show()
plt.imshow(aged_recovered.reshape((2048,2048,3)))
plt.title("Aged Recovered Image")
plt.show()