In [55]:
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE

tf.__version__

'2.19.0'

# Models for classifiying cell type

## Problem statement

**Objective**:
The objective of this task is to classify cell type data, which consists of four distinct categories: a, b, c, and others. This problem is framed as a multi-class classification problem, where each sample (an image of a cell) must be assigned one of the four labels based on its features. The dataset is imbalanced, meaning certain classes have far fewer samples than others, which introduces challenges in model evaluation.

**Metric Selection**:
Given the class imbalance present in the dataset, accuracy alone is not a sufficient evaluation metric. In this context, the weighted F1 score is the preferred metric. The F1 score is the harmonic mean of precision and recall, and it balances these two metrics to provide a more comprehensive evaluation.
Since we have an imbalanced dataset, the weighted F1 score is particularly useful because it gives more weight to the larger classes while still considering performance on the smaller classes. This helps ensure that the model performs well across all classes, not just the majority class, and avoids the bias that could arise from imbalanced class distributions. The weighted F1 score is calculated by averaging the F1 scores of each class, weighted by the number of instances in that class. This approach ensures that the model’s performance on minority classes is not overlooked.

**Baseline Model**:
For a baseline model, we will start with a Random Forest classifier. Random Forest is chosen because it is a robust, interpretable model well-suited for classification tasks, and it performs well on medium-sized datasets. Random Forest does not require extensive hyperparameter tuning and can provide a solid baseline for comparison with more advanced models.

**Improvement through Neural Networks**:
To improve upon the baseline, we will experiment with Neural Networks (NN) and Convolutional Neural Networks (CNNs). Neural networks, particularly CNNs, are highly specialized for image classification tasks as they can automatically learn hierarchical features and patterns from image data. CNNs have been shown to outperform traditional machine learning models in many image-related tasks due to their ability to capture spatial relationships in images. These models are expected to provide a significant boost in performance compared to the Random Forest baseline.

## Baseline development

In [56]:
# For the random forest classifier, as the image does not represent, so we need to extract the features for classifying.

In [66]:
import pandas as pd
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder
from skimage.feature import hog, local_binary_pattern
from skimage.color import rgb2gray
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import cv2

# Function to load and preprocess image data
def load_and_extract_features(image_path, size=(27, 27)):
    """
    Load an image, resize it, and extract features using HOG and LBP
    """
    try:
        # Load image (assuming grayscale medical images)
        img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            
        # 1. Basic statistical features
        mean_intensity = np.mean(img)
        std_intensity = np.std(img)
        min_intensity = np.min(img)
        max_intensity = np.max(img)
        
        # 2. Histogram features
        hist = cv2.calcHist([img], [0], None, [32], [0, 256])
        hist_features = hist.flatten() / np.sum(hist)  # Normalize
        
        # 3. Texture features using edge detection as proxy
        edges = cv2.Canny(img, 100, 200)
        edge_density = np.sum(edges > 0) / (img.shape[0] * img.shape[1])
        
        # 4. Shape features
        # Apply thresholding
        _, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Calculate area and perimeter if contours exist
        if contours:
            largest_contour = max(contours, key=cv2.contourArea)
            area = cv2.contourArea(largest_contour)
            perimeter = cv2.arcLength(largest_contour, True)
            circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
        else:
            area = 0
            perimeter = 0
            circularity = 0
        
        # Combine all features
        image_features = [
            mean_intensity, std_intensity, min_intensity, max_intensity,
            edge_density, area, perimeter, circularity
        ]
        
        # Add histogram features
        image_features.extend(hist_features)
        
        # features.append(image_features)
        
        # # Combine all features
        # all_features = np.concatenate([hog_features, lbp_hist, color_stats])
        return image_features
        
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        # Return a vector of zeros in case of error
        return np.zeros(324) 

In [67]:
celltype_data = pd.read_csv('./data/train_cell_type.csv')
celltype_data.head()

Unnamed: 0,InstanceID,patientID,ImageName,cellTypeName,cellType,isCancerous,ImagePath,combined_label
0,19035,2,19035.png,fibroblast,0,0,./data/preprocessed/train/19035.png,0_0
1,19036,2,19036.png,fibroblast,0,0,./data/preprocessed/train/19036.png,0_0
2,19037,2,19037.png,fibroblast,0,0,./data/preprocessed/train/19037.png,0_0
3,19038,2,19038.png,fibroblast,0,0,./data/preprocessed/train/19038.png,0_0
4,19039,2,19039.png,fibroblast,0,0,./data/preprocessed/train/19039.png,0_0


In [68]:
# Load your existing training data
celltype_data = pd.read_csv('./data/train_cell_type.csv')
print("Dataset loaded with shape:", celltype_data.shape)
print("Sample of data:")
print(celltype_data.head())

# Check class distribution
print("\nCell type distribution:")
print(celltype_data['cellType'].value_counts())

# Check patient distribution
print("\nPatient distribution:")
print(f"Total unique patients: {celltype_data['patientID'].nunique()}")
print(f"Average images per patient: {celltype_data.shape[0] / celltype_data['patientID'].nunique():.2f}")

# Check class distribution per patient
patient_class_counts = celltype_data.groupby('patientID')['cellType'].value_counts().unstack().fillna(0)
print("\nClass distribution per patient (sample):")
print(patient_class_counts.head())

# Analyze if patients have multiple cell types
patients_with_multiple_types = patient_class_counts[patient_class_counts.sum(axis=1) > 1].shape[0]
print(f"\nPatients with multiple cell types: {patients_with_multiple_types} ({patients_with_multiple_types/celltype_data['patientID'].nunique()*100:.2f}%)")


Dataset loaded with shape: (12112, 8)
Sample of data:
   InstanceID  patientID  ImageName cellTypeName  cellType  isCancerous  \
0       19035          2  19035.png   fibroblast         0            0   
1       19036          2  19036.png   fibroblast         0            0   
2       19037          2  19037.png   fibroblast         0            0   
3       19038          2  19038.png   fibroblast         0            0   
4       19039          2  19039.png   fibroblast         0            0   

                             ImagePath combined_label  
0  ./data/preprocessed/train/19035.png            0_0  
1  ./data/preprocessed/train/19036.png            0_0  
2  ./data/preprocessed/train/19037.png            0_0  
3  ./data/preprocessed/train/19038.png            0_0  
4  ./data/preprocessed/train/19039.png            0_0  

Cell type distribution:
cellType
0    3028
1    3028
3    3028
2    3028
Name: count, dtype: int64

Patient distribution:
Total unique patients: 48
Average im

In [69]:
def print_recommendation():
    y = celltype_data['cellType'].values
    patient_ids = celltype_data['patientID'].values
    num_classes = len(np.unique(y))
    num_patients = len(np.unique(patient_ids))
    samples_per_class = np.bincount(y.astype(int))
    min_samples = np.min(samples_per_class)
    
    print("\n=== RECOMMENDATION ===")
    
    if num_patients < 10:
        print("You have a small number of patients (<10). K-fold cross-validation is strongly recommended.")
        recommended = "k-fold"
    elif min_samples < 20:
        print("You have few samples (<20) for at least one class. K-fold cross-validation is recommended.")
        recommended = "k-fold"
    elif patients_with_multiple_types / num_patients > 0.5:
        print("Most patients have multiple cell types. A simple train/validation split should be adequate.")
        recommended = "simple"
    else:
        print("Your dataset size is reasonable. Both approaches are valid.")
        print("For more reliable performance estimation, use K-fold cross-validation.")
        print("For faster training and iteration, use a simple train/validation split.")
        recommended = "either"
    
    # Final recommendation
    if recommended == "k-fold":
        print("\nFINAL RECOMMENDATION: Use patient-stratified K-fold cross-validation.")
    elif recommended == "simple":
        print("\nFINAL RECOMMENDATION: Use a simple patient-stratified train/validation split.")
    else:
        print("\nFINAL RECOMMENDATION: Either approach is valid. Use simple split for speed, K-fold for reliability.")
    
    # Additional notes
    print("\nNOTES:")
    print("- Always ensure patients are not split between train and validation sets")
    print("- Consider class imbalance in your model using class_weight='balanced'")
    print("- If you're tuning hyperparameters, use nested cross-validation or a separate test set")

print_recommendation()


=== RECOMMENDATION ===
Most patients have multiple cell types. A simple train/validation split should be adequate.

FINAL RECOMMENDATION: Use a simple patient-stratified train/validation split.

NOTES:
- Always ensure patients are not split between train and validation sets
- Consider class imbalance in your model using class_weight='balanced'
- If you're tuning hyperparameters, use nested cross-validation or a separate test set


In [70]:
def create_simple_validation_split(patient_ids, y, test_size=0.2, random_state=42):
    """
    Create a single validation split ensuring patients are not split between sets,
    while trying to maintain class distribution.
    """
    # Get unique patient IDs
    unique_patients = np.unique(patient_ids)
    
    # Calculate average class per patient (for stratification)
    patient_class_mapping = {}
    for patient in unique_patients:
        patient_mask = patient_ids == patient
        patient_classes = y[patient_mask]
        # Use most frequent class for this patient
        if len(patient_classes) > 0:
            patient_class_mapping[patient] = np.bincount(patient_classes.astype(int)).argmax()
        else:
            patient_class_mapping[patient] = -1  # No data for this patient
    
    # Create stratification array
    patient_strata = np.array([patient_class_mapping[p] for p in unique_patients])
    
    # Split patients into train and validation groups
    patients_train, patients_val = train_test_split(
        unique_patients, 
        test_size=test_size, 
        random_state=random_state,
        stratify=patient_strata if len(np.unique(patient_strata)) > 1 else None
    )
    
    # Create masks for train and validation sets
    train_mask = np.isin(patient_ids, patients_train)
    val_mask = np.isin(patient_ids, patients_val)
    
    return train_mask, val_mask, patients_train, patients_val

In [71]:
features = []
for idx, row in celltype_data.iterrows():
    try:
        image_path = row['ImagePath']
        feature_vector = load_and_extract_features(image_path)
        if np.isnan(feature_vector).sum() > 0:
            print(image_path)
        features.append(feature_vector)
    except Exception as e:
        print("can't load image: " + e)
X = np.array(features)
print(np.isnan(X).sum())

Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete
Shape features 
complete


In [72]:
patient_ids = celltype_data['patientID'].values
y = celltype_data['cellType'].values
train_mask, val_mask, patients_train, patients_val = create_simple_validation_split(patient_ids, y)
X_train, X_val = X[train_mask], X[val_mask]
y_train, y_val = y[train_mask], y[val_mask]

In [73]:
rd_clf = RandomForestClassifier()
rd_clf.fit(X_train, y_train)
print("Evaluating on validation set (simulated)...")
y_pred = rd_clf.predict(X_val)
print(classification_report(y_val, y_pred))

Evaluating on validation set (simulated)...
              precision    recall  f1-score   support

           0       0.56      0.33      0.41       874
           1       0.50      0.43      0.46       658
           2       0.50      0.70      0.58       498
           3       0.32      0.48      0.39       505

    accuracy                           0.46      2535
   macro avg       0.47      0.48      0.46      2535
weighted avg       0.48      0.46      0.45      2535



In [81]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from tensorflow.keras.metrics import Precision, Recall

from tensorflow.keras import layers, models, regularizers

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Load and prepare the data
celltype_data = pd.read_csv('./data/train_cell_type.csv')
print(f"Dataset shape: {celltype_data.shape}")

# Explore unique cell types for our multiclass classification
unique_cell_types = celltype_data['cellType'].unique()
print(f"Unique cell types: {unique_cell_types}")
print(f"Number of unique cell types: {len(unique_cell_types)}")

# Check class distribution
class_distribution = celltype_data['cellType'].value_counts()
print("Class distribution:")
print(class_distribution)

# Using the existing cellType values (0,1,2,3) instead of encoding
n_classes = len(unique_cell_types)

# Create patient-based train-validation split using the provided function
patient_ids = celltype_data['patientID'].values
labels = celltype_data['cellType'].values

# Split into train (80%) and validation (20%)
train_mask, val_mask, train_patients, val_patients = create_simple_validation_split(
    patient_ids, labels, test_size=0.2, random_state=42
)

# Create dataframes for each set
train_df = celltype_data[train_mask]
val_df = celltype_data[val_mask]

print(f"Training set: {len(train_df)} samples from {len(train_patients)} patients")
print(f"Validation set: {len(val_df)} samples from {len(val_patients)} patients")

# Function to load and preprocess images
def load_images(dataframe, img_size=(27, 27)):
    images = []
    for img_path in dataframe['ImagePath']:
        # Check if path exists and is valid
        if os.path.exists(img_path):
            img = load_img(img_path, color_mode='rgb', target_size=img_size)
            img_array = img_to_array(img)
            images.append(img_array)
        else:
            print(f"Warning: Image not found at {img_path}")
            # Add a blank image as a placeholder
            images.append(np.zeros((img_size[0], img_size[1], 3)))
    
    return np.array(images)

# Load images
print("Loading and preprocessing images...")
X_train = load_images(train_df)
X_val = load_images(val_df)

# Normalize pixel values to [0, 1]
# X_train = X_train / 255.0
# X_val = X_val / 255.0

# Get labels
y_train = train_df['cellType'].values
y_val = val_df['cellType'].values

# Convert labels to one-hot encoding
y_train_onehot = tf.keras.utils.to_categorical(y_train, num_classes=n_classes)
y_val_onehot = tf.keras.utils.to_categorical(y_val, num_classes=n_classes)

print(f"Input shape: {X_train.shape}")
print(f"Output shape: {y_train_onehot.shape}")

# ======================================================
# 1. CNN MODEL IMPLEMENTATION
# ======================================================

# def build_cnn_model(input_shape=(27, 27, 3), n_classes=n_classes):
#     model = models.Sequential([
#         # First Convolutional Block
#         layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
#         layers.BatchNormalization(),
#         layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
#         layers.BatchNormalization(),
#         layers.MaxPooling2D((2, 2)),
#         layers.Dropout(0.25),
        
#         # Second Convolutional Block
#         layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
#         layers.BatchNormalization(),
#         layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
#         layers.BatchNormalization(),
#         layers.MaxPooling2D((2, 2)),
#         layers.Dropout(0.25),
        
#         # Third Convolutional Block
#         layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
#         layers.BatchNormalization(),
#         layers.Dropout(0.25),
        
#         # Flatten and Dense layers
#         layers.Flatten(),
#         layers.Dense(256, activation='relu'),
#         layers.BatchNormalization(),
#         layers.Dropout(0.5),
#         layers.Dense(n_classes, activation='softmax')
#     ])
    
#     return model

def build_enhanced_cnn_model(input_shape=(27, 27, 3), n_classes=4):
    """
    Build an enhanced CNN model optimized for histopathology image classification.
    
    The architecture is designed for small (27x27) images and includes:
    - Residual connections
    - Spatial attention mechanism
    - Advanced regularization techniques
    - Feature pyramid network elements
    - Squeeze-and-excitation blocks
    
    Args:
        input_shape (tuple): Input shape of the images (height, width, channels)
        n_classes (int): Number of classes for classification
        
    Returns:
        tf.keras.Model: Compiled model ready for training
    """
    
    # Input layer
    inputs = layers.Input(shape=input_shape)
    
    # Initial normalization
    x = layers.Rescaling(1./255)(inputs)
    
    # Enhanced First Block with Residual Connection
    conv1 = layers.Conv2D(32, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-4))(x)
    conv1 = layers.LeakyReLU(alpha=0.1)(conv1)
    conv1 = layers.BatchNormalization()(conv1)
    conv1 = layers.Conv2D(32, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-4))(conv1)
    conv1 = layers.LeakyReLU(alpha=0.1)(conv1)
    conv1 = layers.BatchNormalization()(conv1)
    
    # Residual connection
    res1 = layers.Conv2D(32, (1, 1), padding='same')(x)
    conv1 = layers.add([conv1, res1])
    
    # Spatial attention
    attention1 = layers.Conv2D(1, (1, 1), padding='same', activation='sigmoid')(conv1)
    conv1 = layers.multiply([conv1, attention1])
    
    # Pooling
    pool1 = layers.MaxPooling2D((2, 2))(conv1)
    pool1 = layers.Dropout(0.2)(pool1)
    
    # Enhanced Second Block with Residual Connection
    conv2 = layers.Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-4))(pool1)
    conv2 = layers.LeakyReLU(alpha=0.1)(conv2)
    conv2 = layers.BatchNormalization()(conv2)
    conv2 = layers.Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-4))(conv2)
    conv2 = layers.LeakyReLU(alpha=0.1)(conv2)
    conv2 = layers.BatchNormalization()(conv2)
    
    # Residual connection
    res2 = layers.Conv2D(64, (1, 1), padding='same')(pool1)
    conv2 = layers.add([conv2, res2])
    
    # Squeeze-and-Excitation block
    se2 = layers.GlobalAveragePooling2D()(conv2)
    se2 = layers.Dense(64 // 4, activation='relu')(se2)
    se2 = layers.Dense(64, activation='sigmoid')(se2)
    se2 = layers.Reshape((1, 1, 64))(se2)
    conv2 = layers.multiply([conv2, se2])
    
    # Pooling
    pool2 = layers.MaxPooling2D((2, 2))(conv2)
    pool2 = layers.Dropout(0.3)(pool2)
    
    # Enhanced Third Block with Dilated Convolutions
    # Dilated convolutions increase receptive field without reducing spatial dimensions
    conv3 = layers.Conv2D(128, (3, 3), padding='same', dilation_rate=(2, 2), 
                         kernel_regularizer=regularizers.l2(1e-4))(pool2)
    conv3 = layers.LeakyReLU(alpha=0.1)(conv3)
    conv3 = layers.BatchNormalization()(conv3)
    
    # Mix of different kernel sizes for better feature extraction
    conv3_1 = layers.Conv2D(64, (1, 1), padding='same', kernel_regularizer=regularizers.l2(1e-4))(conv3)
    conv3_3 = layers.Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-4))(conv3)
    
    # Concatenate different kernel outputs
    conv3 = layers.concatenate([conv3_1, conv3_3])
    conv3 = layers.BatchNormalization()(conv3)
    conv3 = layers.Dropout(0.3)(conv3)
    
    # Spatial Pyramid Pooling to handle spatial variations
    # For small images, we'll use small pool sizes
    spp1 = layers.GlobalAveragePooling2D()(conv3)
    spp1 = layers.Reshape((1, 1, 128))(spp1)
    spp1 = layers.UpSampling2D(size=(conv3.shape[1], conv3.shape[2]))(spp1)
    
    spp2 = layers.AveragePooling2D(pool_size=(2, 2))(conv3)
    spp2 = layers.UpSampling2D(size=(2, 2))(spp2)
    
    # Concatenate SPP features with conv3
    concat = layers.concatenate([conv3, spp1, spp2])
    
    # Global Features
    flatten = layers.Flatten()(concat)
    
    # First Dense block with residual connection
    dense1 = layers.Dense(256)(flatten)
    dense1 = layers.LeakyReLU(alpha=0.1)(dense1)
    dense1 = layers.BatchNormalization()(dense1)
    
    # Second Dense block
    dense2 = layers.Dense(128)(dense1)
    dense2 = layers.LeakyReLU(alpha=0.1)(dense2)
    dense2 = layers.BatchNormalization()(dense2)
    dense2 = layers.Dropout(0.4)(dense2)
    
    # Output layer
    outputs = layers.Dense(n_classes, activation='softmax')(dense2)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    return model

# ======================================================
# 2. FULLY CONNECTED NEURAL NETWORK IMPLEMENTATION
# ======================================================

def build_fcnn_model(input_shape=(27, 27, 3), n_classes=n_classes):
    # Calculate the flattened input size
    flat_size = input_shape[0] * input_shape[1] * input_shape[2]
    
    model = models.Sequential([
        # Flatten the input images
        layers.Flatten(input_shape=input_shape),
        
        # First dense layer
        layers.Dense(1024, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Second dense layer
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Third dense layer
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Output layer
        layers.Dense(n_classes, activation='softmax')
    ])
    
    return model

import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

def create_per_class_metrics(num_classes):
    """
    Create precision and recall metrics for each class.
    
    Args:
        num_classes (int): Number of classes in the classification problem
    
    Returns:
        list: List of precision and recall metrics for each class
    """
    metrics = []
    
    # Add overall metrics
    metrics.extend([
        tf.keras.metrics.Accuracy(name='accuracy'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall')
    ])
    
    # Add per-class metrics
    for i in range(num_classes):
        metrics.extend([
            tf.keras.metrics.Precision(name=f'precision_class_{i}'),
            tf.keras.metrics.Recall(name=f'recall_class_{i}')
        ])
    
    return metrics

def train_and_evaluate_model(model, model_name, X_train, y_train, X_val, y_val, 
                              epochs=300, batch_size=32, n_classes=None):
    """
    Train and evaluate a neural network model with comprehensive logging and visualization.
    
    Args:
        model (keras.Model): Compiled neural network model
        model_name (str): Name of the model for logging and plotting
        X_train (numpy.ndarray): Training input data
        y_train (numpy.ndarray): Training target data
        X_val (numpy.ndarray): Validation input data
        y_val (numpy.ndarray): Validation target data
        epochs (int, optional): Number of training epochs. Defaults to 50.
        batch_size (int, optional): Batch size for training. Defaults to 32.
        n_classes (int, optional): Number of classes for confusion matrix. Defaults to None.
    
    Returns:
        tuple: Trained model and training history
    """
    # Determine number of classes
    if n_classes is None:
        n_classes = y_train.shape[1] if len(y_train.shape) > 1 else len(np.unique(y_train))
    
    # Create metrics
    metrics = create_per_class_metrics(n_classes)
    
    # Compile the model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=metrics
    )
    
    # Set up callbacks
    callbacks = [
        # EarlyStopping(
        #     patience=10, 
        #     restore_best_weights=True, 
        #     monitor='val_loss',
        #     min_delta=1e-4
        # ),
        # ReduceLROnPlateau(
        #     factor=0.5, 
        #     patience=5, 
        #     min_lr=1e-6, 
        #     monitor='val_loss',
        #     verbose=1
        # ),
        # ModelCheckpoint(
        #     f'best_{model_name}.h5', 
        #     save_best_only=True, 
        #     monitor='val_loss',
        #     verbose=1
        # )
    ]
    
    # Train the model
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=callbacks,
        verbose=1
    )
    
    # Plot training history
    plot_training_history(history, model_name)
    
    # Evaluate the model
    eval_results = model.evaluate(X_val, y_val)
    
    # Print metrics dynamically based on the order of metrics
    metric_names = ['loss', 'accuracy', 'precision', 'recall'] + \
                   [f'precision_class_{i}' for i in range(n_classes)] + \
                   [f'recall_class_{i}' for i in range(n_classes)]
    
    print(f"\n{model_name} Validation Metrics:")
    for name, value in zip(metric_names, eval_results):
        print(f"{name}: {value:.4f}")
    
    # Generate predictions
    y_pred_probs = model.predict(X_val)
    y_pred = np.argmax(y_pred_probs, axis=1)
    y_true = np.argmax(y_val, axis=1)
    
    # Print classification report
    print(f"\n{model_name} Classification Report:")
    print(classification_report(y_true, y_pred))
    
    # Generate confusion matrix
    conf_matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        conf_matrix, 
        annot=True, 
        fmt='d', 
        cmap='viridis',
        xticklabels=range(n_classes),
        yticklabels=range(n_classes)
    )
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'{model_name} Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
    return model, history

def plot_training_history(history, model_name):
    """
    Create a comprehensive visualization of model training history.
    
    Args:
        history (keras.callbacks.History): Training history object
        model_name (str): Name of the model for plot titles
    """
    # Create a figure with subplots for different metrics
    fig, axs = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'{model_name} Training Metrics', fontsize=16)
    
    # Metrics to plot with their corresponding labels
    metrics = [
        ('loss', 'Loss'),
        ('accuracy', 'Accuracy'),
        ('precision', 'Precision'),
        ('recall', 'Recall')
    ]
    
    # Plot each metric
    for (metric_key, metric_title), ax in zip(metrics, axs.ravel()):
        # Train metric
        train_key = metric_key
        val_key = f'val_{metric_key}'
        
        # Check if metrics exist in history
        if train_key in history.history and val_key in history.history:
            ax.plot(history.history[train_key], label=f'Train {metric_title}')
            ax.plot(history.history[val_key], label=f'Validation {metric_title}')
            ax.set_title(f'{metric_title}')
            ax.set_xlabel('Epoch')
            ax.set_ylabel(metric_title)
            ax.legend()
        else:
            ax.text(0.5, 0.5, f'No {metric_key} data available', 
                    horizontalalignment='center', verticalalignment='center')
            ax.set_title(f'{metric_title} (Not Tracked)')
    
    # Adjust layout and display
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    # Optional: Learning Rate Plot if ReduceLROnPlateau is used
    if 'lr' in history.history:
        plt.figure(figsize=(10, 5))
        plt.plot(history.history['lr'])
        plt.title(f'{model_name} Learning Rate')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.show()


# ======================================================
# TRAIN BOTH MODELS
# ======================================================

# Create and train CNN model
print("\n\n==== TRAINING CNN MODEL ====")
cnn_model = build_enhanced_cnn_model()
cnn_model.summary()
cnn_model, cnn_history = train_and_evaluate_model(
    cnn_model, 'CNN', X_train, y_train_onehot, X_val, y_val_onehot
)

# Create and train FCNN model
print("\n\n==== TRAINING FULLY CONNECTED NEURAL NETWORK ====")
fcnn_model = build_fcnn_model()
fcnn_model.summary()
fcnn_model, fcnn_history = train_and_evaluate_model(
    fcnn_model, 'FCNN', X_train, y_train_onehot, X_val, y_val_onehot
)

# ======================================================
# HYPERPARAMETER TUNING RECOMMENDATIONS
# ======================================================

print("\n\n==== HYPERPARAMETER TUNING RECOMMENDATIONS ====")
print("""
Hyperparameter tuning recommendations for cell type classification:

1. Learning Rate:
   - Initial value: 0.001 (current setting)
   - Tuning range: [0.0001, 0.01]
   - Justification: The learning rate controls step size during optimization. 
     Too high may cause divergence, too low leads to slow convergence. The optimal 
     learning rate depends on the specific dataset characteristics and model architecture.
   - Implementation: Use ReduceLROnPlateau callback to dynamically reduce learning rate 
     when validation metrics plateau.

2. Batch Size:
   - Initial value: 32 (current setting)
   - Tuning range: [16, 64]
   - Justification: Smaller batch sizes provide more parameter updates and potentially better 
     generalization but slower training. Larger batch sizes enable better gradient estimates 
     and faster training but may lead to poorer generalization.
   - Implementation: Grid search different batch sizes and measure validation performance.

3. Network Depth:
   - CNN: Test with 2-4 convolutional blocks
   - FCNN: Test with 2-5 dense layers
   - Justification: Deeper networks can learn more complex representations but require more 
     data and are prone to overfitting. For 27x27 images, extremely deep networks may be 
     unnecessary and could lead to overfitting.

4. Dropout Rate:
   - Current values: 0.25 (convolutional layers), 0.5 (dense layers)
   - Tuning range: [0.1, 0.5]
   - Justification: Dropout is a regularization technique to prevent overfitting. The optimal 
     rate depends on model complexity and dataset size. Early layers typically need less 
     dropout than later layers.

5. Number of Filters/Neurons:
   - CNN filters: Test [16, 32, 64] for first layer, doubling in each subsequent layer
   - FCNN neurons: Test different architectures like [512, 256], [1024, 512, 256], etc.
   - Justification: The number of filters/neurons controls model capacity. Too few may 
     result in underfitting, while too many can lead to overfitting and increased computational cost.

6. Data Augmentation Parameters:
   - Rotation range: 15-30 degrees
   - Width/height shift: 0.1-0.2
   - Zoom range: 0.1-0.15
   - Horizontal/vertical flip: True/False depending on cell orientation importance
   - Justification: Data augmentation increases effective training set size and improves 
     model generalization. The parameters should be tuned to match expected variations in the data 
     while preserving class-specific features.

7. Early Stopping Patience:
   - Current value: 10 epochs
   - Tuning range: [5, 15]
   - Justification: Patience determines how many epochs to wait for improvement before stopping 
     training. Too short may stop training prematurely, while too long wastes computational resources.

8. Optimizer:
   - Current choice: Adam
   - Alternatives: RMSprop, SGD with momentum
   - Justification: Different optimizers have different convergence properties. Adam generally 
     works well for most problems, but alternatives may perform better in specific cases.

Recommended Tuning Approach:
1. Start with learning rate and batch size tuning as these often have the largest impact
2. Proceed to network architecture tuning (depth, width)
3. Fine-tune regularization parameters (dropout)
4. Optimize data augmentation parameters
5. Consider ensemble methods combining CNN and FCNN predictions for potentially better performance
""")

# Function to visualize model architecture
def visualize_model_architecture(model, model_name):
    from tensorflow.keras.utils import plot_model
    plot_model(model, to_file=f'{model_name}_architecture.png', show_shapes=True, show_layer_names=True)
    print(f"Model architecture saved as {model_name}_architecture.png")

# Uncomment to visualize model architectures
# visualize_model_architecture(cnn_model, 'CNN')
# visualize_model_architecture(fcnn_model, 'FCNN')

# ======================================================
# DATA AUGMENTATION IMPLEMENTATION (FOR REFERENCE)
# ======================================================

print("\n\n==== DATA AUGMENTATION IMPLEMENTATION ====")
print("""
# Example implementation of data augmentation:

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Create data augmentation generator
datagen = ImageDataGenerator(
    rotation_range=20,           # Random rotations (0-20 degrees)
    width_shift_range=0.15,      # Random horizontal shifts
    height_shift_range=0.15,     # Random vertical shifts
    zoom_range=0.15,             # Random zoom
    horizontal_flip=True,        # Random horizontal flips
    vertical_flip=False,         # No vertical flips (cell orientation may matter)
    fill_mode='nearest'          # Strategy for filling points outside the input boundaries
)

# Fit the generator on the training data
datagen.fit(X_train)

# Train using the generator
model.fit(
    datagen.flow(X_train, y_train_onehot, batch_size=32),
    validation_data=(X_val, y_val_onehot),
    steps_per_epoch=len(X_train) // 32,
    epochs=50,
    callbacks=callbacks
)
""")

# Save final models
cnn_model.save('final_cnn_model.h5')
fcnn_model.save('final_fcnn_model.h5')
print("Models saved successfully!")

Dataset shape: (12112, 8)
Unique cell types: [0 1 3 2]
Number of unique cell types: 4
Class distribution:
cellType
0    3028
1    3028
3    3028
2    3028
Name: count, dtype: int64
Training set: 9577 samples from 38 patients
Validation set: 2535 samples from 10 patients
Loading and preprocessing images...
Input shape: (9577, 27, 27, 3)
Output shape: (9577, 4)


==== TRAINING CNN MODEL ====




Epoch 1/300
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.0000e+00 - loss: 1.3059 - precision: 0.5554 - precision_class_0: 0.5554 - precision_class_1: 0.5554 - precision_class_2: 0.5554 - precision_class_3: 0.5554 - recall: 0.4065 - recall_class_0: 0.4065 - recall_class_1: 0.4065 - recall_class_2: 0.4065 - recall_class_3: 0.4065 - val_accuracy: 0.0000e+00 - val_loss: 7.9307 - val_precision: 0.1964 - val_precision_class_0: 0.1964 - val_precision_class_1: 0.1964 - val_precision_class_2: 0.1964 - val_precision_class_3: 0.1964 - val_recall: 0.1964 - val_recall_class_0: 0.1964 - val_recall_class_1: 0.1964 - val_recall_class_2: 0.1964 - val_recall_class_3: 0.1964
Epoch 2/300
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 34ms/step - accuracy: 0.0000e+00 - loss: 1.0754 - precision: 0.6205 - precision_class_0: 0.6205 - precision_class_1: 0.6205 - precision_class_2: 0.6205 - precision_class_3: 0.6205 - recall: 0.4362 - recall_clas

KeyboardInterrupt: 

## Using extra data to improve the current accuracy

Analysis on the given data: 
The given data does not have the labels for the cell type but only with the cancerous status. So, one possible way is to train the embeddings via the encoder/decoder neural network (i.e. similar to what BERT do), then fine tuned on the given dataset. 

Running the transfer learning: https://www.tensorflow.org/tutorials/images/transfer_learning#create_the_base_model_from_the_pre-trained_convnets
