# Ki-67 Model Validation - Get Your REAL Accuracy!

This notebook will give you the **actual F1 accuracy** of your trained model.

## Instructions:
1. Upload your checkpoint: `ki67-point-epoch=68-val_peak_f1_avg=0.8503.ckpt` to the `/content/` folder
2. Make sure your `BCData.zip` is in your Google Drive (MyDrive folder)
3. Run all cells below
4. Get your REAL accuracy score!

---

In [None]:
# Mount Google Drive and extract dataset
from google.colab import drive
drive.mount('/content/drive')

# Unzip BCData.zip from Google Drive
!unzip -q "/content/drive/MyDrive/BCData.zip" -d "/content/"
print("‚úÖ Dataset extracted to /content/BCData")

In [None]:
# Install dependencies
!pip install -q segmentation-models-pytorch albumentations h5py opencv-python-headless pytorch-lightning scipy scikit-image

In [None]:
# Import libraries
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
import h5py
import os
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
from scipy.spatial.distance import cdist
print("‚úÖ Libraries imported successfully!")

In [None]:
# Copy classes from your training script
class ImprovedPointHeatmapGenerator:
    def __init__(self, sigma=8.0):
        self.sigma = sigma
    
    def generate_heatmap(self, points, image_shape=(640, 640)):
        if len(points) == 0:
            return np.zeros(image_shape, dtype=np.float32)
        
        heatmap = np.zeros(image_shape, dtype=np.float32)
        kernel_size = int(6 * self.sigma + 1)
        if kernel_size % 2 == 0:
            kernel_size += 1
        
        x = np.arange(0, kernel_size)
        y = x[:, np.newaxis]
        x0 = y0 = kernel_size // 2
        
        gaussian = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
        gaussian = gaussian / gaussian.max()
        
        for point in points:
            x, y = int(point[0]), int(point[1])
            x_min = max(0, x - kernel_size // 2)
            x_max = min(image_shape[1], x + kernel_size // 2 + 1)
            y_min = max(0, y - kernel_size // 2)
            y_max = min(image_shape[0], y + kernel_size // 2 + 1)
            
            k_x_min = max(0, kernel_size // 2 - x)
            k_x_max = min(kernel_size, kernel_size // 2 + (image_shape[1] - x))
            k_y_min = max(0, kernel_size // 2 - y)
            k_y_max = min(kernel_size, kernel_size // 2 + (image_shape[0] - y))
            
            heatmap[y_min:y_max, x_min:x_max] = np.maximum(
                heatmap[y_min:y_max, x_min:x_max],
                gaussian[k_y_min:k_y_max, k_x_min:k_x_max]
            )
        
        return heatmap

class Ki67PointDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.heatmap_generator = ImprovedPointHeatmapGenerator(sigma=8.0)
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
    
    def __len__(self):
        return len(self.image_files)
    
    def load_points_from_h5(self, h5_path):
        try:
            with h5py.File(h5_path, 'r') as f:
                if 'coordinates' in f:
                    return f['coordinates'][:]
        except:
            pass
        return np.array([])
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        base_name = os.path.splitext(img_name)[0]
        pos_h5 = os.path.join(self.annotation_dir, 'positive', f"{base_name}.h5")
        neg_h5 = os.path.join(self.annotation_dir, 'negative', f"{base_name}.h5")
        
        pos_points = self.load_points_from_h5(pos_h5)
        neg_points = self.load_points_from_h5(neg_h5)
        
        pos_heatmap = self.heatmap_generator.generate_heatmap(pos_points, image.shape[:2])
        neg_heatmap = self.heatmap_generator.generate_heatmap(neg_points, image.shape[:2])
        heatmaps = np.stack([pos_heatmap, neg_heatmap], axis=0)
        
        if self.transform:
            transformed = self.transform(image=image, mask=heatmaps.transpose(1, 2, 0))
            image = transformed['image']
            heatmaps = transformed['mask'].permute(2, 0, 1).float()
        else:
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
            heatmaps = torch.from_numpy(heatmaps).float()
        
        return image, heatmaps, img_name

class ImprovedKi67PointDetectionModel(pl.LightningModule):
    def __init__(self, encoder_name='efficientnet-b3', learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = smp.Unet(encoder_name=encoder_name, encoder_weights='imagenet', in_channels=3, classes=2, activation=None)
        self.criterion = ImprovedHeatmapLoss(pos_weight=10.0)
    
    def forward(self, x):
        return self.model(x)

class ImprovedHeatmapLoss(nn.Module):
    def __init__(self, pos_weight=10.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]), reduction='mean')
    
    def dice_loss(self, pred, target, smooth=1e-6):
        pred = torch.sigmoid(pred)
        intersection = (pred * target).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
        dice = (2.0 * intersection + smooth) / (union + smooth)
        return 1.0 - dice.mean()
    
    def focal_loss(self, pred, target, alpha=0.25, gamma=2.0):
        bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = alpha * (1 - pt) ** gamma * bce_loss
        return focal_loss.mean()
    
    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        dice = self.dice_loss(pred, target)
        focal = self.focal_loss(pred, target)
        total_loss = 0.4 * bce_loss + 0.4 * dice + 0.2 * focal
        return total_loss, {'bce': bce_loss.item(), 'dice': dice.item(), 'focal': focal.item()}

print("‚úÖ Classes defined successfully!")

In [None]:
# Main validation function
def get_actual_accuracy(checkpoint_path='/content/ki67-point-epoch=68-val_peak_f1_avg=0.8503.ckpt',
                       data_path='/content/BCData'):
    """
    Get the REAL F1 accuracy of your trained model
    """
    
    print("üîç Getting REAL Model Accuracy...")
    print(f"Model: {checkpoint_path}")
    print(f"Data: {data_path}")
    print("="*50)
    
    # Check files exist
    if not os.path.exists(checkpoint_path):
        print(f"‚ùå Model not found: {checkpoint_path}")
        return None
    
    if not os.path.exists(data_path):
        print(f"‚ùå Data not found: {data_path}")
        return None
    
    # Load model
    print("Loading model...")
    model = ImprovedKi67PointDetectionModel.load_from_checkpoint(checkpoint_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    print(f"‚úì Model loaded on {device}")
    
    # Create validation dataset
    val_transform = A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    val_dataset = Ki67PointDataset(
        image_dir=os.path.join(data_path, 'images/validation'),
        annotation_dir=os.path.join(data_path, 'annotations/validation'),
        transform=val_transform
    )
    
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
    
    print(f"‚úì Validation dataset: {len(val_dataset)} images")
    
    # Run validation
    print("Running validation...")
    all_pred = []
    all_target = []
    
    with torch.no_grad():
        for images, heatmaps, names in val_loader:
            images = images.to(device)
            outputs = model(images)
            pred_heatmaps = torch.sigmoid(outputs)
            
            all_pred.append(pred_heatmaps.cpu())
            all_target.append(heatmaps.cpu())
    
    all_pred = torch.cat(all_pred)
    all_target = torch.cat(all_target)
    
    print(f"‚úì Processed {len(all_pred)} images")
    
    # Compute peak-based F1 (the real metric)
    total_tp_pos, total_fp_pos, total_fn_pos = 0, 0, 0
    total_tp_neg, total_fp_neg, total_fn_neg = 0, 0, 0
    
    for i in range(len(all_pred)):
        pred_pos = all_pred[i, 0].float().numpy()
        pred_neg = all_pred[i, 1].float().numpy()
        target_pos = all_target[i, 0].float().numpy()
        target_neg = all_target[i, 1].float().numpy()
        
        # Positive cells
        pred_peaks_pos = peak_local_max(pred_pos, threshold_abs=0.3, min_distance=10)
        target_peaks_pos = peak_local_max(target_pos, threshold_abs=0.2, min_distance=10)
        
        if len(pred_peaks_pos) > 0 and len(target_peaks_pos) > 0:
            distances = cdist(pred_peaks_pos, target_peaks_pos)
            matches = (distances < 10).any(axis=1)
            total_tp_pos += matches.sum()
            total_fp_pos += len(pred_peaks_pos) - matches.sum()
            total_fn_pos += len(target_peaks_pos) - matches.sum()
        else:
            total_fp_pos += len(pred_peaks_pos)
            total_fn_pos += len(target_peaks_pos)
        
        # Negative cells
        pred_peaks_neg = peak_local_max(pred_neg, threshold_abs=0.3, min_distance=10)
        target_peaks_neg = peak_local_max(target_neg, threshold_abs=0.2, min_distance=10)
        
        if len(pred_peaks_neg) > 0 and len(target_peaks_neg) > 0:
            distances = cdist(pred_peaks_neg, target_peaks_neg)
            matches = (distances < 10).any(axis=1)
            total_tp_neg += matches.sum()
            total_fp_neg += len(pred_peaks_neg) - matches.sum()
            total_fn_neg += len(target_peaks_neg) - matches.sum()
        else:
            total_fp_neg += len(pred_peaks_neg)
            total_fn_neg += len(target_peaks_neg)
    
    # Calculate F1 scores
    f1_pos = 2 * total_tp_pos / (2 * total_tp_pos + total_fp_pos + total_fn_pos + 1e-6)
    f1_neg = 2 * total_tp_neg / (2 * total_tp_neg + total_fp_neg + total_fn_neg + 1e-6)
    f1_avg = (f1_pos + f1_neg) / 2
    
    # Results
    print("\n" + "="*60)
    print("üéØ ACTUAL MODEL ACCURACY RESULTS")
    print("="*60)
    print(f"Peak-based F1 Score (Positive): {f1_pos:.4f}")
    print(f"Peak-based F1 Score (Negative): {f1_neg:.4f}")
    print(f"AVERAGE F1 SCORE: {f1_avg:.4f}")
    print(f"ACCURACY: {f1_avg*100:.2f}%")
    print()
    print("INTERPRETATION:")
    if f1_avg >= 0.95:
        print("üéâ EXCELLENT: 95%+ accuracy - World-class performance!")
    elif f1_avg >= 0.90:
        print("‚úÖ OUTSTANDING: 90%+ accuracy - Clinical gold standard!")
    elif f1_avg >= 0.85:
        print("üëç VERY GOOD: 85%+ accuracy - Better than most published methods!")
    elif f1_avg >= 0.80:
        print("üëå GOOD: 80%+ accuracy - Clinical grade performance!")
    elif f1_avg >= 0.70:
        print("‚ö†Ô∏è FAIR: 70%+ accuracy - Needs improvement")
    else:
        print("‚ùå POOR: <70% accuracy - Retrain needed")
    print("="*60)
    
    return {
        'f1_avg': f1_avg,
        'f1_pos': f1_pos,
        'f1_neg': f1_neg,
        'accuracy_percent': f1_avg * 100
    }

In [None]:
# RUN VALIDATION - This is the main cell to run!
print("üéØ Ki-67 Model Validation")
print("="*50)

# Get your REAL accuracy!
results = get_actual_accuracy()

if results:
    print("\n‚úÖ Validation Complete!")
    print(f"Final F1 Score: {results['f1_avg']:.4f}")
    print(f"Accuracy: {results['accuracy_percent']:.1f}%")
else:
    print("\n‚ùå Validation failed - check your paths!")