In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, models, Input

In [2]:
CSV_PATH = './gazebo_dataset_01272026/labels_01272026.csv'
IMG_DIR = './gazebo_dataset_01272026/images/'
EDGE_DIR = './gazebo_dataset_01272026/edge_detection_results_01272026/'

In [3]:
df = pd.read_csv(CSV_PATH)
print(f"Total samples: {len(df)}")
print(f"\nFirst few rows:\n{df.head()}")
print(f"\nDirection counts:\n{df['direction'].value_counts()}")

Total samples: 11876

First few rows:
                               current_image  \
0  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
1  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
2  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
3  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
4  Loc0-102ed7ec84c44be3b4066caccff2011e.png   

                           destination_image direction  
0  Loc0-5107f16132e14cbbae95826a39aa0643.png     right  
1  Loc0-ad13f58f6f9549b48af9a145c8398fde.png     right  
2  Loc9-fb497088983647238d06767871bef8f7.png  backward  
3  Loc9-057e8f35ac974ad487b4ca23310cb397.png  backward  
4  Loc9-15467f62ac4f489194c3adb5d3fea27e.png  backward  

Direction counts:
direction
forward     3030
backward    3018
left        2925
right       2903
Name: count, dtype: int64


In [4]:
direction_map = {'forward': 0, 'backward': 1, 'left': 2, 'right': 3}
df['direction_label'] = df['direction'].map(direction_map)
num_clients = 2
shard_size = len(df) // num_clients
print(f"first few rows after direction mapping:\n{df.head()}")

first few rows after direction mapping:
                               current_image  \
0  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
1  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
2  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
3  Loc0-102ed7ec84c44be3b4066caccff2011e.png   
4  Loc0-102ed7ec84c44be3b4066caccff2011e.png   

                           destination_image direction  direction_label  
0  Loc0-5107f16132e14cbbae95826a39aa0643.png     right                3  
1  Loc0-ad13f58f6f9549b48af9a145c8398fde.png     right                3  
2  Loc9-fb497088983647238d06767871bef8f7.png  backward                1  
3  Loc9-057e8f35ac974ad487b4ca23310cb397.png  backward                1  
4  Loc9-15467f62ac4f489194c3adb5d3fea27e.png  backward                1  


In [5]:
# Don't shuffle - preserve original order
# df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)  # REMOVED

client_datasets = []

for i in range(num_clients):
    start = i * shard_size
    end = start + shard_size
    
    # Get shard in original order
    shard = df.iloc[start:end]
    
    # Shuffle THIS client's shard only
    shard_shuffled = shard.sample(frac=1, random_state=42 + i).reset_index(drop=True)
    
    # For each shard create train/val split
    train_df, val_df = train_test_split(shard_shuffled, test_size=0.2, random_state=42)
    client_datasets.append((train_df, val_df))
    
    print(f"\nClient: {i+1}")
    print(f"Training samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")


Client: 1
Training samples: 4750
Validation samples: 1188

Client: 2
Training samples: 4750
Validation samples: 1188


In [6]:
image_cache = {}

def load_image_cached(img_path):
    if img_path not in image_cache:
        img = load_img(img_path, target_size=(128, 128))
        img = img_to_array(img) / 255.0
        image_cache[img_path] = img
    return image_cache[img_path]

In [7]:
def create_dataset(dataframe, image_dir, batch_size=32):
    current_images = []
    dest_images = []
    labels = []
    suffix = ''
    if image_dir == EDGE_DIR:
        suffix = '_hed'
        
    for idx, row in dataframe.iterrows():
        current_img = load_image_cached(image_dir + row['current_image'][:-4] + suffix + '.png')
        dest_img = load_image_cached(image_dir + row['destination_image'][:-4] + suffix + '.png')

        current_images.append(current_img)
        dest_images.append(dest_img)
        labels.append(row['direction_label'])

    current_images = np.array(current_images)
    dest_images = np.array(dest_images)
    labels = np.array(labels)

    return current_images, dest_images, labels

In [8]:
client_train_data = []
client_val_data = []

for i, (train_df, val_df) in enumerate(client_datasets):

  print(f"\nClient: {i+1}")
  print("Loading training data...")
  X_train_current, X_train_dest, y_train = create_dataset(train_df, IMG_DIR)

  X_train_current_hed, X_train_dest_hed, y_train_hed = create_dataset(train_df, EDGE_DIR)


  print("Loading validation data...")
  X_val_current, X_val_dest, y_val = create_dataset(val_df, IMG_DIR)

  X_val_current_hed, X_val_dest_hed, y_val_hed = create_dataset(val_df, EDGE_DIR)


  client_train_data.append((X_train_current, X_train_dest, X_train_current_hed, X_train_dest_hed, y_train, ))
  client_val_data.append((X_val_current, X_val_dest, X_val_current_hed, X_val_dest_hed, y_val))
  print(f"\nTraining data shapes:")
  print(f"Current images: {X_train_current.shape}")
  print(f"Destination images: {X_train_dest.shape}")
  print(f"Labels: {y_train.shape}")


Client: 1
Loading training data...
Loading validation data...

Training data shapes:
Current images: (4750, 128, 128, 3)
Destination images: (4750, 128, 128, 3)
Labels: (4750,)

Client: 2
Loading training data...
Loading validation data...

Training data shapes:
Current images: (4750, 128, 128, 3)
Destination images: (4750, 128, 128, 3)
Labels: (4750,)


In [26]:
# for each client
for i, (Xc, Xd, Xch, Xdh, y) in enumerate(client_train_data):
    print("Client", i+1)
    print(np.unique(y, return_counts=True))
    print("n_train:", len(y))
for i, (Xc, Xd, Xch, Xdh, y) in enumerate(client_val_data):
    print("Client val", i+1, np.unique(y, return_counts=True))


Client 1
(array([0, 1, 2, 3]), array([1236, 1217, 1168, 1129]))
n_train: 4750
Client 2
(array([0, 1, 2, 3]), array([1167, 1193, 1147, 1243]))
n_train: 4750
Client val 1 (array([0, 1, 2, 3]), array([311, 306, 312, 259]))
Client val 2 (array([0, 1, 2, 3]), array([316, 302, 298, 272]))


In [34]:
def rgb_encoder(input_shape):
    inp = Input(shape=input_shape)
    x = layers.Conv2D(32, 3, activation='relu', padding='same')(inp)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(96, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    x = layers.LayerNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    return models.Model(inp, x, name="RGB_Encoder")


In [35]:
def hed_encoder(input_shape):
    inp = Input(shape=input_shape)
    x = layers.Conv2D(16, 3, activation='relu', padding='same')(inp)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(32, 3, activation='relu', padding='same')(x)
    x = layers.LayerNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    return models.Model(inp, x, name="HED_Encoder")


In [36]:
def build_siamese_model(H=128, W=128):
    # Inputs
    current_rgb = Input(shape=(H, W, 3))
    current_hed = Input(shape=(H, W, 3))
    dest_rgb = Input(shape=(H, W, 3))
    dest_hed = Input(shape=(H, W, 3))

    rgb_enc = rgb_encoder((H, W, 3))
    hed_enc = hed_encoder((H, W, 3))

    curr_feat = layers.Concatenate()([
        rgb_enc(current_rgb),
        hed_enc(current_hed)
    ])

    
    dest_feat = layers.Concatenate()([
        rgb_enc(dest_rgb),
        hed_enc(dest_hed)
    ])

    
    diff = layers.Subtract()([dest_feat, curr_feat])
    abs_diff = layers.Lambda(lambda x: tf.abs(x))(diff)

    final_feat = layers.Concatenate()([
        curr_feat, dest_feat, diff, abs_diff
    ])

    x = layers.Dense(256, activation='relu')(final_feat)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(128, activation='relu')(x)
    output = layers.Dense(4, activation='softmax')(x)

    model = models.Model(
        inputs=[current_rgb, current_hed, dest_rgb, dest_hed],
        outputs=output
    )

    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

In [37]:
global_model = build_siamese_model()

In [38]:
client_models = []
for i in range(num_clients):
    m = build_siamese_model()
    m.compile(optimizer='adam',
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
    client_models.append(m)

In [39]:
def federated_average(models):
    weights = [m.get_weights() for m in models]
    new_weights = []

    for layer_weights in zip(*weights):
        new_weights.append(
            np.mean(np.stack(layer_weights, axis=0), axis=0)
        )

    return new_weights


In [41]:
num_rounds = 25
local_epochs = 1
hed_freeze_round = 16
rgb_freeze_round = 20

for round_idx in range(num_rounds):
    
    print(f"\n===== Federated Round {round_idx+1} =====")

    if round_idx == rgb_freeze_round:
        print("Freezing encoders and recompiling models")

        for m in client_models:
            m.get_layer("RGB_Encoder").trainable = False

            m.compile(
                optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )

        global_model.get_layer("RGB_Encoder").trainable = False

        global_model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
    if round_idx == hed_freeze_round:
        print("Freezing encoders and recompiling models")

        for m in client_models:
            m.get_layer("HED_Encoder").trainable = False

            m.compile(
                optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )

        global_model.get_layer("HED_Encoder").trainable = False

        global_model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

    global_weights = global_model.get_weights()
    # Braodcast weights
    for m in client_models:
        m.set_weights(global_weights)

    # Local training
    for i in range(num_clients):
        print(f"Training model for client {i+1}")
        X_train_current, X_train_dest, X_train_current_hed, X_train_dest_hed, y_train = client_train_data[i]
        X_val_current, X_val_dest, X_val_current_hed, X_val_dest_hed, y_val = client_val_data[i]
        client_models[i].fit(
        [
            X_train_current,
            X_train_current_hed,
            X_train_dest,
            X_train_dest_hed
        ],
        y_train,
        validation_data=(
            [
                X_val_current,
                X_val_current_hed,
                X_val_dest,
                X_val_dest_hed
            ],
            y_val
        ),
        epochs=local_epochs,
        batch_size=16
    )

    # Aggregate
    global_model.set_weights(
        federated_average(client_models)
    )
    


===== Federated Round 1 =====
Training model for client 1
[1m297/297[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m92s[0m 292ms/step - accuracy: 0.2901 - loss: 1.3975 - val_accuracy: 0.3375 - val_loss: 1.3559
Training model for client 2
[1m297/297[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 289ms/step - accuracy: 0.2514 - loss: 1.4044 - val_accuracy: 0.2753 - val_loss: 1.3887

===== Federated Round 2 =====
Training model for client 1
[1m297/297[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 290ms/step - accuracy: 0.3105 - loss: 1.3573 - val_accuracy: 0.3636 - val_loss: 1.3189
Training model for client 2
[1m297/297[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 285ms/step - accuracy: 0.2802 - loss: 1.3789 - val_accuracy: 0.2988 - val_loss: 1.3538

===== Federated Round 3 =====
Training model for client 1
[1m297/297[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 290ms/step - accuracy: 0.3512 - loss: 1.3235 - val_accuracy: 0.4015 - val_loss: 1.26

In [42]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report

y_true_all = []
y_pred_all = []

for i in range(num_clients):
    print(f"Evaluating validation data for client {i+1}")

    X_val_current, X_val_dest, X_val_current_hed, X_val_dest_hed, y_val = client_val_data[i]

    preds = global_model.predict(
        [
            X_val_current,
            X_val_current_hed,
            X_val_dest,
            X_val_dest_hed
        ],
        batch_size=16,
        verbose=0
    )

    y_pred = np.argmax(preds, axis=1)

    y_true_all.append(y_val)
    y_pred_all.append(y_pred)

# Concatenate across all clients
y_true_all = np.concatenate(y_true_all)
y_pred_all = np.concatenate(y_pred_all)


Evaluating validation data for client 1
Evaluating validation data for client 2


In [43]:
accuracy = accuracy_score(y_true_all, y_pred_all)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_true_all,
    y_pred_all,
    average='macro'
)

print("\n===== Global Model Validation Metrics =====")
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1 Score : {f1:.4f}")

print("\nPer-class metrics:")
print(classification_report(y_true_all, y_pred_all))



===== Global Model Validation Metrics =====
Accuracy : 0.8422
Precision: 0.8422
Recall   : 0.8405
F1 Score : 0.8409

Per-class metrics:
              precision    recall  f1-score   support

           0       0.82      0.88      0.85       627
           1       0.85      0.85      0.85       608
           2       0.88      0.83      0.85       610
           3       0.82      0.80      0.81       531

    accuracy                           0.84      2376
   macro avg       0.84      0.84      0.84      2376
weighted avg       0.84      0.84      0.84      2376

