<a href="https://colab.research.google.com/github/nisreen28/3breast_cancer_SVC_ECE.ipynb/blob/main/CODE%20FOR%20FP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import random
import cv2
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from sklearn.metrics import confusion_matrix, precision_score, recall_score
import seaborn as sns
from torch.amp import GradScaler, autocast
from numba import jit

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Device and Output Directory
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTPUT_DIR = "/content/drive/MyDrive/LeishManiaPlots"  # Save to Google Drive
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Part 1: Data Preparation, Preprocessing and Visualization
def generate_synthetic_data(img_size=256, num_objects=20, uninfected=False):
    image = np.random.normal(0.5, 0.1, (img_size, img_size)).astype(np.float32)
    mask = np.zeros((img_size, img_size), dtype=np.float32)
    is_infected = not uninfected
    if is_infected:
        for _ in range(num_objects + np.random.randint(-5, 5)):
            center = (np.random.randint(20, img_size-20), np.random.randint(20, img_size-20))
            radius = np.random.randint(2, 5)
            cv2.circle(image, center, radius, 0.8, -1)
            cv2.circle(mask, center, radius, 1.0, -1)
    if np.random.rand() > 0.5:
        flip_code = np.random.randint(0, 2)
        image = cv2.flip(image, flip_code)
        mask = cv2.flip(mask, flip_code)
    angle = np.random.choice([0, 90, 180, 270])
    M = cv2.getRotationMatrix2D((img_size//2, img_size//2), angle, 1)
    image = cv2.warpAffine(image, M, (img_size, img_size))
    mask = cv2.warpAffine(mask, M, (img_size, img_size))
    return image, mask, is_infected

def add_gaussian_noise(image, std=0.2):
    noise = np.random.normal(0, std, image.shape).astype(np.float32)
    return np.clip(image + noise, 0, 1)

@jit(nopython=True)
def kalman_filter(image, a=0.3, b=0.3):
    h, w = image.shape
    filtered = np.zeros((h, w), dtype=np.float32)
    for m in range(h):
        for n in range(w):
            left = filtered[m-1, n] if m > 0 else image[m, n]
            top = filtered[m, n-1] if n > 0 else image[m, n]
            filtered[m, n] = a * left + b * top + (1 - a - b) * image[m, n]
    return filtered

def sobel_edges(image):
    sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
    sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
    magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
    return (magnitude / magnitude.max()).astype(np.float32)

def visualize_preprocessing(clean, noisy, denoised, edges, mask, is_infected, prefix="example"):
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    axes[0].imshow(clean, cmap='gray'); axes[0].set_title('Clean Image')
    axes[1].imshow(noisy, cmap='gray'); axes[1].set_title('Noisy Image')
    axes[2].imshow(denoised, cmap='gray'); axes[2].set_title('Denoised (Kalman)')
    axes[3].imshow(edges, cmap='gray'); axes[3].set_title('Sobel Edges')
    axes[4].imshow(mask, cmap='gray'); axes[4].set_title('Ground Truth Mask')
    for ax in axes: ax.axis('off')
    status = "Infected" if is_infected else "Uninfected"
    plt.suptitle(f"{prefix} - {status}", fontsize=16)
    plt.savefig(os.path.join(OUTPUT_DIR, f"{prefix}_preprocessing.png"))
    plt.close()

clean_inf, mask_inf, is_infected_inf = generate_synthetic_data(uninfected=False)
noisy_inf = add_gaussian_noise(clean_inf, std=0.2)
denoised_inf = kalman_filter(noisy_inf, a=0.3, b=0.3)
edges_inf = sobel_edges(denoised_inf)
visualize_preprocessing(clean_inf, noisy_inf, denoised_inf, edges_inf, mask_inf, is_infected_inf, prefix="infected")

clean_uninf, mask_uninf, is_infected_uninf = generate_synthetic_data(uninfected=True)
noisy_uninf = add_gaussian_noise(clean_uninf, std=0.2)
denoised_uninf = kalman_filter(noisy_uninf, a=0.3, b=0.3)
edges_uninf = sobel_edges(denoised_uninf)
visualize_preprocessing(clean_uninf, noisy_uninf, denoised_uninf, edges_uninf, mask_uninf, is_infected_uninf, prefix="uninfected")

# Part 2: Dataset Creation
class LeishManiaDataset(Dataset):
    def __init__(self, num_samples=1000, img_size=256, is_test=False, noise_std=0.2, kalman_a=0.3, kalman_b=0.3, uninfected_prob=0.2, use_kalman_edges=True):
        self.num_samples = num_samples
        self.img_size = img_size
        self.is_test = is_test
        self.noise_std = noise_std
        self.kalman_a = kalman_a
        self.kalman_b = kalman_b
        self.uninfected_prob = uninfected_prob
        self.use_kalman_edges = use_kalman_edges
        self.samples = [self._generate_sample() for _ in range(num_samples)] if self.is_test else None

    def _generate_sample(self):
        uninfected = np.random.rand() < self.uninfected_prob
        clean, mask, is_infected = generate_synthetic_data(self.img_size, uninfected=uninfected)
        noisy = add_gaussian_noise(clean, std=self.noise_std)
        if self.use_kalman_edges:
            denoised = kalman_filter(noisy, a=self.kalman_a, b=self.kalman_b)
            edges = sobel_edges(denoised)
            inputs = torch.tensor(np.stack([noisy, denoised, edges], axis=0), dtype=torch.float32)
        else:
            inputs = torch.tensor(noisy, dtype=torch.float32).unsqueeze(0)
        target = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
        return inputs, target, is_infected

    def __len__(self): return self.num_samples

    def __getitem__(self, idx):
        if self.is_test and self.samples is not None:
            return self.samples[idx][0], self.samples[idx][1]
        else:
            inputs, target, _ = self._generate_sample()
            return inputs, target

NOISE_STD = 0.2
KALMAN_A = 0.3
KALMAN_B = 0.3
UNINFECTED_PROB = 0.2

train_dataset_knet = LeishManiaDataset(num_samples=800, noise_std=NOISE_STD, kalman_a=KALMAN_A, kalman_b=KALMAN_B, uninfected_prob=UNINFECTED_PROB, use_kalman_edges=True)
val_dataset_knet = LeishManiaDataset(num_samples=200, is_test=True, noise_std=NOISE_STD, kalman_a=KALMAN_A, kalman_b=KALMAN_B, uninfected_prob=UNINFECTED_PROB, use_kalman_edges=True)
test_dataset_knet = LeishManiaDataset(num_samples=200, is_test=True, noise_std=NOISE_STD, kalman_a=KALMAN_A, kalman_b=KALMAN_B, uninfected_prob=UNINFECTED_PROB, use_kalman_edges=True)

train_dataset_base = LeishManiaDataset(num_samples=800, noise_std=NOISE_STD, kalman_a=KALMAN_A, kalman_b=KALMAN_B, uninfected_prob=UNINFECTED_PROB, use_kalman_edges=False)
val_dataset_base = LeishManiaDataset(num_samples=200, is_test=True, noise_std=NOISE_STD, kalman_a=KALMAN_A, kalman_b=KALMAN_B, uninfected_prob=UNINFECTED_PROB, use_kalman_edges=False)
test_dataset_base = LeishManiaDataset(num_samples=200, is_test=True, noise_std=NOISE_STD, kalman_a=KALMAN_A, kalman_b=KALMAN_B, uninfected_prob=UNINFECTED_PROB, use_kalman_edges=False)

train_loader_knet = DataLoader(train_dataset_knet, batch_size=8, shuffle=True)
val_loader_knet = DataLoader(val_dataset_knet, batch_size=8, shuffle=False)
test_loader_knet = DataLoader(test_dataset_knet, batch_size=8, shuffle=False)

train_loader_base = DataLoader(train_dataset_base, batch_size=8, shuffle=True)
val_loader_base = DataLoader(val_dataset_base, batch_size=8, shuffle=False)
test_loader_base = DataLoader(test_dataset_base, batch_size=8, shuffle=False)

# Part 3: Model Definitions
class KNetSeg(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = conv_block(256, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)
        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.final(d1)

class BaselineUNet(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = conv_block(256, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)
        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.final(d1)

# Loss Function
class HybridLoss(nn.Module):
    def __init__(self, alpha=0.7):
        super().__init__()
        self.alpha = alpha
    def forward(self, pred, target):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(pred, target)
        pred_sigmoid = torch.sigmoid(pred)
        intersection = (pred_sigmoid * target).sum()
        dice_loss = 1 - (2. * intersection + 1e-5) / (pred_sigmoid.sum() + target.sum() + 1e-5)
        total_loss = self.alpha * dice_loss + (1 - self.alpha) * bce_loss
        return total_loss, bce_loss, dice_loss

# Early Stopping
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.best_state = None
    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.best_state = model.state_dict()
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# Training Function with Loss Components
def train(model, train_loader, val_loader, epochs, device, model_name):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)  # Removed verbose
    early_stopping = EarlyStopping(patience=5, min_delta=0.005)
    scaler = GradScaler('cuda')
    criterion = HybridLoss(alpha=0.8)

    train_losses, train_bce_losses, train_dice_losses = [], [], []
    val_losses, val_bce_losses, val_dice_losses, val_dices = [], [], [], []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_bce_loss = 0.0
        epoch_dice_loss = 0.0
        train_preds, train_targets = [], []

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            with autocast('cuda'):
                outputs = model(inputs)
                total_loss, bce_loss, dice_loss = criterion(outputs, targets)
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += total_loss.item()
            epoch_bce_loss += bce_loss.item()
            epoch_dice_loss += dice_loss.item()

            preds = torch.sigmoid(outputs) > 0.5
            train_preds.extend(preds.cpu().numpy().flatten())
            train_targets.extend(targets.cpu().numpy().flatten())

        val_loss, val_bce, val_dice, val_dice_score = validate(model, val_loader, device, criterion)
        train_losses.append(epoch_loss / len(train_loader))
        train_bce_losses.append(epoch_bce_loss / len(train_loader))
        train_dice_losses.append(epoch_dice_loss / len(train_loader))
        val_losses.append(val_loss)
        val_bce_losses.append(val_bce)
        val_dice_losses.append(val_dice)
        val_dices.append(val_dice_score)

        print(f"{model_name} Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_losses[-1]:.4f} (BCE: {train_bce_losses[-1]:.4f}, Dice: {train_dice_losses[-1]:.4f}) | "
              f"Val Loss: {val_loss:.4f} (BCE: {val_bce:.4f}, Dice: {val_dice:.4f}) | Val Dice Score: {val_dice_score:.4f}")

        scheduler.step(val_dice_score)
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print(f"{model_name} Early stopping triggered")
            model.load_state_dict(early_stopping.best_state)
            break
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f'{model_name}_best_model.pth'))

    torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f'{model_name}_final.pth'))
    return model, train_losses, train_bce_losses, train_dice_losses, val_losses, val_bce_losses, val_dice_losses, val_dices

# Validation Function with Loss Components
def validate(model, loader, device, criterion):
    model.eval()
    val_loss = 0.0
    val_bce = 0.0
    val_dice = 0.0
    total_dice_score = 0.0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            with autocast('cuda'):
                outputs = model(inputs)
                total_loss, bce_loss, dice_loss = criterion(outputs, targets)
            preds = torch.sigmoid(outputs)
            intersection = (preds * targets).sum()
            dice_score = (2. * intersection) / (preds.sum() + targets.sum() + 1e-5)
            val_loss += total_loss.item()
            val_bce += bce_loss.item()
            val_dice += dice_loss.item()
            total_dice_score += dice_score.item()
    return (val_loss / len(loader), val_bce / len(loader), val_dice / len(loader), total_dice_score / len(loader))

# Testing Function with Metrics and Loss Components
def test(model, test_loader, device, criterion, model_name):
    model.eval()
    test_loss = 0.0
    test_bce = 0.0
    test_dice = 0.0
    test_dice_score = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            total_loss, bce_loss, dice_loss = criterion(outputs, targets)
            preds = torch.sigmoid(outputs) > 0.5

            test_loss += total_loss.item()
            test_bce += bce_loss.item()
            test_dice += dice_loss.item()
            intersection = (preds * targets).sum()
            dice_score = (2. * intersection) / (preds.sum() + targets.sum() + 1e-5)
            test_dice_score += dice_score.item()

            all_preds.extend(preds.cpu().numpy().flatten())
            all_targets.extend(targets.cpu().numpy().flatten())

    test_loss /= len(test_loader)
    test_bce /= len(test_loader)
    test_dice /= len(test_loader)
    test_dice_score /= len(test_loader)
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    cm = confusion_matrix(all_targets, all_preds)
    tn, fp, fn, tp = cm.ravel()
    iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    print(f"{model_name} Test Loss: {test_loss:.4f} (BCE: {test_bce:.4f}, Dice: {test_dice:.4f}) | "
          f"Test Dice Score: {test_dice_score:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | "
          f"F1: {f1:.4f} | IoU: {iou:.4f} | Specificity: {specificity:.4f}")
    print(f"{model_name} Confusion Matrix: TN={tn}, FP={fp}, FN={fn}, TP={tp}")

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'{model_name} Test Set Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(os.path.join(OUTPUT_DIR, f'{model_name}_test_confusion_matrix.png'))
    plt.close()

    return test_loss, test_bce, test_dice, test_dice_score, precision, recall, f1, iou, specificity, cm

# Visualization Function
def visualize_test_results(model, test_loader, device, model_name, num_examples=3):
    model.eval()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(test_loader):
            if i >= num_examples:
                break
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            preds = torch.sigmoid(outputs) > 0.5

            fig, axes = plt.subplots(1, 4 if model_name == "KNetSeg" else 2, figsize=(20 if model_name == "KNetSeg" else 10, 5))
            if model_name == "KNetSeg":
                axes[0].imshow(inputs[0, 0].cpu().numpy(), cmap='gray')
                axes[0].set_title('Noisy Input')
                axes[1].imshow(inputs[0, 1].cpu().numpy(), cmap='gray')
                axes[1].set_title('Denoised Input')
                axes[2].imshow(inputs[0, 2].cpu().numpy(), cmap='gray')
                axes[2].set_title('Edge Input')
                axes[3].imshow(preds[0, 0].cpu().numpy(), cmap='gray')
                axes[3].set_title('Predicted Mask')
            else:
                axes[0].imshow(inputs[0, 0].cpu().numpy(), cmap='gray')
                axes[0].set_title('Noisy Input')
                axes[1].imshow(preds[0, 0].cpu().numpy(), cmap='gray')
                axes[1].set_title('Predicted Mask')
            for ax in axes:
                ax.axis('off')
            plt.savefig(os.path.join(OUTPUT_DIR, f'{model_name}_test_example_{i+1}.png'))
            plt.close()

# Train and Test Both Models
EPOCHS = 50

# KNetSeg
set_seed(42)
knet_model = KNetSeg(in_channels=3).to(DEVICE)
knet_trained, knet_train_losses, knet_train_bce_losses, knet_train_dice_losses, knet_val_losses, knet_val_bce_losses, knet_val_dice_losses, knet_val_dices = train(
    knet_model, train_loader_knet, val_loader_knet, EPOCHS, DEVICE, "KNetSeg"
)
knet_model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'KNetSeg_best_model.pth'), weights_only=True))
knet_test_results = test(knet_model, test_loader_knet, DEVICE, HybridLoss(), "KNetSeg")
visualize_test_results(knet_model, test_loader_knet, DEVICE, "KNetSeg")

# BaselineUNet
set_seed(42)
base_model = BaselineUNet(in_channels=1).to(DEVICE)
base_trained, base_train_losses, base_train_bce_losses, base_train_dice_losses, base_val_losses, base_val_bce_losses, base_val_dice_losses, base_val_dices = train(
    base_model, train_loader_base, val_loader_base, EPOCHS, DEVICE, "BaselineUNet"
)
base_model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'BaselineUNet_best_model.pth'), weights_only=True))
base_test_results = test(base_model, test_loader_base, DEVICE, HybridLoss(), "BaselineUNet")
visualize_test_results(base_model, test_loader_base, DEVICE, "BaselineUNet")

# Comparison Plots
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(knet_train_losses) + 1), knet_train_losses, label='KNetSeg Train Loss')
plt.plot(range(1, len(base_train_losses) + 1), base_train_losses, label='BaselineUNet Train Loss')
plt.plot(range(1, len(knet_val_losses) + 1), knet_val_losses, label='KNetSeg Val Loss')
plt.plot(range(1, len(base_val_losses) + 1), base_val_losses, label='BaselineUNet Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Comparison')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(OUTPUT_DIR, 'loss_comparison.png'))
plt.close()

plt.figure(figsize=(12, 6))
plt.plot(range(1, len(knet_val_dices) + 1), knet_val_dices, label='KNetSeg Val Dice')
plt.plot(range(1, len(base_val_dices) + 1), base_val_dices, label='BaselineUNet Val Dice')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.title('Validation Dice Score Comparison')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(OUTPUT_DIR, 'dice_comparison.png'))
plt.close()

# Plot Loss Components
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(knet_train_bce_losses) + 1), knet_train_bce_losses, label='KNetSeg Train BCE')
plt.plot(range(1, len(knet_train_dice_losses) + 1), knet_train_dice_losses, label='KNetSeg Train Dice')
plt.plot(range(1, len(base_train_bce_losses) + 1), base_train_bce_losses, label='BaselineUNet Train BCE')
plt.plot(range(1, len(base_train_dice_losses) + 1), base_train_dice_losses, label='BaselineUNet Train Dice')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training BCE and Dice Loss Comparison')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(OUTPUT_DIR, 'train_loss_components.png'))
plt.close()

plt.figure(figsize=(12, 6))
plt.plot(range(1, len(knet_val_bce_losses) + 1), knet_val_bce_losses, label='KNetSeg Val BCE')
plt.plot(range(1, len(knet_val_dice_losses) + 1), knet_val_dice_losses, label='KNetSeg Val Dice')
plt.plot(range(1, len(base_val_bce_losses) + 1), base_val_bce_losses, label='BaselineUNet Val BCE')
plt.plot(range(1, len(base_val_dice_losses) + 1), base_val_dice_losses, label='BaselineUNet Val Dice')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation BCE and Dice Loss Comparison')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(OUTPUT_DIR, 'val_loss_components.png'))
plt.close()

print(f"Training and testing complete for both models. Results and comparisons saved in {OUTPUT_DIR}")

Mounted at /content/drive
KNetSeg Epoch 1/50 | Train Loss: 0.7213 (BCE: 0.1729, Dice: 0.8584) | Val Loss: 0.4750 (BCE: 0.0248, Dice: 0.5875) | Val Dice Score: 0.4125
KNetSeg Epoch 2/50 | Train Loss: 0.2229 (BCE: 0.0109, Dice: 0.2759) | Val Loss: 0.1280 (BCE: 0.0082, Dice: 0.1579) | Val Dice Score: 0.8421
KNetSeg Epoch 3/50 | Train Loss: 0.1040 (BCE: 0.0077, Dice: 0.1281) | Val Loss: 0.1112 (BCE: 0.0093, Dice: 0.1367) | Val Dice Score: 0.8633
KNetSeg Epoch 4/50 | Train Loss: 0.0835 (BCE: 0.0070, Dice: 0.1027) | Val Loss: 0.0775 (BCE: 0.0073, Dice: 0.0950) | Val Dice Score: 0.9050
KNetSeg Epoch 5/50 | Train Loss: 0.0690 (BCE: 0.0066, Dice: 0.0845) | Val Loss: 0.0825 (BCE: 0.0099, Dice: 0.1006) | Val Dice Score: 0.8994
KNetSeg Epoch 6/50 | Train Loss: 0.0625 (BCE: 0.0061, Dice: 0.0767) | Val Loss: 0.0638 (BCE: 0.0068, Dice: 0.0780) | Val Dice Score: 0.9220
KNetSeg Epoch 7/50 | Train Loss: 0.0557 (BCE: 0.0062, Dice: 0.0681) | Val Loss: 0.0540 (BCE: 0.0059, Dice: 0.0660) | Val Dice Score: 0