# PathoVision: BreakHis Binary Classification (Benign vs Malignant)
Academic support model for histopathology screening. Not a clinical diagnostic tool.

**Goals**
- Binary classification (Benign/Malignant)
- ResNet50 transfer learning
- Grad-CAM explainability
- Kaggle/Colab ready
- Exportable for backend inference

## 1. Setup

In [None]:
import os
import random
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_curve, auc
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 2. Dataset Setup (BreakHis)
Set the dataset path below. Example for Kaggle: `/kaggle/input/breakhis`

In [None]:
DATA_ROOT = '/kaggle/input/breakhis'  # change if needed

# Expected BreakHis path structure (one possible layout):
# /BreaKHis_v1/histology_slides/breast/benign/SOB/.../40X/*.png
# /BreaKHis_v1/histology_slides/breast/malignant/SOB/.../40X/*.png

benign_paths = glob(os.path.join(DATA_ROOT, '**', 'benign', '**', '*.png'), recursive=True)
malignant_paths = glob(os.path.join(DATA_ROOT, '**', 'malignant', '**', '*.png'), recursive=True)

print('Benign images:', len(benign_paths))
print('Malignant images:', len(malignant_paths))

all_paths = benign_paths + malignant_paths
all_labels = [0] * len(benign_paths) + [1] * len(malignant_paths)

# Build dataframe for easy splits
df = pd.DataFrame({'path': all_paths, 'label': all_labels})
df.head()

## 3. Train/Val/Test Split (70/15/15)

In [None]:
train_df, temp_df = train_test_split(
    df, test_size=0.30, random_state=SEED, stratify=df['label']
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.50, random_state=SEED, stratify=temp_df['label']
)

print('Train:', len(train_df), 'Val:', len(val_df), 'Test:', len(test_df))
train_df['label'].value_counts(), val_df['label'].value_counts(), test_df['label'].value_counts()

## 4. Transforms and Dataset

In [None]:
IMG_SIZE = 224

train_tfms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=10),
    T.ColorJitter(brightness=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_tfms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class BreakHisDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df.reset_index(drop=True)
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert('RGB')
        if self.transforms:
            img = self.transforms(img)
        label = int(row['label'])
        return img, label

train_ds = BreakHisDataset(train_df, transforms=train_tfms)
val_ds = BreakHisDataset(val_df, transforms=val_tfms)
test_ds = BreakHisDataset(test_df, transforms=val_tfms)

BATCH_SIZE = 16
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

## 5. Model (ResNet50 Transfer Learning)

In [None]:
from torchvision.models import ResNet50_Weights

model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Freeze early layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze layer4 and fc for fine-tuning
for param in model.layer4.parameters():
    param.requires_grad = True

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

## 6. Training Loop with Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in tqdm(loader, leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in tqdm(loader, leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

EPOCHS = 15
early_stopping = EarlyStopping(patience=5)

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    print(f'Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}')

    early_stopping(val_loss)
    if early_stopping.early_stop:
        print('Early stopping triggered')
        break

## 7. Training Curves

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.title('Accuracy vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

## 8. Evaluation on Test Set

In [None]:
model.eval()
all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)

        all_labels.extend(labels.numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())

acc = accuracy_score(all_labels, all_preds)
prec = precision_score(all_labels, all_preds)
rec = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)

print(f'Test Accuracy: {acc:.4f}')
print(f'Precision: {prec:.4f}')
print(f'Recall: {rec:.4f}')
print(f'F1 Score: {f1:.4f}')

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Benign', 'Malignant'], yticklabels=['Benign', 'Malignant'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

fpr, tpr, _ = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(5, 4))
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.4f}')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()

## 9. Grad-CAM Explainability

In [None]:
import cv2

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)

    def generate(self, input_tensor, class_idx=None):
        self.model.eval()
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()

        self.model.zero_grad()
        output[0, class_idx].backward()

        gradients = self.gradients[0]
        activations = self.activations[0]
        weights = torch.mean(gradients, dim=(1, 2))

        cam = torch.zeros(activations.shape[1:], dtype=torch.float32).to(device)
        for i, w in enumerate(weights):
            cam += w * activations[i]

        cam = torch.relu(cam)
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        cam = cam.detach().cpu().numpy()
        return cam

def overlay_heatmap(image_path, cam, output_path='heatmap.png'):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    overlay = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
    cv2.imwrite(output_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    return output_path

gradcam = GradCAM(model, model.layer4[-1])

# Example usage
sample_path = test_df.iloc[0]['path']
sample_img = Image.open(sample_path).convert('RGB')
input_tensor = val_tfms(sample_img).unsqueeze(0).to(device)
cam = gradcam.generate(input_tensor)
os.makedirs('heatmaps', exist_ok=True)
heatmap_path = overlay_heatmap(sample_path, cam, output_path='heatmaps/sample_01.png')
print('Saved heatmap to:', heatmap_path)

## 10. Inference Helper

In [None]:
class_names = {0: 'Benign', 1: 'Malignant'}

def predict_image(image_path, model, gradcam):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    input_tensor = val_tfms(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.softmax(output, dim=1).cpu().numpy()[0]

    pred_idx = int(np.argmax(probs))
    confidence = float(probs[pred_idx])

    cam = gradcam.generate(input_tensor, class_idx=pred_idx)
    os.makedirs('heatmaps', exist_ok=True)
    heatmap_path = overlay_heatmap(image_path, cam, output_path=f'heatmaps/{os.path.basename(image_path)}')

    return {
        'prediction': class_names[pred_idx],
        'confidence': confidence,
        'probabilities': {
            'benign': float(probs[0]),
            'malignant': float(probs[1])
        },
        'heatmap_path': heatmap_path
    }

# Example
result = predict_image(sample_path, model, gradcam)
result

## 11. Export Model

In [None]:
os.makedirs('models', exist_ok=True)
torch.save(model.state_dict(), 'models/pathovision_resnet50.pt')
print('Model saved to models/pathovision_resnet50.pt')

## 12. Simple FastAPI Inference Example

In [None]:
# Save as app.py for FastAPI usage
fastapi_example = r'''
from fastapi import FastAPI, File, UploadFile
import uvicorn
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import os

app = FastAPI()

# Load model
model = ...  # load model and weights
model.eval()

val_tfms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    image = Image.open(file.file).convert('RGB')
    input_tensor = val_tfms(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
    pred_idx = int(np.argmax(probs))
    return {
        'prediction': 'Malignant' if pred_idx == 1 else 'Benign',
        'confidence': float(probs[pred_idx]),
        'probabilities': {
            'benign': float(probs[0]),
            'malignant': float(probs[1])
        },
        'heatmap_path': 'heatmaps/sample.png'
    }

if __name__ == '__main__':
    uvicorn.run(app, host='0.0.0.0', port=8000)
'''
print(fastapi_example)