<a href="https://colab.research.google.com/github/chjayarajesh/Brain-Stroke-detection-and-Segmentation/blob/main/Stroke_detection_and_segmentation_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Brain Stroke classification and segmentation**

# **Using Kaggle API to import dataset**

In [None]:
!pip install -q kaggle timm

# Upload kaggle.json from your Kaggle account (Account → API → Create New Token)
from google.colab import files
files.upload()

# Move kaggle.json to correct path
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


cp: cannot stat 'kaggle.json': No such file or directory
chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory


Downloading Dataset

In [None]:
# Download the Brain Stroke CT Dataset
!kaggle datasets download -d ozguraslank/brain-stroke-ct-dataset -p /content/dataset

# Unzip the dataset
!unzip -q /content/dataset/brain-stroke-ct-dataset.zip -d /content/dataset

Dataset URL: https://www.kaggle.com/datasets/ozguraslank/brain-stroke-ct-dataset
License(s): other
Downloading brain-stroke-ct-dataset.zip to /content/dataset
100% 1.41G/1.41G [00:16<00:00, 165MB/s]
100% 1.41G/1.41G [00:16<00:00, 94.0MB/s]


# **Data Segregation and splitting into (Train, validate and test)**

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split
import cv2  # Added import for OpenCV

# Step 1: Segregate images and overlays into train, val, test
def segregate_dataset(base_dir, classes, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    for cls in classes:
        png_dir = os.path.join(base_dir, cls, 'PNG')
        overlay_dir = os.path.join(base_dir, cls, 'OVERLAY')
        if not os.path.exists(png_dir):
            print(f"PNG folder not found for {cls}. Skipping...")
            continue

        # Get list of PNG images
        images = [os.path.join(png_dir, f) for f in os.listdir(png_dir) if f.lower().endswith('.png')]
        if len(images) == 0:
            print(f"No PNG images found for {cls}. Skipping...")
            continue

        # Split into train/val/test
        train_images, test_images = train_test_split(images, test_size=val_ratio + test_ratio, random_state=42)
        val_images, test_images = train_test_split(test_images, test_size=test_ratio / (val_ratio + test_ratio), random_state=42)

        # Create directories for PNG and OVERLAY
        for split, split_images in [('train', train_images), ('val', val_images), ('test', test_images)]:
            split_png_dir = os.path.join(base_dir, split, 'PNG', cls)
            split_overlay_dir = os.path.join(base_dir, split, 'OVERLAY', cls)
            os.makedirs(split_png_dir, exist_ok=True)
            os.makedirs(split_overlay_dir, exist_ok=True)

            # Copy PNG images
            for img in split_images:
                filename = os.path.basename(img)
                shutil.copy(img, os.path.join(split_png_dir, filename))

            # Copy corresponding OVERLAY images or create zero mask for Normal
            for img in split_images:
                filename = os.path.basename(img)
                overlay_path = os.path.join(overlay_dir, filename)
                if os.path.exists(overlay_path):
                    shutil.copy(overlay_path, os.path.join(split_overlay_dir, filename))
                elif cls == 'Normal' and not os.path.exists(overlay_dir):
                    # Create a zero mask for Normal if no overlay exists
                    zero_mask = np.zeros((256, 256), dtype=np.uint8)  # Adjust size if needed
                    cv2.imwrite(os.path.join(split_overlay_dir, filename), zero_mask)
                    print(f"Created zero mask for {filename} in {cls}")
                else:
                    print(f"Overlay not found for {filename}, skipping copy.")

        print(f"{cls}: Train {len(train_images)}, Val {len(val_images)}, Test {len(test_images)}")

# Base directory (adjust to your Colab path)
base_dir = '/content/dataset/Brain_Stroke_CT_Dataset'  # Update to your actual path
classes = ['Bleeding', 'Ischemia', 'Normal']  # Adjust based on your dataset

# Run segregation
segregate_dataset(base_dir, classes)

# **1. Classification Model training**

---


Preproccesing and augmentation of dataset

---

ConvNext model -(Convolutional Next)


In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from torchvision import transforms
import timm
from tqdm import tqdm  # For progress bar

base_dir = '/content/dataset/Brain_Stroke_CT_Dataset'  # Update to your actual path in Colab

# Step 2: Preprocessing and Training
# Custom transform to replicate grayscale to 3 channels
class GrayscaleToRGB:
    def __call__(self, image):
        if image.shape[0] == 1:
            return torch.cat([image] * 3, dim=0)
        return image

# Custom Dataset with 3-channel replication
class CTDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.class_names)}

        for cls_name in self.class_names:
            cls_dir = os.path.join(data_dir, cls_name)
            if os.path.isdir(cls_dir):
                for img_name in os.listdir(cls_dir):
                    img_path = os.path.join(cls_dir, img_name)
                    if os.path.isfile(img_path):
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[cls_name])
        if not self.images:
            raise ValueError(f"No images found in {data_dir}. Check your dataset path and ensure it contains image subfolders.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        try:
            image = Image.open(img_path).convert('L')  # Load as grayscale
            if self.transform:
                image = self.transform(image)  # Apply full transform pipeline
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, 224, 224), label  # Return 3-channel placeholder

# Load datasets (after segregation, use the train/val/test directories)
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    GrayscaleToRGB(),  # Replicate to 3 channels after ToTensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 3-channel normalization
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    GrayscaleToRGB(),  # Replicate to 3 channels after ToTensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = CTDataset(train_dir, transform=train_transform)
val_dataset = CTDataset(val_dir, transform=val_test_transform)
test_dataset = CTDataset(test_dir, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Classes: {train_dataset.class_names}")

# Set device
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load pre-trained ConvNeXt-Base model
model = timm.create_model('convnext_base', pretrained=True)
model.reset_classifier(num_classes=3)  # 3 classes: hemorrhagic, ischemic, normal
model = model.to(device)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    best_acc = 0.0
    val_accuracies = []
    val_precisions = []
    val_recalls = []
    val_f1_scores = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        print(f"Starting epoch {epoch + 1}/{num_epochs}")
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                pbar.update(1)

        # Validation
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        epoch_val_loss = val_loss / len(val_loader)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_precision = precision_score(all_labels, all_preds, average='weighted')
        epoch_recall = recall_score(all_labels, all_preds, average='weighted')
        epoch_f1 = f1_score(all_labels, all_preds, average='weighted')
        val_accuracies.append(epoch_acc)
        val_precisions.append(epoch_precision)
        val_recalls.append(epoch_recall)
        val_f1_scores.append(epoch_f1)

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_acc:.4f}, Precision: {epoch_precision:.4f}, Recall: {epoch_recall:.4f}, F1: {epoch_f1:.4f}')

        if epoch_acc > best_acc:
            best_acc = epoch_acc
            torch.save(model.state_dict(), 'best_model_3class.pth')

    # Plot graphs for 4 parameters
    epochs = range(1, num_epochs + 1)
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle('Validation Metrics Over Epochs')

    # Accuracy graph
    axs[0, 0].plot(epochs, val_accuracies, label='Accuracy', color='b')
    axs[0, 0].set_title('Accuracy')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Score')
    axs[0, 0].legend()

    # Precision graph
    axs[0, 1].plot(epochs, val_precisions, label='Precision', color='r')
    axs[0, 1].set_title('Precision')
    axs[0, 1].set_xlabel('Epoch')
    axs[0, 1].set_ylabel('Score')
    axs[0, 1].legend()

    # Recall graph
    axs[1, 0].plot(epochs, val_recalls, label='Recall', color='g')
    axs[1, 0].set_title('Recall')
    axs[1, 0].set_xlabel('Epoch')
    axs[1, 0].set_ylabel('Score')
    axs[1, 0].legend()

    # F1 Score graph
    axs[1, 1].plot(epochs, val_f1_scores, label='F1 Score', color='m')
    axs[1, 1].set_title('F1 Score')
    axs[1, 1].set_xlabel('Epoch')
    axs[1, 1].set_ylabel('Score')
    axs[1, 1].legend()

    plt.tight_layout()
    plt.savefig('validation_metrics_graphs.png')  # Save the figure
    plt.show()
    return val_accuracies, val_precisions, val_recalls, val_f1_scores

# Evaluate function
def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Hemorrhagic', 'Ischemic', 'Normal'], yticklabels=['Hemorrhagic', 'Ischemic', 'Normal'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')  # Save the confusion matrix
    plt.show()

# Train the model and get metrics
val_accuracies, val_precisions, val_recalls, val_f1_scores = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

# Load best model and evaluate
model.load_state_dict(torch.load('best_model_3class.pth'))
evaluate_model(model, test_loader)

# **Segmentation Model training**
---


Preproccesing and augmentation of dataset

---

U-Net with efficientnet-b4 encoder

In [None]:
# Step 1: Install required packages
!pip install segmentation-models-pytorch -q

# Step 3: Preprocessing and Training
import os
import cv2
from glob import glob
import numpy as np
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.optim as optim  # Added import
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm

# Configuration
class Config:
    BATCH_SIZE = 8
    IMG_HEIGHT = 256
    IMG_WIDTH = 256
    EPOCHS = 50
    LEARNING_RATE = 1e-4
    DATASET_PATH = '/content/brain-stroke-images'  # Root path
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2

config = Config()

# Preprocessing: Generate pseudo-masks for stroke images
def generate_pseudo_mask(image, threshold=0.1):  # Simple thresholding
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(gray, threshold * 255, 255, cv2.THRESH_BINARY)
    mask = mask / 255.0
    return mask.astype(np.float32)

# Load data from cropped folders
def load_segmentation_data(base_dir):
    image_paths = []
    mask_paths = []  # Pseudo-masks generated on-the-fly
    labels = []  # 1 for STROKE, 0 for NORMAL

    for split in ['TRAIN_CROP', 'VAL_CROP', 'TEST_CROP']:
        for category in ['NORMAL', 'STROKE']:
            category_dir = os.path.join(base_dir, 'stroke_cropped/CROPPED', split, category)
            if os.path.exists(category_dir):
                images = glob(os.path.join(category_dir, '*.jpg'))
                for img in images:
                    image_paths.append(img)
                    mask_paths.append(None)  # Will generate pseudo-mask
                    labels.append(1 if category == 'STROKE' else 0)

    print(f"Loaded {len(image_paths)} images: {np.sum(labels)} stroke, {len(labels) - np.sum(labels)} normal.")
    return image_paths, mask_paths, labels

image_paths, mask_paths, labels = load_segmentation_data(config.DATASET_PATH)
if len(image_paths) == 0:
    raise ValueError("No images found! Check dataset path.")

# Split into train/val/test
train_images, test_images, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42, stratify=labels)
val_images, test_images, val_labels, test_labels = train_test_split(test_images, test_labels, test_size=0.5, random_state=42, stratify=test_labels)

print(f"Training samples: {len(train_images)}")
print(f"Validation samples: {len(val_images)}")
print(f"Test samples: {len(test_images)}")

# Dataset class
class StrokeDataset(Dataset):
    def __init__(self, image_paths, labels, transforms=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (config.IMG_WIDTH, config.IMG_HEIGHT))

        label = self.labels[idx]
        if label == 1:  # Stroke: generate pseudo-mask
            msk = generate_pseudo_mask(img, threshold=0.1)
        else:  # Normal: zero mask
            msk = np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.float32)

        if self.transforms:
            transformed = self.transforms(image=img, mask=msk)
            img, msk = transformed['image'], transformed['mask']

        return img, msk

# Transforms
def get_transforms(is_training=True):
    if is_training:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.3),
            A.Affine(translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)}, scale=(0.95, 1.05), rotate=(-5, 5), p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

train_dataset = StrokeDataset(train_images, train_labels, get_transforms(True))
val_dataset = StrokeDataset(val_images, val_labels, get_transforms(False))
test_dataset = StrokeDataset(test_images, test_labels, get_transforms(False))

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)

# Segmentation Model
model = smp.Unet(
    encoder_name='efficientnet-b4',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1,
    activation=None,
).to(config.DEVICE)

# Loss and Optimizer
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, y_pred_prob, y_true):
        y_pred = y_pred_prob.view(-1)
        y_true = y_true.view(-1)
        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    def forward(self, y_pred_logits, y_true):
        prob = torch.sigmoid(y_pred_logits)
        return 0.5 * self.bce(y_pred_logits, y_true) + 0.5 * self.dice(prob, y_true)

criterion = CombinedLoss()
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)  # Now should work

# Training
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = running_dice = running_iou = 0.0
    progress_bar = tqdm(dataloader, desc='Training')
    for batch_idx, (images, masks) in enumerate(progress_bar):
        images = images.to(device)
        masks = masks.unsqueeze(1).to(device)

        logits = model(images)
        loss = criterion(logits, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        probs = torch.sigmoid(logits)
        dice_score = calculate_dice_score(probs, masks)
        iou_score = calculate_iou_score(probs, masks)

        running_loss += loss.item()
        running_dice += dice_score
        running_iou += iou_score

        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Dice': f'{dice_score:.4f}', 'IoU': f'{iou_score:.4f}'})

    N = len(dataloader)
    return running_loss / N, running_dice / N, running_iou / N

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = running_dice = running_iou = 0.0
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.unsqueeze(1).to(device)

            logits = model(images)
            loss = criterion(logits, masks)

            probs = torch.sigmoid(logits)
            dice_score = calculate_dice_score(probs, masks)
            iou_score = calculate_iou_score(probs, masks)

            running_loss += loss.item()
            running_dice += dice_score
            running_iou += iou_score

    N = len(dataloader)
    return running_loss / N, running_dice / N, running_iou / N

def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=50):
    best_dice = 0.0
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        train_loss, train_dice, train_iou = train_epoch(model, train_loader, criterion, optimizer, config.DEVICE)
        val_loss, val_dice, val_iou = validate_epoch(model, val_loader, criterion, config.DEVICE)

        print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, Train IoU: {train_iou:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), 'best_stroke_segmentation_model.pth')
            print(f"New best model saved! Dice: {best_dice:.4f}")

    model.load_state_dict(torch.load('best_stroke_segmentation_model.pth'))
    return model

# Metrics
def calculate_dice_score(y_pred_prob, y_true, smooth=1e-6):
    y_pred = (y_pred_prob > 0.5).float()
    intersection = (y_pred * y_true).sum()
    dice = (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
    return dice.item()

def calculate_iou_score(y_pred_prob, y_true, smooth=1e-6):
    y_pred = (y_pred_prob > 0.5).float()
    intersection = (y_pred * y_true).sum()
    union = y_pred.sum() + y_true.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.item()

# Train the model
model = train_model(model, train_loader, val_loader, criterion, optimizer)

# Step 3: Show 2 Images After Segmentation
def predict_and_show(model, image_paths, num_samples=2):
    model.eval()
    transforms = get_transforms(False)
    for i in range(num_samples):
        image = cv2.imread(image_paths[i])
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_resized = cv2.resize(image_rgb, (config.IMG_WIDTH, config.IMG_HEIGHT))

        transformed = transforms(image=image_resized, mask=np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), np.float32))
        image_tensor = transformed['image'].unsqueeze(0).to(config.DEVICE)

        with torch.no_grad():
            pred_logits = model(image_tensor)[0, 0].cpu().numpy()
            pred_prob = 1 / (1 + np.exp(-pred_logits))
            pred_mask_binary = (pred_prob > 0.5).astype(np.uint8) * 255

        plt.figure(figsize=(8, 6))
        plt.imshow(pred_mask_binary, cmap='gray')
        plt.title(f"Segmented Image {i+1}")
        plt.axis('off')
        plt.show()

# Use 2 images from test set
predict_and_show(model, test_images, num_samples=2)