# Imports

In [None]:
import numpy as np
import sys
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import Adam
import sys

# Specify path and Load Data

In [None]:
#replace with your path here
path_name = "final_evaluation/texture_final_2" #path where outputs of "inference.py" are stored
path_name_properties = "evaluation"

In [None]:
#load outputs generated by inference.py
all_images = np.load(path_name+"/all_images.npy")
all_masks = np.load(path_name+"/all_masks.npy")
all_recons = np.load(path_name+"/all_recons.npy")
all_predictions = np.load(path_name+"/all_predictions.npy")
all_masks_predicted = np.load(path_name+"/all_masks_predicted.npy")
all_slots = np.load(path_name+"/all_slots.npy")
all_positions = np.load(path_name+"/all_positions.npy")
all_sizes = np.load(path_name+"/all_sizes.npy")
all_masks_predicted = np.transpose(all_masks_predicted,[0,1,3,4,2,5]).reshape(156,32,128,128,11)

In [None]:
all_properties = np.load(path_name_properties+"/ground_truth_properties.npy")


# Bind Slots to Ground-Truth Objects via Masks

In [None]:
def match_slots(mask_pred, mask_ref):
    mask_ref = mask_ref.astype(int).reshape(128,128).flatten() #transform 2D ground-truth masks to a 1D integer list
    mask_pred = mask_pred.argmax(axis=-1).flatten() #transform softmaxed predicted masks to a 1D integer list
    classes = np.unique(mask_ref) #get all classes present in the ground-truth objects.
    pairs = [] #safe (ground-truth mask, predicted mask) pairs
    for j in classes: #iterate over all classes
        index_mask_predicted, overlap_percent = find_largest_overlap(mask_pred,mask_ref,j)
        if overlap_percent < 0.35: #check that slot is not a background slot 
            return []
        if j != 0: #first slot is the background, do not include it
            pairs.append((j-1, index_mask_predicted))
    return pairs


def find_largest_overlap(mask_pred, mask_ref, j):
    mask_positions = np.array(mask_ref) == j #boolean array where ground-truth mask is
    countsBefore = np.bincount(mask_pred) #store how many pixels belong to each predicted slot 
    mask_pred = mask_pred[mask_positions] #consider only pixels where ground-truth mask is
    counts = np.bincount(mask_pred) #get counts of overlapping pixels
    index_max_overlap = np.argmax(counts) #get index of largest overlapping slot
    index_max_overlap_count = counts[index_max_overlap] #get count of largest overlapping slot
    overlap_percent = index_max_overlap_count / len(mask_pred) #check that slot is not a background slot (mask of slot is now allowed to be more than 3 times bigger than the real object)
    return index_max_overlap, overlap_percent

In [None]:
binded_slots = []
binded_properties = []
for i in range(156):
    for j in range(32):
        pairs = match_slots(all_masks_predicted[i][j],all_masks[i][j])
        for refs, preds in pairs:
            binded_properties.append(all_properties[i][j][refs])
            #concat slot representations with its position invariant parameters
            concat_slots = np.concatenate((all_slots[i][j][preds] , all_positions[i][j][preds]), axis = -1)
            concat_slots = np.concatenate((concat_slots, all_sizes[i][j][preds]), axis = -1)
            binded_slots.append(concat_slots)

In [None]:
#safe pairs - these can further be used for computing disentanglement metrics, found at: https://github.com/google-research/disentanglement_lib
np.save(path_name + "/binded_properties.npy",np.array(all_properties))
np.save(path_name + "/binded_slots.npy",np.array(all_slots))

# Predict Properties

In [None]:
#Inputs to model: Binded (Slot, Property)-pairs
X = all_slots
y = all_properties

# Validate the shape of the loaded data
print(f"Shape of X: {X.shape}")  
print(f"Shape of y: {y.shape}")

# Split the one-hot encoded labels into shape, size, and material - do not consider position
y_size = y[:, 3:6]
y_material = y[:, 6:71-5]
y_shape = y[:, 71-5:70]

# Split the data into training and testing sets
X_train, X_test, y_train_shape, y_test_shape = train_test_split(X, y_shape, test_size=0.2, random_state=42)
_, _, y_train_size, y_test_size = train_test_split(X, y_size, test_size=0.2, random_state=42)
_, _, y_train_material, y_test_material = train_test_split(X, y_material, test_size=0.2, random_state=42)

# Neural network architecture - here: only linear layer
input_layer = Input(shape=(X.shape[1],))
#x = Dense(128, activation='relu')(input_layer)
#x = Dense(128, activation='relu')(x)

# Output layers
output_shape = Dense(4, activation='softmax', name='shape')(input_layer)
output_size = Dense(3, activation='softmax', name='size')(input_layer)
output_material = Dense(60, activation='softmax', name='material')(input_layer)

# Create the model
model = Model(inputs=input_layer, outputs=[output_shape, output_size, output_material])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss={'shape': 'categorical_crossentropy',
                    'size': 'categorical_crossentropy',
                    'material': 'categorical_crossentropy'},
              metrics={'shape': 'accuracy', 'size': 'accuracy', 'material': 'accuracy'})

# Train the model
history = model.fit(X_train, [y_train_shape, y_train_size, y_train_material],
                    epochs=500,
                    batch_size=32,
                    validation_split=0.2)

# Evaluate the model
eval_results = model.evaluate(X_test, [y_test_shape, y_test_size, y_test_material])
print(f"Test Results - Shape Loss: {eval_results[1]}, Size Loss: {eval_results[2]}, Material Loss: {eval_results[3]}")
print(f"Test Results - Shape Accuracy: {eval_results[4]}, Size Accuracy: {eval_results[5]}, Material Accuracy: {eval_results[6]}")
# Calculate the overall accuracy
y_pred = model.predict(X_test)
y_pred_shape = np.argmax(y_pred[0], axis=1)
y_pred_size = np.argmax(y_pred[1], axis=1)
y_pred_material = np.argmax(y_pred[2], axis=1)

y_true_shape = np.argmax(y_test_shape, axis=1)
y_true_size = np.argmax(y_test_size, axis=1)
y_true_material = np.argmax(y_test_material, axis=1)

# Check if all properties are correctly predicted
correct_predictions = np.all([y_pred_shape == y_true_shape,
                              y_pred_size == y_true_size,
                              y_pred_material == y_true_material], axis=0)

# Overall accuracy
overall_accuracy = np.mean(correct_predictions)
print(f"Overall Accuracy (all properties correct): {overall_accuracy}")