In [None]:
# Cavity Detection System in Colab - UNET + Training + Testing
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# ----------------------------- UNet Model -----------------------------
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(3, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.upconv3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec3 = conv_block(256, 128)

        self.upconv2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = conv_block(128, 64)

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        d3 = self.upconv3(e3)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)

        out = self.final(d2)
        return out

# ----------------------------- Dataset -----------------------------
class CavityDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.image_paths = sorted([f for f in os.listdir(root_dir)
                                   if not os.path.isdir(os.path.join(root_dir, f)) and
                                   not f == 'labels'])

        self.labels_dir = os.path.join(root_dir, 'labels')
        if os.path.exists(self.labels_dir):
            self.label_paths = sorted(os.listdir(self.labels_dir))
        else:
            print(f"Warning: Labels directory not found at {self.labels_dir}")
            self.label_paths = []

        valid_images = []
        valid_labels = []

        for img_name in self.image_paths:
            img_base = os.path.splitext(img_name)[0]
            matched = [lbl for lbl in self.label_paths if lbl.startswith(img_base)]
            if matched:
                valid_images.append(img_name)
                valid_labels.append(matched[0])

        self.image_paths = valid_images
        self.label_paths = valid_labels
        print(f"Loaded {len(self.image_paths)} image-label pairs from {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_paths[idx])
        label_path = os.path.join(self.labels_dir, self.label_paths[idx])

        image = Image.open(img_path).convert('RGB')
        label = Image.open(label_path).convert('L')

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

# ----------------------------- Training Setup -----------------------------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

train_dataset = CavityDataset('/content/dataset/train', transform=transform)
valid_dataset = CavityDataset('/content/dataset/valid', transform=transform)
test_dataset = CavityDataset('/content/dataset/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=5, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ----------------------------- Training Loop -----------------------------
num_epochs = 15
best_val_loss = float('inf')
print("Starting training...")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for images, masks in valid_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, masks).item()

    val_loss /= len(valid_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), '/content/unet_cavity_best.pth')
        print(f"Epoch {epoch+1}: ✅ Saved new best model")

    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {running_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}")

# ----------------------------- Testing + Visualization -----------------------------
model.load_state_dict(torch.load('/content/unet_cavity_best.pth'))
model.eval()

test_dice = 0.0
num_samples = 0

plt.figure(figsize=(15, 10))

with torch.no_grad():
    for idx, (images, masks) in enumerate(test_loader):
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)

        pred = torch.sigmoid(outputs).squeeze().cpu().numpy()
        pred_binary = (pred > 0.5).astype(np.float32)
        mask = masks.squeeze().cpu().numpy()
        img = images.squeeze().permute(1, 2, 0).cpu().numpy()
        img = (img - img.min()) / (img.max() - img.min())

        intersection = np.sum(pred_binary * mask)
        dice = (2. * intersection) / (np.sum(pred_binary) + np.sum(mask) + 1e-6)
        test_dice += dice
        num_samples += 1

        plt.subplot(4, 3, idx * 3 + 1)
        plt.imshow(img)
        plt.title(f"Image {idx+1}")
        plt.axis('off')

        plt.subplot(4, 3, idx * 3 + 2)
        plt.imshow(mask, cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')

        plt.subplot(4, 3, idx * 3 + 3)
        plt.imshow(pred_binary, cmap='gray')
        plt.title(f"Prediction (Dice: {dice:.3f})")
        plt.axis('off')

        if idx >= 3:
            break

plt.tight_layout()
plt.savefig("/content/cavity_results.png")
plt.show()

print(f"✅ Average Dice coefficient on test set: {test_dice / num_samples:.4f}")
