In [1]:
import json
import tensorflow as tf

def load_data(json_file):
    with open(json_file, 'r') as file:
        data = json.load(file)

    image_paths = []
    keypoints = []
    for item in data:
        image_paths.append("../data/images/" + str(item['id']) + '.png')  # Assuming JPG format
        keypoints.append(item['kps'])  # List of 28 keypoints (x, y)

    return image_paths, keypoints

train_image_paths, train_keypoints = load_data('../data/data_train.json')
val_image_paths, val_keypoints = load_data('../data/data_val.json')


In [2]:
def parse_function(filename, keypoints):
    image = tf.io.read_file(filename)
    image = tf.image.decode_image(image, channels=3)  # Decode the image
    # image = tf.image.convert_image_dtype(image, tf.float32)  # Convert to float32

    # Ensure the image tensor has a known shape
    image = tf.ensure_shape(image, [None, None, 3])

    # Resize the image
    image = tf.image.resize(image, (224, 224))  # Resize image

    keypoints  = tf.reshape(keypoints, [-1, 2])
    keypoints *= tf.constant([[224/1280, 224/720]])
    keypoints  = tf.reshape(keypoints, [-1])

    return image, keypoints


def create_dataset(image_paths, keypoints):
    image_paths = tf.constant(image_paths)
    keypoints = tf.constant(keypoints, dtype=tf.float32)
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, keypoints))
    dataset = dataset.map(parse_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset

train_dataset = create_dataset(train_image_paths, train_keypoints)
val_dataset = create_dataset(val_image_paths, val_keypoints)


In [3]:
def prepare_for_training(ds, batch_size=128, shuffle_buffer_size=10000):  
    ds = ds.shuffle(buffer_size=shuffle_buffer_size)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return ds

train_dataset = prepare_for_training(train_dataset)
val_dataset = prepare_for_training(val_dataset, batch_size=16)  # Smaller batch for validation

In [75]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras.applications.mobilenet_v3 import preprocess_input
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

def create_model():
    ## Define input layer
    inputs = Input(shape=(224, 224, 3))
    preprocessed_input = preprocess_input(inputs)  # Preprocess input images
    
    ## Load pre-trained model “Resnet50” without the final(top) layer
    base_model = MobileNetV3Small(weights='imagenet', include_top=False, input_tensor=preprocessed_input)
    base_model.trainable = False
    output = base_model.output

    ## Condense feature maps from the output
    output = Flatten()(output)


    # Final layer has 28 output neurons
    final_output = Dense(28, activation='relu')(output)  

    ## Create our own network/model
    model = Model(inputs=inputs, outputs=final_output)

    return model


model = create_model()
model.compile(optimizer=tf.keras.optimizers.Adam(.0001), loss='mae')  # Using mean squared error loss for regression task

earlystopping = EarlyStopping(patience=5, restore_best_weights=True)

# Define model checkpoint callback based on validation loss
checkpoint_path = "model_checkpoint.model.keras"
checkpoint = ModelCheckpoint(checkpoint_path, save_best_only=True)


In [57]:
history = model.fit(train_dataset, validation_data=val_dataset, epochs=60, callbacks=[earlystopping])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [16]:
model.compile(optimizer=tf.keras.optimizers.Adam(.0001), loss='mae')  # Using mean squared error loss for regression task
history = model.fit(train_dataset, validation_data=val_dataset, epochs=60, callbacks=[earlystopping])

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60


In [35]:
model.save('models/keypoints_1_6950