In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
# Change to MyDrive first
os.chdir('/content/drive/MyDrive')
print('Current working directory for cloning:', os.getcwd())

# Clone the repository directly into MyDrive
!git clone https://github.com/dungdinhhaha/ClassifyCell /content/drive/MyDrive/PhatHienTeBao

# Now change into the cloned repository for subsequent operations
os.chdir('/content/drive/MyDrive/PhatHienTeBao')
print('CWD:', os.getcwd())

## 2) Install Dependencies

In [None]:
!pip install -q opencv-python pillow scikit-learn matplotlib seaborn

import tensorflow as tf
import keras
import numpy as np

print(f'‚úÖ TensorFlow version: {tf.__version__}')
print(f'‚úÖ Keras version: {keras.__version__}')
print(f'‚úÖ NumPy version: {np.__version__}')

## 3) Load Images from Folders (Each Folder = 1 Class)

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import os

# Define label map
label_map = {
    0: "back_ground",
    1: 'ascus',
    2: 'asch',
    3: 'lsil',
    4: 'hsil',
    5: 'scc',
    6: 'agc',
    7: 'trichomonas',
    8: 'candida',
    9: 'flora',
    10: 'herps',
    11: 'actinomyces',
}

num_classes = len(label_map)
print(f'Number of classes: {num_classes}')
print('Label map:', label_map)

# Image folder path - thay ƒë·ªïi theo v·ªã tr√≠ folder c·ªßa b·∫°n
images_dir = '/content/drive/MyDrive/PhatHienTeBao/images'
print(f'\nLoading images from: {images_dir}')

In [None]:
def load_images_from_folders(images_dir, label_map, target_size=128):
    """
    Load images from folder structure:
    images_dir/
        1/  (ascus)
            image1.jpg
            image2.jpg
        2/  (asch)
            image3.jpg
        ...
    """
    all_images = []
    all_labels = []

    print(f'üîÑ Loading images from {images_dir}...')

    # Iterate through each class folder
    for class_id in sorted(label_map.keys()):
        class_name = label_map[class_id]
        class_dir = os.path.join(images_dir, str(class_id))

        if not os.path.exists(class_dir):
            print(f'‚ö†Ô∏è  Folder not found: {class_dir}')
            continue

        # Get all image files in this class folder
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.PNG']:
            image_files.extend(Path(class_dir).glob(ext))

        print(f'\nLoading Class {class_id} ({class_name}): {len(image_files)} images')

        # Load each image
        for img_path in image_files:
            try:
                # Load image
                img = Image.open(img_path).convert('RGB')

                # Resize
                img = img.resize((target_size, target_size))

                # Convert to numpy - NO NORMALIZATION (model will do it)
                img_array = np.array(img, dtype=np.float32)

                all_images.append(img_array)
                all_labels.append(class_id)
            except Exception as e:
                print(f'  ‚ùå Error loading {img_path}: {e}')

    # Convert to numpy arrays
    X = np.array(all_images)
    y = np.array(all_labels)

    print(f'\n‚úÖ Total images loaded: {len(all_images)}')
    print(f'üìä Dataset shape: X={X.shape}, y={y.shape}')
    print(f'üìä Pixel range: [{X.min():.1f}, {X.max():.1f}]')

    # Print class distribution
    print(f'\nüìä Class distribution:')
    unique, counts = np.unique(y, return_counts=True)
    for cls, cnt in zip(unique, counts):
        print(f'   Class {cls:2d} ({label_map.get(cls, "unknown"):15s}): {cnt:4d} samples')

    return X, y

# Load all images (224x224 for better transfer learning)
X, y = load_images_from_folders(images_dir, label_map, target_size=224)

## 4) Visualize Sample Images

In [None]:
# Show samples from each class
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
axes = axes.flatten()

for i, cls in enumerate(range(num_classes)):
    if i >= 12:
        break

    # Find first sample of this class
    idx = np.where(y == cls)[0]

    if len(idx) > 0:
        sample_idx = idx[0]
        axes[i].imshow(X[sample_idx])
        axes[i].set_title(f'Class {cls}: {label_map.get(cls, "unknown")}\n({len(idx)} samples)')
    else:
        axes[i].text(0.5, 0.5, f'Class {cls}\nNo samples',
                    ha='center', va='center', transform=axes[i].transAxes)

    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 5) Split Train/Val/Test & Data Augmentation

In [None]:
from sklearn.model_selection import train_test_split

# First split: 80% train+val, 20% test
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Second split: 80/20 of remaining = 64% train, 16% val
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.2, random_state=42, stratify=y_temp
)

print(f'Train set: {X_train.shape[0]} samples')
print(f'Val set:   {X_val.shape[0]} samples')
print(f'Test set:  {X_test.shape[0]} samples')

print('\n‚úÖ Data split complete!')

## 6) Build Classification Model (Transfer Learning)

In [None]:
from tensorflow.keras import layers, models
import tensorflow as tf

# Very light augmentation for cytology - preserve cell morphology
train_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.05),  # Very small rotation - preserve cell orientation
], name='augmentation')

def build_classifier(input_shape=(224, 224, 3), num_classes=12):
    # Use DenseNet121 - excellent for medical imaging with fine details
    base_model = tf.keras.applications.DenseNet121(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )

    # Freeze backbone completely first
    base_model.trainable = False

    # Input
    inputs = layers.Input(shape=input_shape)

    # Preprocessing - simple normalization [0,1]
    x = layers.Rescaling(1./255.0)(inputs)

    # Backbone
    x = base_model(x, training=False)

    # Dense classifier head for cytology
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)

    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = tf.keras.Model(inputs, outputs)

    return model, base_model

# Build model
model, base_model = build_classifier(num_classes=num_classes)

# Compile - use standard loss without label smoothing
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print('üìä Model: DenseNet121 (frozen) + Dense Head for Cytology')
model.summary()

## 7) Train Model

In [None]:
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf

# Calculate class weights
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights_dict = dict(enumerate(class_weights))

print('üìä Class weights:')
for cls, weight in class_weights_dict.items():
    print(f'   Class {cls} ({label_map.get(cls, "unknown")}): {weight:.3f}')

# Apply LIGHT augmentation to training data only
print('\nüîÑ Applying light augmentation to training data...')
X_train_aug = train_augmentation(X_train, training=True).numpy()
print(f'‚úÖ Training data augmented: {X_train_aug.shape}')

# Create model directory
model_save_dir = '/content/drive/MyDrive/PhatHienTeBao/models'
os.makedirs(model_save_dir, exist_ok=True)

# Callbacks
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(model_save_dir, 'best_cytology_model.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=20,
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.3,
        patience=8,
        min_lr=1e-7,
        verbose=1
    )
]

# PHASE 1: Train with frozen backbone - MORE EPOCHS for cytology
print('\nüöÄ PHASE 1: Training classifier head (backbone frozen)...\n')
print('   Note: Cervical cytology is challenging - may take 40-50 epochs\n')

history1 = model.fit(
    X_train_aug, y_train,
    validation_data=(X_val, y_val),
    epochs=60,  # Increased from 50 to 60
    batch_size=32,
    callbacks=callbacks,
    class_weight=class_weights_dict,
    verbose=1
)

best_val_acc = max(history1.history['val_accuracy'])
print(f'\n‚úÖ Phase 1 done! Best val accuracy: {best_val_acc:.2%}')

# PHASE 2: Unfreeze and fine-tune if Phase 1 accuracy > 30%
if best_val_acc > 0.30:
    print(f'\nüöÄ PHASE 2: Fine-tuning backbone (last 80 layers)...\n')
    base_model.trainable = True

    # Unfreeze last 80 layers (more aggressive fine-tuning)
    for layer in base_model.layers[:-80]:
        layer.trainable = False

    # Recompile with very low learning rate for medical imaging
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=2e-6),  # Even lower for better fine-tuning
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    history2 = model.fit(
        X_train_aug, y_train,
        validation_data=(X_val, y_val),
        epochs=40,  # More epochs for fine-tuning
        batch_size=16,
        callbacks=callbacks,
        class_weight=class_weights_dict,
        verbose=1
    )

    print(f'\n‚úÖ Phase 2 done! Best val accuracy: {max(history2.history["val_accuracy"]):.2%}')

    # Combine histories
    history = history1
    for key in ['loss', 'accuracy', 'val_loss', 'val_accuracy']:
        history.history[key].extend(history2.history[key])
else:
    print(f'\n‚ö†Ô∏è  Phase 1 accuracy: {best_val_acc:.2%}')
    print('   This is expected for cervical cytology - classes are very similar')
    print('   Model is learning but needs more epochs or better features')
    history = history1

print('\n‚úÖ Training complete!')
print(f'üìä Final best validation accuracy: {max(history.history["val_accuracy"]):.2%}')

## 8) Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss
ax1.plot(history.history['loss'], label='Train Loss', marker='o')
ax1.plot(history.history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(history.history['accuracy'], label='Train Accuracy', marker='o')
ax2.plot(history.history['val_accuracy'], label='Val Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training & Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f'\nüìä Final Results:')
print(f'   Train Accuracy: {history.history["accuracy"][-1]:.4f}')
print(f'   Val Accuracy:   {history.history["val_accuracy"][-1]:.4f}')
print(f'   Best Val Acc:   {max(history.history["val_accuracy"]):.4f}')

## 9) Evaluate on Test Set

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Predictions
y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)

# Get unique classes in test set
unique_classes = np.unique(y_test)
target_names = [label_map.get(i, f'Class {i}') for i in unique_classes]

# Classification report
print('Classification Report on TEST SET:')
print(classification_report(y_test, y_pred_classes,
                          labels=unique_classes,
                          target_names=target_names))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred_classes, labels=unique_classes)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
           xticklabels=[label_map.get(i, f'{i}') for i in unique_classes],
           yticklabels=[label_map.get(i, f'{i}') for i in unique_classes])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Test Set)')
plt.tight_layout()
plt.show()

## 10) Visualize Predictions on Test Set

In [None]:
# Show some test predictions
fig, axes = plt.subplots(4, 4, figsize=(15, 15))
axes = axes.flatten()

# Get random samples from test set
random_indices = np.random.choice(len(X_test), 16, replace=False)

for idx, sample_idx in enumerate(random_indices):
    img = X_test[sample_idx]
    true_class = y_test[sample_idx]

    # Normalize image for display [0, 1]
    img_display = np.clip(img / 255.0, 0, 1)

    # Predict
    pred = model.predict(img[None], verbose=0)[0]
    pred_class = np.argmax(pred)
    pred_conf = pred[pred_class]

    # Plot
    axes[idx].imshow(img_display)

    true_name = label_map.get(true_class, f'C{true_class}')
    pred_name = label_map.get(pred_class, f'C{pred_class}')

    status = '‚úì' if pred_class == true_class else '‚úó'
    color = 'green' if pred_class == true_class else 'red'

    title = f'{status} True: {true_name}\nPred: {pred_name} ({pred_conf:.2f})'
    axes[idx].set_title(title, color=color, fontweight='bold')
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

## 11) Save Model

In [None]:
# Save final model
model_save_dir = '/content/drive/MyDrive/PhatHienTeBao/models'
final_model_path = os.path.join(model_save_dir, 'cell_classifier_final.keras')
model.save(final_model_path)
print(f'‚úÖ Model saved to: {final_model_path}')

# List saved models
import os
if os.path.exists(model_save_dir):
    print(f'\nüìÇ Models in {model_save_dir}:')
    for f in os.listdir(model_save_dir):
        if f.endswith('.keras'):
            full_path = os.path.join(model_save_dir, f)
            size_mb = os.path.getsize(full_path) / (1024*1024)
            print(f'   ‚úì {f} ({size_mb:.2f} MB)')

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import ListedColormap
import numpy as np

if 'predictions' in locals() and len(predictions) > 0:
    # Create heat map showing abnormal cells
    print('üìä Creating visualization...')

    # Create confidence map for abnormal classes (non-background, non-flora, non-candida)
    # These are typically: ascus, asch, lsil, hsil, scc, agc, trichomonas, herps, actinomyces
    normal_classes = [0, 9, 8]  # background, flora, candida
    abnormal_classes = [1, 2, 3, 4, 5, 6, 7, 10, 11]  # All others

    # Create grid for heat map
    patch_size = 224
    stride = 112

    grid_width = width // stride
    grid_height = height // stride

    heat_map = np.zeros((grid_height, grid_width))
    confidence_map = np.zeros((grid_height, grid_width))
    class_map = np.zeros((grid_height, grid_width), dtype=int)

    for pred in predictions:
        grid_x = pred['x'] // stride
        grid_y = pred['y'] // stride

        # Mark abnormal cells with high confidence
        is_abnormal = pred['class'] in abnormal_classes
        heat_map[grid_y, grid_x] = 1.0 if is_abnormal else 0.0
        confidence_map[grid_y, grid_x] = pred['confidence']
        class_map[grid_y, grid_x] = pred['class']

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Heat map of abnormal cells
    axes[0].imshow(heat_map, cmap='RdYlGn_r', vmin=0, vmax=1)
    axes[0].set_title('Abnormal Cells Detection\n(Red = Abnormal, Green = Normal)')
    axes[0].set_xlabel('X')
    axes[0].set_ylabel('Y')
    plt.colorbar(axes[0].imshow(heat_map, cmap='RdYlGn_r', vmin=0, vmax=1), ax=axes[0])

    # Confidence map
    im = axes[1].imshow(confidence_map, cmap='hot')
    axes[1].set_title('Prediction Confidence')
    axes[1].set_xlabel('X')
    axes[1].set_ylabel('Y')
    plt.colorbar(im, ax=axes[1])

    # Class map
    im = axes[2].imshow(class_map, cmap='tab20', vmin=0, vmax=num_classes)
    axes[2].set_title('Cell Type Distribution')
    axes[2].set_xlabel('X')
    axes[2].set_ylabel('Y')
    plt.colorbar(im, ax=axes[2], label='Class')

    plt.tight_layout()
    plt.show()

    # Print statistics
    print(f'\nüìà Statistics:')
    print(f'   Total patches analyzed: {len(predictions)}')
    print(f'   Abnormal patches: {int(heat_map.sum())} ({100*heat_map.sum()/heat_map.size:.1f}%)')
    print(f'   Mean confidence: {confidence_map.mean():.3f}')
    print(f'   High confidence (>0.8): {(confidence_map > 0.8).sum()} patches')

    # ========== EXTRACT ABNORMAL CELLS ==========
    print(f'\nüî¥ ABNORMAL CELLS DETECTED:')
    print(f'   {"="*80}')

    abnormal_predictions = [p for p in predictions if p['class'] in abnormal_classes]
    print(f'   Total abnormal cells: {len(abnormal_predictions)}')
    print(f'   {"="*80}\n')

    # Sort by confidence (highest first)
    abnormal_predictions_sorted = sorted(abnormal_predictions,
                                        key=lambda x: x['confidence'],
                                        reverse=True)

    # Print coordinates and class info
    for idx, pred in enumerate(abnormal_predictions_sorted[:20], 1):  # Show top 20
        class_name = label_map.get(pred['class'], f"Class {pred['class']}")
        print(f"   {idx:2d}. Coords: ({pred['x']:5d}, {pred['y']:5d}) | "
              f"Class: {class_name:15s} | Confidence: {pred['confidence']:.3f}")

    if len(abnormal_predictions) > 20:
        print(f"   ... and {len(abnormal_predictions) - 20} more")

else:
    print('‚ö†Ô∏è  No predictions available. Run cell 12 first.')


## 14) Display Abnormal Cell Images

In [None]:
# Display abnormal cell patches with their coordinates and predictions
if 'predictions' in locals() and len(predictions) > 0:
    normal_classes = [0, 9, 8]
    abnormal_classes = [1, 2, 3, 4, 5, 6, 7, 10, 11]

    # Get abnormal predictions sorted by confidence
    abnormal_predictions = [p for p in predictions if p['class'] in abnormal_classes]
    abnormal_predictions_sorted = sorted(abnormal_predictions,
                                        key=lambda x: x['confidence'],
                                        reverse=True)

    num_to_show = min(12, len(abnormal_predictions_sorted))  # Show max 12

    print(f'üì∏ Displaying top {num_to_show} abnormal cells (highest confidence)...\n')

    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.flatten()

    # Re-open slide to extract patches
    try:
        slide = OpenSlide(svs_file)

        # Use detail_level if available (from smart zoom), otherwise level 0
        if 'detail_lvl' in locals() and detail_lvl is not None:
            use_level = detail_lvl
        elif 'level' in locals() and level is not None:
            use_level = level
        else:
            use_level = 0

        for idx, pred in enumerate(abnormal_predictions_sorted[:num_to_show]):
            x, y = pred['x'], pred['y']
            patch_size = 224

            # Read patch from SVS file
            region = slide.read_region((x, y), use_level, (patch_size, patch_size))
            patch = np.array(region.convert('RGB'), dtype=np.float32)

            # Normalize for display [0, 1]
            patch_display = np.clip(patch / 255.0, 0, 1)

            # Get prediction details
            class_id = pred['class']
            class_name = label_map.get(class_id, f'Class {class_id}')
            confidence = pred['confidence']

            # Plot
            axes[idx].imshow(patch_display)
            axes[idx].set_title(
                f'üî¥ {class_name}\nConfidence: {confidence:.3f}\nLocation: ({x}, {y})',
                fontsize=10, fontweight='bold', color='darkred'
            )
            axes[idx].axis('off')

            # Add red border
            for spine in axes[idx].spines.values():
                spine.set_edgecolor('red')
                spine.set_linewidth(3)

        slide.close()

    except Exception as e:
        print(f'‚ùå Error reading patches: {e}')
        print('‚ö†Ô∏è  Trying fallback method...')

        # Fallback: If OpenSlide fails, try PIL
        try:
            img = Image.open(svs_file)
            img_array = np.array(img.convert('RGB'), dtype=np.float32)

            for idx, pred in enumerate(abnormal_predictions_sorted[:num_to_show]):
                x, y = pred['x'], pred['y']
                patch_size = 224

                # Extract patch from image array
                patch = img_array[y:y+patch_size, x:x+patch_size]
                patch_display = np.clip(patch / 255.0, 0, 1)

                class_id = pred['class']
                class_name = label_map.get(class_id, f'Class {class_id}')
                confidence = pred['confidence']

                axes[idx].imshow(patch_display)
                axes[idx].set_title(
                    f'üî¥ {class_name}\nConfidence: {confidence:.3f}\nLocation: ({x}, {y})',
                    fontsize=10, fontweight='bold', color='darkred'
                )
                axes[idx].axis('off')

                for spine in axes[idx].spines.values():
                    spine.set_edgecolor('red')
                    spine.set_linewidth(3)
        except Exception as e2:
            print(f'‚ùå Fallback also failed: {e2}')

    # Hide remaining subplots
    for idx in range(num_to_show, len(axes)):
        axes[idx].axis('off')

    plt.tight_layout()
    plt.show()

    # Save coordinates to file
    print(f'\nüíæ Saving abnormal cell coordinates...')
    output_path = '/content/drive/MyDrive/PhatHienTeBao/abnormal_cells_coordinates.txt'

    with open(output_path, 'w') as f:
        f.write(f"{'='*80}\n")
        f.write(f"ABNORMAL CELLS DETECTED - COORDINATES AND PREDICTIONS\n")
        f.write(f"{'='*80}\n")
        f.write(f"File: {svs_file}\n")
        f.write(f"Total abnormal cells found: {len(abnormal_predictions)}\n")
        f.write(f"Slide dimensions: {width} x {height}\n")
        f.write(f"{'='*80}\n\n")

        for idx, pred in enumerate(abnormal_predictions_sorted, 1):
            x, y = pred['x'], pred['y']
            class_id = pred['class']
            class_name = label_map.get(class_id, f'Class {class_id}')
            confidence = pred['confidence']

            f.write(f"{idx:4d}. X={x:6d}, Y={y:6d} | Class: {class_name:15s} | Confidence: {confidence:.4f}\n")

    print(f'‚úÖ Coordinates saved to: {output_path}')

elif 'predictions' not in locals() or len(predictions) == 0:
    print('‚ö†Ô∏è  No predictions available. Run cell 12 first.')
else:
    print('‚ö†Ô∏è  Could not load slide file.')

## 13) Visualize Heat Map of Predictions

In [None]:
## 12) Predict on Whole Slide Image (SVS)