# 🎯 Day 4.6 - Mask-Guided Attention Training (90%+ Target)

## 🎯 Learning Objectives

In this notebook, you'll:
1. **Use tumor masks** to create attention-weighted training
2. **Focus model on tumor regions** rather than skull/background
3. **Implement mask-based data augmentation**
4. **Train with region-of-interest (ROI) focus**
5. **Achieve 90%+ accuracy** by leveraging segmentation masks

---

## 💡 **Why Masks Matter:**

Current problem: Model is confused because it looks at **entire brain** including:
- ❌ Skull boundaries
- ❌ Background noise
- ❌ Non-tumor tissue

**With masks, we can:**
- ✅ Focus on **tumor regions only**
- ✅ Apply stronger augmentation to tumor area
- ✅ Weight loss by tumor importance
- ✅ Extract tumor-specific features

---

In [None]:
## 🔧 Setup

import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Flatten, Dense, 
    Dropout, Multiply, Lambda, Concatenate
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
import cv2

# Add src to path
sys.path.insert(0, '../..')
from src.modeling.model_cnn import enable_gpu_memory_growth

# Setup
print(f"TensorFlow: {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")
enable_gpu_memory_growth()

sns.set_style('white')
plt.rcParams['figure.figsize'] = (14, 6)

np.random.seed(42)
tf.random.set_seed(42)

print("\n✅ Libraries imported successfully")
print(f"⏰ Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
## 📂 Paths Configuration

# Data paths
TRAIN_CSV = '../../outputs/data_splits/train_split.csv'
VAL_CSV = '../../outputs/data_splits/val_split.csv'
TEST_CSV = '../../outputs/data_splits/test_split.csv'

# Image directories
IMAGES_DIR = '../../outputs/ce_mri_enhanced'
MASKS_DIR = '../../outputs/ce_mri_masks'

# Output directories
MODELS_DIR = '../../outputs/models'
LOGS_DIR = '../../outputs/logs'
VIZ_DIR = '../../outputs/visualizations'
METRICS_DIR = '../../outputs/metrics'

for directory in [MODELS_DIR, LOGS_DIR, VIZ_DIR, METRICS_DIR]:
    os.makedirs(directory, exist_ok=True)

print("✅ Paths configured")

## 🎨 Custom Data Generator with Masks

This generator will:
1. Load image AND corresponding mask
2. Apply same augmentation to both
3. Multiply image by mask to focus on tumor
4. Return (masked_image, label)

In [None]:
class MaskGuidedDataGenerator(tf.keras.utils.Sequence):
    """Data generator that uses tumor masks for attention-focused training."""
    
    def __init__(self, csv_path, batch_size=32, target_size=(128, 128),
                 images_dir='../../outputs/ce_mri_enhanced',
                 masks_dir='../../outputs/ce_mri_masks',
                 augment=True, shuffle=True):
        
        self.df = pd.read_csv(csv_path)
        self.batch_size = batch_size
        self.target_size = target_size
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.augment = augment
        self.shuffle = shuffle
        self.indices = np.arange(len(self.df))
        
        # Augmentation parameters
        self.rotation_range = 15
        self.width_shift_range = 0.1
        self.height_shift_range = 0.1
        self.zoom_range = 0.1
        self.horizontal_flip = True
        
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_data = self.df.iloc[batch_indices]
        
        images = []
        labels = []
        
        for _, row in batch_data.iterrows():
            # Load image
            img_path = row['filepath']
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            img = cv2.resize(img, self.target_size)
            img = img.astype('float32') / 255.0
            
            # Load corresponding mask
            mask_path = img_path.replace(self.images_dir, self.masks_dir)
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                mask = cv2.resize(mask, self.target_size)
                mask = (mask > 127).astype('float32')  # Binarize
                
                # Dilate mask slightly to include tumor boundary
                kernel = np.ones((3, 3), np.uint8)
                mask = cv2.dilate(mask, kernel, iterations=1)
            else:
                mask = np.ones(self.target_size, dtype='float32')
            
            # Apply augmentation if enabled
            if self.augment:
                img, mask = self._augment(img, mask)
            
            # Apply mask attention: focus on tumor region
            # Keep some background context (0.3 weight)
            attended_img = img * (0.3 + 0.7 * mask)
            
            # Add channel dimension
            attended_img = np.expand_dims(attended_img, axis=-1)
            
            images.append(attended_img)
            
            # Label (0-indexed for model)
            label = row['label'] - 1
            labels.append(label)
        
        # Convert to arrays
        X = np.array(images)
        y = tf.keras.utils.to_categorical(labels, num_classes=3)
        
        return X, y
    
    def _augment(self, img, mask):
        """Apply same augmentation to image and mask."""
        if np.random.rand() < 0.5:
            # Rotation
            angle = np.random.uniform(-self.rotation_range, self.rotation_range)
            center = (img.shape[1] // 2, img.shape[0] // 2)
            M = cv2.getRotationMatrix2D(center, angle, 1.0)
            img = cv2.warpAffine(img, M, img.shape[::-1])
            mask = cv2.warpAffine(mask, M, mask.shape[::-1])
        
        if np.random.rand() < 0.5:
            # Horizontal flip
            img = cv2.flip(img, 1)
            mask = cv2.flip(mask, 1)
        
        if np.random.rand() < 0.5:
            # Translation
            tx = np.random.uniform(-self.width_shift_range, self.width_shift_range) * img.shape[1]
            ty = np.random.uniform(-self.height_shift_range, self.height_shift_range) * img.shape[0]
            M = np.float32([[1, 0, tx], [0, 1, ty]])
            img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
            mask = cv2.warpAffine(mask, M, (mask.shape[1], mask.shape[0]))
        
        return img, mask
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

print("✅ MaskGuidedDataGenerator class defined")

In [None]:
## 📊 Create Data Generators

print("📂 Creating mask-guided data generators...\n")

# Training generator (with augmentation and mask attention)
train_gen = MaskGuidedDataGenerator(
    csv_path=TRAIN_CSV,
    batch_size=32,
    target_size=(128, 128),
    images_dir=IMAGES_DIR,
    masks_dir=MASKS_DIR,
    augment=True,
    shuffle=True
)

# Validation generator (no augmentation, but with mask attention)
val_gen = MaskGuidedDataGenerator(
    csv_path=VAL_CSV,
    batch_size=32,
    target_size=(128, 128),
    images_dir=IMAGES_DIR,
    masks_dir=MASKS_DIR,
    augment=False,
    shuffle=False
)

print(f"✅ Training generator: {len(train_gen)} batches")
print(f"✅ Validation generator: {len(val_gen)} batches")

# Visualize sample
print("\n📊 Sample batch:")
X_sample, y_sample = train_gen[0]
print(f"   Input shape: {X_sample.shape}")
print(f"   Label shape: {y_sample.shape}")

In [None]:
## 🖼️ Visualize Mask-Attended Images

# Show examples of mask-attended images
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(8):
    row = i // 4
    col = i % 4
    
    img = X_sample[i, :, :, 0]
    label = np.argmax(y_sample[i])
    label_name = ['Meningioma', 'Glioma', 'Pituitary'][label]
    
    axes[row, col].imshow(img, cmap='gray')
    axes[row, col].set_title(f'{label_name}', fontweight='bold')
    axes[row, col].axis('off')

plt.suptitle('🎯 Mask-Attended Training Images (Tumor-Focused)', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
viz_path = os.path.join(VIZ_DIR, 'day4_06_mask_attended_samples.png')
plt.savefig(viz_path, dpi=300, bbox_inches='tight')
print(f"✅ Visualization saved: {viz_path}")
plt.show()

## 🏗️ Build Enhanced CNN Model

Larger capacity model for better feature extraction from tumor regions.

In [None]:
from sklearn.utils.class_weight import compute_class_weight

# Calculate class weights
train_df = pd.read_csv(TRAIN_CSV)
class_weights_array = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_df['label']),
    y=train_df['label']
)
class_weights = {i: weight for i, weight in enumerate(class_weights_array)}

print("⚖️ Class Weights:")
for class_idx, weight in class_weights.items():
    tumor_type = {0: 'Meningioma', 1: 'Glioma', 2: 'Pituitary'}[class_idx]
    print(f"   Class {class_idx} ({tumor_type}): {weight:.4f}x")

# Build model
def build_mask_aware_cnn(input_shape=(128, 128, 1), num_classes=3, learning_rate=5e-5):
    """
    Enhanced CNN with more capacity for mask-attended features.
    """
    model = tf.keras.Sequential([
        # Block 1
        Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        # Block 2
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        # Block 3
        Conv2D(256, (3, 3), activation='relu', padding='same'),
        Conv2D(256, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        # Dense layers
        Flatten(),
        Dense(512, activation='relu'),
        Dropout(0.5),
        Dense(256, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ], name='MaskAwareCNN')
    
    # Compile with lower learning rate
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

model = build_mask_aware_cnn(
    input_shape=(128, 128, 1),
    num_classes=3,
    learning_rate=5e-5
)

print(f"\n📊 Model Architecture:")
model.summary()

total_params = model.count_params()
print(f"\n✅ Total parameters: {total_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")

## 🎛️ Setup Callbacks

In [None]:
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, 
    CSVLogger, Callback
)

# Callbacks
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

checkpoint_path = os.path.join(MODELS_DIR, 'model_mask_aware_best.h5')
model_checkpoint = ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1,
    mode='max'
)

csv_log_path = os.path.join(LOGS_DIR, 'training_log_mask_aware.csv')
csv_logger = CSVLogger(csv_log_path)

# LR Tracker
class LRTracker(Callback):
    def __init__(self):
        super().__init__()
        self.lrs = []
        self.epochs = []
    
    def on_epoch_end(self, epoch, logs=None):
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        self.lrs.append(lr)
        self.epochs.append(epoch + 1)
        logs['lr'] = lr

lr_tracker = LRTracker()

callbacks_list = [early_stopping, model_checkpoint, reduce_lr, csv_logger, lr_tracker]

print("✅ Callbacks configured")

## 🚀 Train Mask-Aware Model

**Target: 90%+ test accuracy**

Training with:
- ✅ Mask-guided attention
- ✅ Class weights for balance
- ✅ Enhanced model capacity
- ✅ Lower learning rate (5e-5)
- ✅ More epochs (25)

In [None]:
import time

EPOCHS = 25

print("="*70)
print("🏋️ STARTING MASK-AWARE TRAINING")
print("="*70)
print(f"⏰ Start time: {datetime.now().strftime('%H:%M:%S')}")
print(f"📊 Max epochs: {EPOCHS}")
print(f"🎯 Target: 90%+ test accuracy")
print(f"⚖️ Using balanced class weights")
print(f"🎭 Using tumor mask attention")
print("="*70 + "\n")

start_time = time.time()

# Train
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks_list,
    class_weight=class_weights,
    verbose=1
)

training_time = time.time() - start_time
minutes = int(training_time // 60)
seconds = int(training_time % 60)

print("\n" + "="*70)
print("✅ TRAINING COMPLETED")
print("="*70)
print(f"⏰ End time: {datetime.now().strftime('%H:%M:%S')}")
print(f"⌛ Total time: {minutes}m {seconds}s")
print(f"📈 Epochs completed: {len(history.history['loss'])}")
print("="*70)

In [None]:
## 📊 Training Curves

train_loss = history.history['loss']
train_acc = history.history['accuracy']
val_loss = history.history['val_loss']
val_acc = history.history['val_accuracy']
epochs_completed = len(train_loss)
epochs_range = range(1, epochs_completed + 1)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Accuracy
axes[0].plot(epochs_range, train_acc, 'b-o', label='Training', linewidth=2)
axes[0].plot(epochs_range, val_acc, 'r-s', label='Validation', linewidth=2)
axes[0].axhline(y=max(val_acc), color='g', linestyle='--', alpha=0.5)
axes[0].set_title(f'Model Accuracy (Best Val: {max(val_acc)*100:.2f}%)', 
                  fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(epochs_range, train_loss, 'b-o', label='Training', linewidth=2)
axes[1].plot(epochs_range, val_loss, 'r-s', label='Validation', linewidth=2)
axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
curves_path = os.path.join(VIZ_DIR, 'day4_06_mask_aware_training_curves.png')
plt.savefig(curves_path, dpi=300, bbox_inches='tight')
print(f"✅ Training curves saved: {curves_path}")
plt.show()

print(f"\n📊 Training Summary:")
print(f"   Best validation accuracy: {max(val_acc)*100:.2f}%")
print(f"   Final training accuracy: {train_acc[-1]*100:.2f}%")
print(f"   Overfitting gap: {(train_acc[-1] - val_acc[-1])*100:.2f}%")

## 🧪 Evaluate on Test Set

Let's see if we hit our 90%+ target!

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# Create test generator
test_gen = MaskGuidedDataGenerator(
    csv_path=TEST_CSV,
    batch_size=32,
    target_size=(128, 128),
    images_dir=IMAGES_DIR,
    masks_dir=MASKS_DIR,
    augment=False,
    shuffle=False
)

print("🔮 Generating predictions on test set...\n")

# Predict
y_pred_probs = model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_probs, axis=1)

# True labels
test_df = pd.read_csv(TEST_CSV)
y_true = (test_df['label'] - 1).values

# Calculate accuracy
test_accuracy = accuracy_score(y_true, y_pred)

print("\n" + "="*70)
print("🎯 TEST SET PERFORMANCE (MASK-AWARE MODEL)")
print("="*70)
print(f"\n📊 Overall Accuracy: {test_accuracy * 100:.2f}%")
print(f"   Correct: {np.sum(y_true == y_pred)} / {len(y_true)}")
print(f"   Incorrect: {np.sum(y_true != y_pred)} / {len(y_true)}")

if test_accuracy >= 0.90:
    print(f"\n🎉 ✅ TARGET ACHIEVED! {test_accuracy*100:.2f}% >= 90%")
else:
    print(f"\n⚠️ Target not yet reached ({test_accuracy*100:.2f}% < 90%)")
    print("   Consider: longer training, more augmentation, or ensemble methods")

print("="*70)

# Classification report
class_names = ['Meningioma', 'Glioma', 'Pituitary']
print("\n📋 CLASSIFICATION REPORT")
print("="*70)
y_true_labels = y_true + 1
y_pred_labels = y_pred + 1
report = classification_report(y_true_labels, y_pred_labels, target_names=class_names, digits=4)
print(report)

# Confusion matrix
cm = confusion_matrix(y_true_labels, y_pred_labels, labels=[1, 2, 3])
print("\n🔲 CONFUSION MATRIX")
print("="*70)
print(cm)
print("="*70)

In [None]:
## 📈 Visualize Confusion Matrix

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names,
            ax=axes[0], cbar_kws={'label': 'Count'})
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Actual Class')
axes[0].set_xlabel('Predicted Class')

# Normalized
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='RdYlGn',
            xticklabels=class_names, yticklabels=class_names,
            ax=axes[1], cbar_kws={'label': 'Percentage'}, vmin=0, vmax=1)
axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Actual Class')
axes[1].set_xlabel('Predicted Class')

plt.tight_layout()
cm_path = os.path.join(VIZ_DIR, 'day4_06_mask_aware_confusion_matrix.png')
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
print(f"✅ Confusion matrix saved: {cm_path}")
plt.show()

## ✅ Summary

### What We Accomplished:

1. ✅ **Mask-guided attention** - Model focuses on tumor regions
2. ✅ **Enhanced architecture** - Deeper network with more capacity
3. ✅ **Class weighting** - Balanced training across classes
4. ✅ **Improved augmentation** - Masks help preserve tumor structure

### Results:

**Test Accuracy:** {test_accuracy*100:.2f}%

**Improvements over baseline (71.76%):**
- Absolute gain: {(test_accuracy - 0.7176)*100:.2f} percentage points
- Relative improvement: {((test_accuracy / 0.7176) - 1)*100:.1f}%

### If Target Not Reached (< 90%):

**Additional strategies:**
1. **Ensemble methods** - Combine multiple models
2. **Transfer learning** - Use pre-trained models (ResNet, EfficientNet)
3. **Longer training** - 40-50 epochs with careful monitoring
4. **External data** - Augment with similar datasets
5. **Test-time augmentation** - Average predictions over multiple augmentations

---

**Date:** October 22, 2025  
**Status:** ✅ Completed  
**Best Model:** `model_mask_aware_best.h5`