# MobileNetv2-Based Model
This script is the implementation of a MobileNetV2-based model.

### Image Metadata Preprocessing
This code in the cell below loads all `.png` files from the dataset, extracts metadata (row, column, field, plane, channel, cell number, plate ID, and source), and maps columns to biological class labels. 

It then removes invalid entries, sorts the dataset for consistency, and creates a structured DataFrame (`img_path_pd`) for later usage.

In [None]:
import os
from pathlib import Path
import platform
import pandas as pd
import tensorflow as tf

print('Python =', platform.python_version())
print('TensorFlow =', tf.__version__)

NUM_GPU = 4
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
# pd.set_option('display.max_rows', None)
# pd.set_option('display.max_columns', None)
# pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

CODE_FOLDER_NAME = 'codes'

FOLDER = 'home/featurize'

# Base path to the processed image dataset 
BASE_PATH = 'home/featurize/F1_balanced'

# Recursively get all .png file paths
all_png_paths = list(Path(BASE_PATH).rglob('*.png'))
img_path_pd = pd.DataFrame(all_png_paths, columns=['path']).astype(str)

# Extract base filename (no extension)
img_path_pd['filename'] = img_path_pd['path'].apply(lambda x: os.path.splitext(os.path.basename(x))[0])

# Save subfolder path
img_path_pd['folder'] = img_path_pd['path'].apply(lambda x: str(Path(x).parent))

# Extract "front" identifier (before "-chX...")
# This will give r02c02f01p03-
img_path_pd['front'] = img_path_pd['filename'].apply(lambda x: x.split('-')[0] + '-')

# Parse metadata from filename
img_path_pd['row']     = img_path_pd['filename'].str[0:3]    # e.g., r02
img_path_pd['column']  = img_path_pd['filename'].str[3:6]    # e.g., c02
img_path_pd['field']   = img_path_pd['filename'].str[6:9]    # e.g., f01
img_path_pd['plane']   = img_path_pd['filename'].str[9:12]   # e.g., p03
img_path_pd['channel'] = img_path_pd['filename'].str.split('-').str[1].str[:3]  # ch1
img_path_pd['rcf']     = img_path_pd['row'] + img_path_pd['column'] + img_path_pd['field']
img_path_pd['rc']      = img_path_pd['row'] + img_path_pd['column']

# Extract cell number 
img_path_pd['cell_no_str'] = img_path_pd['filename'].str.extract(r'_Cell_(\d{1,3})')[0]

# Extract plate_id from filename (_P1, _P2, _P3)
# img_path_pd['plate_id'] = img_path_pd['filename'].str.extract(r'_P(\d)$')[0].astype(int)
img_path_pd['plate_id'] = img_path_pd['filename'].str.extract(r'_P(\d)_F\d')[0].astype(int)

# For all the dataset, add this part to extract data sources (F1, F2, F3)
img_path_pd['source'] = img_path_pd['filename'].str.extract(r'_F(\d)$')[0].astype(int)

# Convert column string to integer
img_path_pd['column_id'] = img_path_pd['column'].str[1:].astype(int)

# Map labels based on column number
column_label_map = {
    2: 'PARENT',
    3: 'TREM2_KO',
    4: 'R47H',
    5: 'H157Y',
    6: 'PLCG2_KO',
    7: 'P522R',
    8: 'P522R_HET',
    9: 'SHIP1_KO',
    10: 'ABI3_KO',
    11: 'S209F'
}
img_path_pd['class_name'] = img_path_pd['column_id'].map(column_label_map)

# Drop invalid rows
img_path_pd = img_path_pd.dropna(subset=['class_name']).reset_index(drop=True)

# Sort for consistency
# img_path_pd = img_path_pd.sort_values(by=['rcf', 'plane', 'channel']).reset_index(drop=True)
img_path_pd = img_path_pd.sort_values(by=['rcf', 'plane', 'channel', 'source']).reset_index(drop=True)

# Output preview
print(img_path_pd.shape)
img_path_pd.head()

### Cell-Level Grouping and Train/Validation Split
This code assigns numerical labels to classes based on a fixed order, constructs a unique `cell_id` (combining position, cell number, and plate ID) to represent each 5-channel image group, and filters out incomplete cells. 

It then ensures one label per cell, stratifies by class, and splits the dataset into training and validation sets at the **cell level** to avoid data leakage.

Finally, it retrieves all 5-channel image paths for each split and checks the label distribution to confirm balanced sampling across classes.

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np

# control the label order
custom_label_order = [
    'PARENT', 'TREM2_KO', 'R47H', 'H157Y', 'PLCG2_KO',
    'P522R', 'P522R_HET', 'SHIP1_KO', 'ABI3_KO', 'S209F'
]

# create name → index mapping
label2id = {name: idx for idx, name in enumerate(custom_label_order)}

# Add a numerical label to the DataFrame
img_path_pd['label'] = img_path_pd['class_name'].map(label2id)
print('Custom label mapping: ', label2id)

# Add a cell_id column (including plate information, used to represent a 5-channel image group)
# Create a unique cell_id using front + cell_no_str + plate_id
img_path_pd['cell_id'] = (
    img_path_pd['front'] + 
    'Cell_' + img_path_pd['cell_no_str'] + 
    '_P' + img_path_pd['plate_id'].astype(str)
)

#img_path_pd['cell_id'] = (
#    img_path_pd['front'] + 
#    'Cell_' + img_path_pd['cell_no_str'] + 
#    '_P' + img_path_pd['plate_id'].astype(str) +
#    '_F' + img_path_pd['source'].astype(str)
#)

# Check whether each cell has all 5 channels
cell_channel_counts = img_path_pd.groupby('cell_id')['channel'].count()
print(f"Cells with 5 channels: {(cell_channel_counts == 5).sum()}")
print(f"Cells with incomplete channels: {(cell_channel_counts != 5).sum()}")

# Keep only cells with complete 5 channels
complete_cells = cell_channel_counts[cell_channel_counts == 5].index
img_path_pd_filtered = img_path_pd[img_path_pd['cell_id'].isin(complete_cells)].reset_index(drop=True)

print(f"Original cells: {img_path_pd['cell_id'].nunique()}")
print(f"Complete cells: {len(complete_cells)}")

# Deduplicate based on cell_id, and create a label mapping for each cell (use any one image to represent a cell)
cell_df = img_path_pd_filtered.drop_duplicates('cell_id')[['cell_id', 'label']]

print(f"Unique cells for splitting: {len(cell_df)}")
print("Label distribution:")
print(cell_df['label'].value_counts().sort_index())

# Split the dataset by cells (to avoid images from the same cell being mixed in both training and validation sets)
train_cell_ids, val_cell_ids = train_test_split(
    cell_df['cell_id'],
    test_size=0.2,
    random_state=1,
    stratify=cell_df['label']
)
#cell_df['stratify_col'] = cell_df['label'].astype(str) + '_' + cell_df['cell_id'].str.extract(r'_F(\d)$')[0]
#train_cell_ids, val_cell_ids = train_test_split(
#    cell_df['cell_id'],
#    test_size=0.2,
#    random_state=1,
#    stratify=cell_df['stratify_col']
#)

# Retrieve the complete 5-channel image paths corresponding to each cell (image-level data)
pl_train_pd = img_path_pd_filtered[img_path_pd_filtered['cell_id'].isin(train_cell_ids)].reset_index(drop=True)
pl_val_pd = img_path_pd_filtered[img_path_pd_filtered['cell_id'].isin(val_cell_ids)].reset_index(drop=True)

print(f'Training size: {pl_train_pd.shape}')
print(f'Validation size: {pl_val_pd.shape}')
print(f'Training unique cells: {pl_train_pd["cell_id"].nunique()}')
print(f'Validation unique cells: {pl_val_pd["cell_id"].nunique()}')

# Check the label distribution in the training and validation sets
print("\nTraining set label distribution:")
train_label_counts = pl_train_pd.drop_duplicates('cell_id')['label'].value_counts().sort_index()
for label, count in train_label_counts.items():
    class_name = [k for k, v in label2id.items() if v == label][0]
    print(f"  {class_name} (label {label}): {count} cells")

print("\nValidation set label distribution:")
val_label_counts = pl_val_pd.drop_duplicates('cell_id')['label'].value_counts().sort_index()
for label, count in val_label_counts.items():
    class_name = [k for k, v in label2id.items() if v == label][0]
    print(f"  {class_name} (label {label}): {count} cells")

pl_train_pd.head()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Root path where PNG images are stored
BASE_PATH = 'All'

# Ensure consistent channel order
channel_order = sorted(img_path_pd['channel'].unique())  # e.g., ['ch1', 'ch2', ..., 'ch5']

# If we're working with cell-level splits, we need to get unique cells for training and validation
# Get unique cell identifiers and their corresponding labels from the split data
train_cell_info = pl_train_pd.drop_duplicates('cell_id')[['cell_id', 'front', 'cell_no_str', 'plate_id', 'label']]
val_cell_info = pl_val_pd.drop_duplicates('cell_id')[['cell_id', 'front', 'cell_no_str', 'plate_id', 'label']]

# For all dataset:
#train_cell_info = pl_train_pd.drop_duplicates('cell_id')[['cell_id', 'front', 'cell_no_str', 'plate_id', 'source', 'label']]
#val_cell_info = pl_val_pd.drop_duplicates('cell_id')[['cell_id', 'front', 'cell_no_str', 'plate_id', 'source', 'label']]

# Convert to numpy arrays for dataset creation
train_fronts = train_cell_info['front'].to_numpy()
train_cells = train_cell_info['cell_no_str'].to_numpy()
train_plates = train_cell_info['plate_id'].astype(str).to_numpy()
train_labels = train_cell_info['label'].to_numpy()
# add for all
#train_sources = train_cell_info['source'].astype(str).to_numpy()

val_fronts = val_cell_info['front'].to_numpy() 
val_cells = val_cell_info['cell_no_str'].to_numpy()
val_plates = val_cell_info['plate_id'].astype(str).to_numpy()
val_labels = val_cell_info['label'].to_numpy()
# add for all
#val_sources = val_cell_info['source'].astype(str).to_numpy()

# Function to load and stack 5-channel images
#def stack_img(front, cell_no_str, plate_id, label):

def stack_img(front, cell_no_str, plate_id, source, label):
    img_list = []
    for ch in channel_order:
        # Updated folder path construction for F1_balanced dataset
        # Path structure: .../front+ch+sk1fk1fl1/front+ch+sk1fk1fl1_Cell_X_PY.png
        folder = tf.strings.join([BASE_PATH, '/', front, ch, 'sk1fk1fl1'])
        filename = tf.strings.join([front, ch, 'sk1fk1fl1_Cell_', cell_no_str, '_P', plate_id, '.png'])
        #filename = tf.strings.join([front, ch, 'sk1fk1fl1_Cell_', cell_no_str, '_P', plate_id, '_F', source, '.png'])
        full_path = tf.strings.join([folder, '/', filename])
        
        # Read and normalize the image
        img = tf.io.decode_png(tf.io.read_file(full_path), channels=1)
        img = tf.cast(img, tf.float32) / 127.5 - 1.0
        img_list.append(img)
    
    img_stack = tf.concat(img_list, axis=-1)  # shape: (H, W, 5)
    return img_stack, label

# Batch size
BATCH_SIZE = 32

# Build the training dataset
#train_ds = tf.data.Dataset.from_tensor_slices((train_fronts, train_cells, train_plates, train_labels))
train_ds = tf.data.Dataset.from_tensor_slices((train_fronts, train_cells, train_plates, train_sources, train_labels))
train_batches = (train_ds
    .shuffle(buffer_size=len(train_fronts), reshuffle_each_iteration=True)
    .map(stack_img, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE))

# Build the validation dataset  
val_ds = tf.data.Dataset.from_tensor_slices((val_fronts, val_cells, val_plates, val_labels))
#val_ds = tf.data.Dataset.from_tensor_slices((val_fronts, val_cells, val_plates, val_sources, val_labels))
val_batches = (val_ds
    .map(stack_img, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE))

print(f"Training cells: {len(train_fronts)}")
print(f"Validation cells: {len(val_fronts)}")

# Visualise a batch
for images, labels in train_batches.take(1):
    sample_image = images[0]
    sample_label = labels[0].numpy()
    
    print(f"Label index: {sample_label}")
    print(f"Image shape: {sample_image.shape}")
    print(f"Pixel range: min={tf.reduce_min(sample_image).numpy():.2f}, max={tf.reduce_max(sample_image).numpy():.2f}")
    
    id2label = {v: k for k, v in label2id.items()}
    label_name = id2label[sample_label]
    print(f"Label name: {label_name}")
    
    # Plot the 5 channels
    plt.figure(figsize=(15, 3))
    for i in range(sample_image.shape[-1]):
        plt.subplot(1, sample_image.shape[-1], i + 1)
        plt.imshow(sample_image[:, :, i], cmap='gray')
        plt.title(f"Channel {i + 1}")
        plt.axis('off')
    plt.suptitle(f"Sample Image - Label: {sample_label} → {label_name}", fontsize=14)
    plt.tight_layout()
    plt.show()
    break

### (Optional) Overlapping Check

In [None]:
import pandas as pd
import numpy as np

def check_cell_level_overlap(pl_train_pd, pl_val_pd):
    """
    Focused cell-level overlap verification
    """
    print("=" * 50)
    print("CELL-LEVEL OVERLAP CHECK")
    print("=" * 50)
    
    # Get unique cell_ids from each set
    train_cells = set(pl_train_pd['cell_id'].unique())
    val_cells = set(pl_val_pd['cell_id'].unique())
    
    # Find overlaps
    cell_overlap = train_cells.intersection(val_cells)
    
    # Basic stats
    print(f"Training unique cells: {len(train_cells)}")
    print(f"Validation unique cells: {len(val_cells)}")
    print(f"Total unique cells: {len(train_cells) + len(val_cells)}")
    print(f"Overlapping cells: {len(cell_overlap)}")
    
    # Check ratios
    total_original = len(pd.concat([pl_train_pd, pl_val_pd])['cell_id'].unique())
    train_ratio = len(train_cells) / total_original
    val_ratio = len(val_cells) / total_original
    
    print(f"\nSplit ratios:")
    print(f"  Training: {train_ratio:.1%}")
    print(f"  Validation: {val_ratio:.1%}")
    
    # Overlap result
    print(f"\n{'='*20} RESULT {'='*20}")
    if len(cell_overlap) == 0:
        print(" SUCCESS: No cell overlap found!")
        print("   Data split is clean - no risk of data leakage")
    else:
        print(f" PROBLEM: Found {len(cell_overlap)} overlapping cells!")
        print("   This indicates data leakage risk")
        
        # Show some examples of overlapping cells
        overlap_examples = list(cell_overlap)[:10]
        print(f"   Examples of overlapping cell_ids:")
        for cell_id in overlap_examples:
            print(f"     - {cell_id}")
        
        if len(cell_overlap) > 10:
            print(f"     ... and {len(cell_overlap) - 10} more")
    
    print("=" * 50)
    
    return len(cell_overlap) == 0

# Run the cell-level check
is_clean = check_cell_level_overlap(pl_train_pd, pl_val_pd)

# Additional verification: check if all original cells are accounted for
print(f"\nVERIFICATION:")
original_cells = set(img_path_pd_filtered['cell_id'].unique())  
all_split_cells = set(pl_train_pd['cell_id'].unique()).union(set(pl_val_pd['cell_id'].unique()))

missing_cells = original_cells - all_split_cells
extra_cells = all_split_cells - original_cells

print(f"Original total cells: {len(original_cells)}")
print(f"Cells after split: {len(all_split_cells)}")
print(f"Missing cells: {len(missing_cells)}")
print(f"Extra cells: {len(extra_cells)}")

if len(missing_cells) == 0 and len(extra_cells) == 0:
    print(" All original cells are properly accounted for")
else:
    if len(missing_cells) > 0:
        print(f" {len(missing_cells)} cells are missing after split")
        if len(missing_cells) <= 10:
            print("Missing cells:")
            for cell_id in missing_cells:
                print(f"  - {cell_id}")
    if len(extra_cells) > 0:
        print(f" {len(extra_cells)} extra cells appeared after split")
        if len(extra_cells) <= 10:
            print("Extra cells:")
            for cell_id in extra_cells:
                print(f"  - {cell_id}")

# Additional check: verify split maintains label distribution
print(f"\n{'='*20} LABEL DISTRIBUTION CHECK {'='*20}")
print("Training set label distribution:")
train_label_counts = pl_train_pd.drop_duplicates('cell_id')['label'].value_counts().sort_index()
for label, count in train_label_counts.items():
    print(f"  Label {label}: {count} cells")

print("\nValidation set label distribution:")
val_label_counts = pl_val_pd.drop_duplicates('cell_id')['label'].value_counts().sort_index()
for label, count in val_label_counts.items():
    print(f"  Label {label}: {count} cells")

# Check if split is stratified correctly
print(f"\nStratification check:")
all_cells = len(original_cells)
for label in sorted(train_label_counts.index.union(val_label_counts.index)):
    train_count = train_label_counts.get(label, 0)
    val_count = val_label_counts.get(label, 0)
    total_count = train_count + val_count
    train_pct = train_count / total_count * 100 if total_count > 0 else 0
    val_pct = val_count / total_count * 100 if total_count > 0 else 0
    print(f"  Label {label}: Train {train_pct:.1f}% ({train_count}), Val {val_pct:.1f}% ({val_count})")

print("=" * 70)

### Build and Train the MobileNetv2 Model

In [None]:
# Build model
IMG_HEIGHT = 540
IMG_WIDTH = 540
EPOCH_INITIAL = 25
BASE_LEARNING_RATE = 0.002
GLOBAL_BATCH_SIZE = BATCH_SIZE  

# Multi GPU training
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():

    # Builds a MobileNetV2 that accepts 5-channel input.
    base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_HEIGHT, IMG_WIDTH, len(channel_order)),
                                                   include_top=False,
                                                   weights=None)

    # Loads a pretrained MobileNetV2 (3-channel).
    #base_weights = tf.keras.applications.MobileNetV2(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
                                                     #include_top=False,
                                                     #weights='imagenet')

    local_weights_path = '/home/featurize/work/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5'

    base_weights = tf.keras.applications.MobileNetV2(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
                                                 include_top=False,
                                                 weights=local_weights_path)
    
    copy_weight_from = 2
    # layers 0 and 1 are NOT copied
    # This is because the first convolutional layer expects 3 channels (ImageNet) 
    # but your model has 5 channels
    
    for i in range(copy_weight_from, len(base_model.layers)):
        base_model.layers[i].set_weights(base_weights.layers[i].get_weights())
    for layer in base_model.layers[copy_weight_from:]:
        layer.trainable = False

    print("Number of layers in the base model: ", len(base_model.layers))

    inputs = tf.keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, len(channel_order)))
    _ = inputs  # images have already been normalized, so no further preprocessing is needed.

    _ = base_model(_, training=False)  # freeze batch_norm
    _ = tf.keras.layers.MaxPool2D(pool_size=2)(_)

    _ = tf.keras.layers.GlobalAveragePooling2D()(_)
    
    #Applies Dropout (20%) to prevent overfitting
    _ = tf.keras.layers.Dropout(0.2)(_)

    # Final classification layer with 10 output neurons (for 10 classes), 
    # using softmax to output class probabilities
    outputs = tf.keras.layers.Dense(10, activation='softmax')(_)  

    # Wraps the defined layers into a Model object.
    model = tf.keras.Model(inputs, outputs)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=BASE_LEARNING_RATE),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=['accuracy']
    )

    model.summary(line_length=120)

# defines an EarlyStopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    verbose=1,
    patience=80,
    mode='max',
    restore_best_weights=True)


# path = os.path.join('/Users/zhuangzhuang/Desktop/Data Science Project/baseline_datasets', CODE_FOLDER_NAME, FOLDER)
path = os.path.join('work', CODE_FOLDER_NAME, FOLDER)

if not os.path.exists(path):
    os.makedirs(path)
csv_logger = tf.keras.callbacks.CSVLogger(os.path.join(path, 'log_initial_{}.csv'.format(FOLDER)), append=False)

    
class save_per_epoch(tf.keras.callbacks.Callback):
    def __init__(self, verbose=0):
        super(save_per_epoch, self).__init__()
        self.verbose = verbose
    def on_epoch_end(self, epoch, logs={}):
        self.model.save(os.path.join(path, 'model_temp_{}_{}.h5'.format(FOLDER, epoch)))

class validate_all_val(tf.keras.callbacks.Callback):
    def __init__(self, verbose=0):
        super(validate_all_val, self).__init__()
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        self.loss, self.acc = self.model.evaluate(val_batches, verbose=self.verbose)
        print(' - all_val_accuracy: {0:.4f}'.format(self.acc))

        pred = self.model.predict(val_batches)
        label_pred = np.argmax(pred, axis=1)  

        print('accuracy:', accuracy_score(pl_val_pd['label'], label_pred))

In [None]:
# initially freeze all layers, train classification layer
history = model.fit(train_batches,
                    epochs=EPOCH_INITIAL,
                    verbose=1,
                    validation_data=val_batches,
#                     callbacks=[early_stopping, csv_logger, validate_all_val()])
                    callbacks=[early_stopping, csv_logger, save_per_epoch()])

# # model.save('train_v1.1__01.h5')
# print('Model saved.')

In [None]:
# fine-tune the model for an additional 30 epochs, after the initial training phase
EPOCH_FINE = 30
EPOCH_TOTAL = EPOCH_INITIAL + EPOCH_FINE

# if validation accuracy doesn’t improve for 20 epochs, stop training and restore best weights.
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    verbose=1,
    patience=20,
    mode='max',
    restore_best_weights=True)

csv_logger = tf.keras.callbacks.CSVLogger(os.path.join(path, 'log_fine_{}.csv'.format(FOLDER)), append=False)

with strategy.scope():
    # # unfreeze some top layers, fine tune
    base_model.trainable = True

    # (Optional) Fine-tune only top N layers
    # fine_tune_at = 100
    # for layer in base_model.layers[:fine_tune_at]:
    #     layer.trainable = False

    # Recompile model
    model.compile(
        optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=['accuracy']
    )

    print(model.summary())

In [None]:
history_fine = model.fit(train_batches,
                         epochs=EPOCH_TOTAL,
                         verbose=1,
                         initial_epoch=history.epoch[-1]+1,
#                          initial_epoch=100,
                         validation_data=val_batches,
#                          callbacks=[early_stopping, csv_logger, validate_all_val()],
                         callbacks=[early_stopping, csv_logger, save_per_epoch()])
# # model.save('train_v1.1__02.h5')
# print('Model saved.')

In [None]:
# save and load model

# model.save('train_v2.1.h5')
# print('Model saved.')

# Load model with strategy.scope
# https://www.tensorflow.org/tutorials/distribute/keras
with strategy.scope():
    model.load_weights(os.path.join(path, 'model_temp_{}_25.h5'.format(FOLDER)))
    print(model.summary())
# Load model without strategy.scope
# model = tf.keras.models.load_model('train_v2.1_20_02.h5')
# model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.00001),
# #               loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#               metrics=['accuracy'])

### Accuracy

In [None]:
# acc = history.history['accuracy']
# val_acc = history.history['val_accuracy']
# acc_fine = history_fine.history['accuracy']
# val_acc_fine = history_fine.history['val_accuracy']

# Load Accuracy Logs from CSV (Instead of Directly from history)
acc = pd.read_csv(os.path.join(path, 'log_initial_{}.csv'.format(FOLDER)))['accuracy'].to_list()
val_acc = pd.read_csv(os.path.join(path, 'log_initial_{}.csv'.format(FOLDER)))['val_accuracy'].to_list()
acc_fine = pd.read_csv(os.path.join(path, 'log_fine_{}.csv'.format(FOLDER)))['accuracy'].to_list()
val_acc_fine = pd.read_csv(os.path.join(path, 'log_fine_{}.csv'.format(FOLDER)))['val_accuracy'].to_list()

# Plot Accuracy Curves FIRST
plt.figure(figsize=(8, 8))
plt.subplot(2, 2, 1)
plt.plot(acc+acc_fine, label='Training Accuracy')
plt.plot(val_acc+val_acc_fine, label='Validation Accuracy')
plt.plot([EPOCH_INITIAL-1,EPOCH_INITIAL-1], plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel('epoch')

# Create directory and save AFTER plotting
# root_path = '/Users/zhuangzhuang/Desktop/Data Science Project'
root_path = '/home/featurize/results'

accuracy_dir = os.path.join(root_path, 'Accuracy')
os.makedirs(accuracy_dir, exist_ok=True)

save_path = os.path.join(accuracy_dir, 'all.png')

plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()

### Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
import os

# Predict
pred = model.predict(val_batches)
label_pred = np.argmax(pred, axis=1)

# Ground truth
y_true = val_cell_info['label'].to_numpy()

# Ensure correct label order
labels = list(label2id.values())
id2label = {v: k for k, v in label2id.items()}

# Confusion matrix
cf_matrix = confusion_matrix(y_true, label_pred, labels=labels)

# Output directory
root_path = '/home/featurize/results/confusion_matrix'
os.makedirs(root_path, exist_ok=True)

# Plot and save
plt.figure(figsize=(10, 8))
sns.heatmap(cf_matrix,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=[id2label[i] for i in labels],
            yticklabels=[id2label[i] for i in labels],
            linewidths=0.5)

plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()

# Save to file
save_path = os.path.join(root_path, 'all.png')
plt.savefig(save_path, dpi=300)
plt.show()

print(f"Confusion matrix saved to: {save_path}")

### Feature Extraction

In [None]:
import os
from pathlib import Path
import platform
import pandas as pd
import tensorflow as tf

print('Python =', platform.python_version())
print('TensorFlow =', tf.__version__)

NUM_GPU = 4
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
# pd.set_option('display.max_rows', None)
# pd.set_option('display.max_columns', None)
# pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

CODE_FOLDER_NAME = 'codes'

FOLDER = 'vall'

BASE_PATH = 'All'

# Recursively get all .png file paths
all_png_paths = list(Path(BASE_PATH).rglob('*.png'))
img_path_pd = pd.DataFrame(all_png_paths, columns=['path']).astype(str)

# Extract base filename (no extension)
img_path_pd['filename'] = img_path_pd['path'].apply(lambda x: os.path.splitext(os.path.basename(x))[0])

# Save subfolder path
img_path_pd['folder'] = img_path_pd['path'].apply(lambda x: str(Path(x).parent))

# Extract "front" identifier (before "-chX...") 
# This will give r02c02f01p03-
img_path_pd['front'] = img_path_pd['filename'].apply(lambda x: x.split('-')[0] + '-')

# Parse metadata from filename
img_path_pd['row']     = img_path_pd['filename'].str[0:3]    # e.g., r02
img_path_pd['column']  = img_path_pd['filename'].str[3:6]    # e.g., c02
img_path_pd['field']   = img_path_pd['filename'].str[6:9]    # e.g., f01
img_path_pd['plane']   = img_path_pd['filename'].str[9:12]   # e.g., p03
img_path_pd['channel'] = img_path_pd['filename'].str.split('-').str[1].str[:3]  # ch1
img_path_pd['rcf']     = img_path_pd['row'] + img_path_pd['column'] + img_path_pd['field']
img_path_pd['rc']      = img_path_pd['row'] + img_path_pd['column']

# Extract cell number
img_path_pd['cell_no_str'] = img_path_pd['filename'].str.extract(r'_Cell_(\d{1,3})')[0]

# === Extract plate_id from filename (_P1, _P2, _P3) ===
# img_path_pd['plate_id'] = img_path_pd['filename'].str.extract(r'_P(\d)$')[0].astype(int)
img_path_pd['plate_id'] = img_path_pd['filename'].str.extract(r'_P(\d)_F\d')[0].astype(int)

# Extract data factory (F123)
img_path_pd['source'] = img_path_pd['filename'].str.extract(r'_F(\d)$')[0].astype(int)


# Convert column string to integer
img_path_pd['column_id'] = img_path_pd['column'].str[1:].astype(int)

# Map labels based on column number
column_label_map = {
    2: 'PARENT',
    3: 'TREM2_KO',
    4: 'R47H',
    5: 'H157Y',
    6: 'PLCG2_KO',
    7: 'P522R',
    8: 'P522R_HET',
    9: 'SHIP1_KO',
    10: 'ABI3_KO',
    11: 'S209F'
}
img_path_pd['class_name'] = img_path_pd['column_id'].map(column_label_map)

# Drop invalid rows
img_path_pd = img_path_pd.dropna(subset=['class_name']).reset_index(drop=True)

# Sort for consistency
# img_path_pd = img_path_pd.sort_values(by=['rcf', 'plane', 'channel']).reset_index(drop=True)
img_path_pd = img_path_pd.sort_values(by=['rcf', 'plane', 'channel', 'source']).reset_index(drop=True)

print(img_path_pd.shape)
img_path_pd.head()

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np

# control the label order
custom_label_order = [
    'PARENT', 'TREM2_KO', 'R47H', 'H157Y', 'PLCG2_KO',
    'P522R', 'P522R_HET', 'SHIP1_KO', 'ABI3_KO', 'S209F'
]

# create name → index mapping
label2id = {name: idx for idx, name in enumerate(custom_label_order)}

# Add a numerical label to the DataFrame
img_path_pd['label'] = img_path_pd['class_name'].map(label2id)
print('Custom label mapping: ', label2id)

# Add a cell_id column
# Create a unique cell_id using front + cell_no_str + plate_id
#img_path_pd['cell_id'] = (
#    img_path_pd['front'] + 
#    'Cell_' + img_path_pd['cell_no_str'] + 
#    '_P' + img_path_pd['plate_id'].astype(str)
#)
img_path_pd['cell_id'] = (
    img_path_pd['front'] + 
    'Cell_' + img_path_pd['cell_no_str'] + 
    '_P' + img_path_pd['plate_id'].astype(str) +
    '_F' + img_path_pd['source'].astype(str)
)

# Check whether each cell has all 5 channels
cell_channel_counts = img_path_pd.groupby('cell_id')['channel'].count()
print(f"Cells with 5 channels: {(cell_channel_counts == 5).sum()}")
print(f"Cells with incomplete channels: {(cell_channel_counts != 5).sum()}")

# Keep only cells with complete 5 channels
complete_cells = cell_channel_counts[cell_channel_counts == 5].index
img_path_pd_filtered = img_path_pd[img_path_pd['cell_id'].isin(complete_cells)].reset_index(drop=True)

print(f"Original cells: {img_path_pd['cell_id'].nunique()}")
print(f"Complete cells: {len(complete_cells)}")

# Deduplicate based on cell_id, and create a label mapping for each cell
cell_df = img_path_pd_filtered.drop_duplicates('cell_id')[['cell_id', 'label']]

print(f"Unique cells for splitting: {len(cell_df)}")
print("Label distribution:")
print(cell_df['label'].value_counts().sort_index())

# Split the dataset by cells 
#train_cell_ids, val_cell_ids = train_test_split(
#    cell_df['cell_id'],
#    test_size=0.2,
#    random_state=1,
#    stratify=cell_df['label']
#)
cell_df['stratify_col'] = cell_df['label'].astype(str) + '_' + cell_df['cell_id'].str.extract(r'_F(\d)$')[0]
train_cell_ids, val_cell_ids = train_test_split(
    cell_df['cell_id'],
    test_size=0.2,
    random_state=1,
    stratify=cell_df['stratify_col']
)

# Retrieve the complete 5-channel image paths corresponding to each cell (image-level data)
pl_train_pd = img_path_pd_filtered[img_path_pd_filtered['cell_id'].isin(train_cell_ids)].reset_index(drop=True)
pl_val_pd = img_path_pd_filtered[img_path_pd_filtered['cell_id'].isin(val_cell_ids)].reset_index(drop=True)

print(f'Training size: {pl_train_pd.shape}')
print(f'Validation size: {pl_val_pd.shape}')
print(f'Training unique cells: {pl_train_pd["cell_id"].nunique()}')
print(f'Validation unique cells: {pl_val_pd["cell_id"].nunique()}')

# Check the label distribution in the training and validation sets
print("\nTraining set label distribution:")
train_label_counts = pl_train_pd.drop_duplicates('cell_id')['label'].value_counts().sort_index()
for label, count in train_label_counts.items():
    class_name = [k for k, v in label2id.items() if v == label][0]
    print(f"  {class_name} (label {label}): {count} cells")

print("\nValidation set label distribution:")
val_label_counts = pl_val_pd.drop_duplicates('cell_id')['label'].value_counts().sort_index()
for label, count in val_label_counts.items():
    class_name = [k for k, v in label2id.items() if v == label][0]
    print(f"  {class_name} (label {label}): {count} cells")

pl_train_pd.head()

### Extract correctly predicted features only

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os

# Predict on validation set
pred_probs = model.predict(val_batches)
label_pred = np.argmax(pred_probs, axis=1)

# Ground truth labels (cell-level)
y_true = val_cell_info['label'].to_numpy()

# Find the indices of correctly predicted samples
correct_indices = np.where(label_pred == y_true)[0]

# Filter the information of correctly predicted cells (aligned with the order in val_batches)
correct_val_info = val_cell_info.iloc[correct_indices].reset_index(drop=True)

# Rebuild the Dataset to include only correctly predicted cells (add plate_id)
#correct_fronts = correct_val_info['front'].to_numpy()
#correct_cells  = correct_val_info['cell_no_str'].to_numpy()
#correct_plates = correct_val_info['plate_id'].astype(str).to_numpy()
#correct_labels = correct_val_info['label'].to_numpy()
correct_fronts = correct_val_info['front'].to_numpy()
correct_cells  = correct_val_info['cell_no_str'].to_numpy()
correct_plates = correct_val_info['plate_id'].astype(str).to_numpy()
correct_sources = correct_val_info['source'].astype(str).to_numpy()  
correct_labels = correct_val_info['label'].to_numpy()

# correct_ds = tf.data.Dataset.from_tensor_slices((correct_fronts, correct_cells, correct_plates, correct_labels))
correct_ds = tf.data.Dataset.from_tensor_slices((correct_fronts, correct_cells, correct_plates, correct_sources, correct_labels))
correct_batches = (
    correct_ds
    .map(stack_img, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

# extract features
model_part = tf.keras.Model(
    inputs=model.input,
    outputs=model.get_layer('global_average_pooling2d').output  # ensure the layer name is correct
)

feature = model_part.predict(correct_batches)

# Construct the feature DataFrame
feature_pd = pd.DataFrame(
    feature,
    index=correct_val_info.index,
    columns=('f_' + pd.Series(np.arange(feature.shape[1]).astype(str))).to_numpy()
)

# Merge with metadata
feature_pd = pd.concat([feature_pd, correct_val_info], axis=1)


feature_output_path = os.path.join(
    '/home/featurize/features',
    'all.csv'
)

os.makedirs(os.path.dirname(feature_output_path), exist_ok=True)

feature_pd.to_csv(feature_output_path, index=False)

print(f"Feature extraction completed. Saved {feature_pd.shape[0]} correct samples to:\n{feature_output_path}")

## similarity

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
import os

feature_csv_path = '/home/featurize/features/all.csv'
output_dir = '/home/featurize/results/similarity'
os.makedirs(output_dir, exist_ok=True)

feature_pd = pd.read_csv(feature_csv_path)
feature_columns = [col for col in feature_pd.columns if col.startswith('f_')]

feature_pd = feature_pd.dropna(subset=['label'])  
feature_pd[feature_columns] = feature_pd[feature_columns].apply(pd.to_numeric, errors='coerce').fillna(0.0)

custom_label_order = [
    'PARENT', 'TREM2_KO', 'R47H', 'H157Y', 'PLCG2_KO',
    'P522R', 'P522R_HET', 'SHIP1_KO', 'ABI3_KO', 'S209F'
]
label2id = {name: idx for idx, name in enumerate(custom_label_order)}
id2label = {v: k for k, v in label2id.items()}

class_names = sorted(feature_pd['label'].unique())
n_class = len(class_names)

similarity_matrix = np.zeros((n_class, n_class))

# Iterate over each pair of classes and compute the average sample-level pairwise similarity
for i, label_i in enumerate(class_names):
    feats_i = feature_pd[feature_pd['label'] == label_i][feature_columns].values
    for j, label_j in enumerate(class_names):
        if j < i:
            similarity_matrix[i, j] = similarity_matrix[j, i]
        else:
            feats_j = feature_pd[feature_pd['label'] == label_j][feature_columns].values
            sim_ij = cosine_similarity(feats_i, feats_j)
            similarity_matrix[i, j] = np.mean(sim_ij)

simi_df = pd.DataFrame(similarity_matrix, index=class_names, columns=class_names)

if np.issubdtype(simi_df.index.dtype, np.integer):
    simi_df.index = [id2label[i] if i in id2label else i for i in simi_df.index]
    simi_df.columns = [id2label[i] if i in id2label else i for i in simi_df.columns]

vmin = np.min(similarity_matrix)
vmax = np.max(similarity_matrix)
print(f" Similarity range: min = {vmin:.4f}, max = {vmax:.4f}")

mask = np.triu(np.ones_like(simi_df, dtype=bool), k=1)

plt.figure(figsize=(10, 8))
sns.heatmap(
    simi_df,
    mask=mask,
    annot=True,
    fmt='.4f',
    cmap='Blues',
    vmin=vmin,
    vmax=vmax,
    linewidths=0.5,
    annot_kws={"fontsize": 8},
    cbar_kws={'label': 'Cosine Similarity'}
)
plt.title("Similarity Matrix", fontsize=14)
plt.xticks(rotation=45)
plt.tight_layout()

simi_df.to_csv(os.path.join(output_dir, 'all.csv'))
plt.savefig(os.path.join(output_dir, 'all.png'), dpi=300)

print("Pairwise similarity matrix saved to:", output_dir)

## UMAP

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import umap
from sklearn.preprocessing import StandardScaler

feature_csv_path = '/home/featurize/features/all.csv'
output_dir = '/home/featurize/results/umap'
os.makedirs(output_dir, exist_ok=True)

# Load features
feature_pd = pd.read_csv(feature_csv_path)
feature_columns = [col for col in feature_pd.columns if col.startswith('f_')]
feature_pd = feature_pd.dropna(subset=['label'])  # drop rows without label

# Standardize features
features = feature_pd[feature_columns].fillna(0.0).astype(float)
features_scaled = StandardScaler().fit_transform(features)

# UMAP
# reducer = umap.UMAP(n_neighbors=10, min_dist=0.5, metric='cosine', random_state=42)
reducer = umap.UMAP(
    n_neighbors=5,
    min_dist=0.5,
    metric='cosine',
    random_state=42,
    spread=1.5  
)
embedding = reducer.fit_transform(features_scaled)


custom_label_order = [
    'PARENT', 'TREM2_KO', 'R47H', 'H157Y', 'PLCG2_KO',
    'P522R', 'P522R_HET', 'SHIP1_KO', 'ABI3_KO', 'S209F'
]
label2id = {name: idx for idx, name in enumerate(custom_label_order)}
id2label = {v: k for k, v in label2id.items()}

# Replace label with class name 
if np.issubdtype(feature_pd['label'].dtype, np.integer):
    feature_pd['class_name'] = feature_pd['label'].map(id2label)
else:
    feature_pd['class_name'] = feature_pd['label']

plt.figure(figsize=(10, 8))
sns.scatterplot(
    x=embedding[:, 0], y=embedding[:, 1],
    hue=feature_pd['class_name'],
    palette='tab10',
    s=40,
    edgecolor='none',
    alpha=0.9
)
plt.title('UMAP projection of features', fontsize=14)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xlabel('UMAP-1')
plt.ylabel('UMAP-2')
plt.tight_layout()

plt.savefig(os.path.join(output_dir, 'all_by_label.png'), dpi=300)
plt.show()

## UMAP with plate id

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import umap
from sklearn.preprocessing import StandardScaler

feature_csv_path = '/home/featurize/features/all.csv'
output_dir = '/home/featurize/results/umap'
os.makedirs(output_dir, exist_ok=True)

# Load features
feature_pd = pd.read_csv(feature_csv_path)
feature_columns = [col for col in feature_pd.columns if col.startswith('f_')]
feature_pd = feature_pd.dropna(subset=['plate_id'])  # drop rows without plate_id

features = feature_pd[feature_columns].fillna(0.0).astype(float)
features_scaled = StandardScaler().fit_transform(features)

# UMAP 
#reducer = umap.UMAP(n_neighbors=5, min_dist=0.5, metric='cosine', random_state=42)
reducer = umap.UMAP(
    n_neighbors=5,
    min_dist=0.5,
    metric='cosine',
    random_state=42,
    spread=1.5  
)
embedding = reducer.fit_transform(features_scaled)

# Process plate_id as a categorical label

# Ensure plate_id is of string type for classification
# feature_pd['plate_label'] = 'P' + feature_pd['plate_id'].astype(str)
feature_pd['combined_label'] = 'F' + feature_pd['source'].astype(str) + '_P' + feature_pd['plate_id'].astype(str)

# Define legend order

# Create the desired order list
desired_order = []
for f in [1, 2, 3]:  # F1, F2, F3
    for p in [1, 2, 3]:  # P1, P2, P3
        desired_order.append(f'F{f}_P{p}')

print("Desired order:", desired_order)

# Convert combined_label into an ordered categorical variable
feature_pd['combined_label'] = pd.Categorical(
    feature_pd['combined_label'], 
    categories=desired_order, 
    ordered=True
)

# Visulization
plt.figure(figsize=(10, 8))
sns.scatterplot(
    x=embedding[:, 0], y=embedding[:, 1],
    hue=feature_pd['combined_label'],
    palette='Set1',
    s=40,
    edgecolor='none',
    alpha=0.9
)
plt.title('UMAP projection of features (colored by Dataset_Plate)', fontsize=14)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Dataset_Plate')
plt.xlabel('UMAP-1')
plt.ylabel('UMAP-2')
plt.tight_layout()


#plt.figure(figsize=(10, 8))
#sns.scatterplot(
#    x=embedding[:, 0], y=embedding[:, 1],
#    hue=feature_pd['combined_label'],
#    palette='Set1',  
#    s=40,
#    edgecolor='none',
#    alpha=0.9
#)
#plt.title('UMAP projection of features', fontsize=14)
#plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Plate ID')
#plt.xlabel('UMAP-1')
#plt.ylabel('UMAP-2')
#plt.tight_layout()


plt.savefig(os.path.join(output_dir, 'all_by_plate.png'), dpi=300)
plt.show()

# Distribution of Output
#print("Plate distribution in the data:")
#plate_counts = feature_pd['plate_label'].value_counts().sort_index()
#for plate, count in plate_counts.items():
#    print(f"  {plate}: {count} samples")

print("Dataset-Plate distribution in the data:")
combined_counts = feature_pd['combined_label'].value_counts().sort_index()
for label, count in combined_counts.items():
    print(f"  {label}: {count} samples")

print(f"Total samples: {len(feature_pd)}")
print(f"Number of plates: {len(plate_counts)}")