In [None]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sys import stdout

from network import HandModel
from process_data import get_data_set, get_num_batches


In [None]:
model = HandModel()
completedEpochs = 0
dummy_input = tf.random.normal([1, 368, 368, 3])
_ = model(dummy_input)

# model.load_weights('vgg19_pretrain_synth_dataset_epoch10_rate_5e-5.h5')

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)

train_loss = tf.keras.metrics.Mean(name='train_loss')

train_data, val_data = get_data_set()
num_batches = get_num_batches()

In [None]:
for image, label in train_data.take(1):
    img_array = image.numpy()
    lbl = label.numpy()
    print(lbl[0][:, :, 0])

In [None]:
model.feature_extraction.summary()

In [None]:
copy_model = tf.keras.applications.VGG19(False, input_shape=(368, 368, 3), pooling='avg')
copy_model.summary()

In [None]:
for i in range(1, 16):
    print(copy_model.layers[i].name)
    print(model.feature_extraction.layers[i - 1].name)
    print(model.feature_extraction.layers[i - 1].get_weights())
    model.feature_extraction.layers[i - 1].set_weights(copy_model.layers[i].get_weights())
    print(model.feature_extraction.layers[i - 1].get_weights())

In [None]:
def compute_loss(y_pred, intermediate_pred, y):
    loss_object = tf.keras.losses.MeanSquaredError()
    final_loss = loss_object(y, y_pred)
    total_loss = final_loss
    for i in range(len(intermediate_pred)):
        total_loss += loss_object(y, intermediate_pred[i])
    
    return final_loss, total_loss

In [None]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pred, y_intermediate = model(x)
        fin_loss, total_loss = compute_loss(y_pred, y_intermediate, y)

    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return fin_loss, total_loss

In [None]:
try:
        
    for epoch in range(20):
        train_loss.reset_state()


        for batch_num, (x_batch, y_batch) in enumerate(train_data):
            fin_loss, total_loss = train_step(x_batch, y_batch)
            train_loss(total_loss)
            print(f"\rProgress: {batch_num + 1}/{num_batches}", end='', flush=True)

        print()

        curr_loss = train_loss.result()

        print(f"Epoch {epoch + 1} ----- Loss: {curr_loss}")

        if curr_loss < 0.0001:
            print("Finished training, ended early")
            break
except KeyboardInterrupt:
    print("Stopping Training, saving weights")
    model.save("vgg19_train_full_dataset.h5")


In [None]:
model.save("vgg19_pretrain_synth_dataset_epoch10_rate_5e-5.h5")

In [None]:
from visualize_data import visualize_combined_map
from process_data import load_img_and_pos_with_num, create_heat_map

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15,15))

index = np.random.randint(3000)

test_x, test_handpos, _ = load_img_and_pos_with_num(index)
test_y = create_heat_map(test_handpos)

test_x = np.expand_dims(test_x, axis = 0)

y_pred, _ = model(test_x)

test_handpos = test_handpos.T
axes[0].imshow(test_x[0])
axes[0].scatter(test_handpos[0], test_handpos[1], color='green', s=50)
test_handpos = test_handpos.T

visualize_combined_map(test_x[0], test_y, axes[1])
visualize_combined_map(test_x[0], y_pred[0, :, :, :-1], axes[2])


plt.tight_layout()
plt.show()

In [None]:
dummy_input = tf.random.normal((1, 368, 368, 3))
model1 = HandModel()
model1(dummy_input)
model1.load_weights("good progress models/vgg19_train_full_dataset2.h5")
model2 = HandModel()
model2(dummy_input)
model2.load_weights("good progress models/vgg19_train_full_dataset4.h5")

In [None]:
fig, axes = plt.subplots(4, 11, figsize=(15,15))

index = np.random.randint(3000)

test_x, test_handpos, _ = load_img_and_pos_with_num(index)
test_y = create_heat_map(test_handpos)

test_x = np.expand_dims(test_x, axis = 0)

y_pred1, _ = model1(test_x)
y_pred2, _ = model2(test_x)


for i, ax in enumerate(list(axes.flat)):
    if i%2 == 0:
        ax.imshow(y_pred1[0, :, :, i//2], cmap='gray')
    else:
        ax.imshow(y_pred2[0, :, :, i//2], cmap='gray')
    
    ax.axis('off')


plt.tight_layout()
plt.show()

In [None]:
import matplotlib.image as img

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15,15))

img_path = "test_photo.jpeg"

test_x = img.imread(img_path)

test_x = np.expand_dims(test_x, axis = 0)

y_pred, _ = model(test_x)

axes[0].imshow(test_x[0])

visualize_combined_map(test_x[0], y_pred[0, :, :, :-1], axes[1])


plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(3, 7, figsize=(15,15))

for i, ax in enumerate(list(axes.flat)):
    ax.imshow(y_pred[0, :, :, i], cmap='gray')
    
    ax.axis('off')


plt.tight_layout()
plt.show()