In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
class WildfireDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        """
        image_dir: directory with input images.
        mask_dir: directory with segmentation masks.
        transform: torchvision transforms to apply (same for image and mask).
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_list = sorted(os.listdir(image_dir))
        self.mask_list = sorted(os.listdir(mask_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_list[idx])
        mask_path  = os.path.join(self.mask_dir, self.mask_list[idx])
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # grayscale for binary mask
        
        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        
        # Binarize the mask: assuming values > 0 indicate the wildfire risk area.
        mask = (mask > 0.5).float()
        return image, mask

In [3]:
# Residual Convolutional Block
class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out + residual  # learn residual function

# Attention Gate for skip connections
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        """
        F_g: Number of channels in the gating (decoder) signal.
        F_l: Number of channels in the skip (encoder) connection.
        F_int: Number of intermediate channels.
        """
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, g, x):
        # g: gating signal (from decoder), x: features from encoder
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = F.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

In [4]:
# U-Net with Residual Blocks and Attention Gates
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder path
        current_channels = in_channels
        for feature in features:
            self.downs.append(ResidualConvBlock(current_channels, feature))
            current_channels = feature

        # Bottleneck
        self.bottleneck = ResidualConvBlock(features[-1], features[-1]*2)

        # Decoder path
        self.ups = nn.ModuleList()
        self.attention_gates = nn.ModuleList()
        for feature in reversed(features):
            # Up-convolution (transpose conv)
            self.ups.append(nn.ConvTranspose2d(features[-1]*2, feature, kernel_size=2, stride=2))
            # Attention gate: gating signal channels = feature, skip connection channels = feature
            self.attention_gates.append(AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2))
            # Decoder conv block; note that concatenation doubles channels (skip + upsampled)
            self.ups.append(ResidualConvBlock(feature * 2, feature))
            features[-1] = feature  # update for next iteration

        # Final 1x1 conv to produce segmentation map
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        # Encoder
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        # Bottleneck
        x = self.bottleneck(x)
        # Reverse skip connections for decoder
        skip_connections = skip_connections[::-1]
        # Decoder
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # upsample
            skip_connection = skip_connections[idx // 2]
            # Apply attention gate to encoder features
            attn = self.attention_gates[idx // 2](x, skip_connection)
            # Concatenate along channel dimension
            x = torch.cat((attn, x), dim=1)
            x = self.ups[idx+1](x)  # conv block
        return self.final_conv(x)

In [5]:
def iou_metric(pred, target, threshold=0.5, eps=1e-6):
    pred = (pred > threshold).float()
    target = (target > threshold).float()
    intersection = (pred * target).sum(dim=[1,2,3])
    union = (pred + target - pred * target).sum(dim=[1,2,3])
    iou = (intersection + eps) / (union + eps)
    return iou.mean()

def dice_coefficient(pred, target, threshold=0.5, eps=1e-6):
    pred = (pred > threshold).float()
    target = (target > threshold).float()
    intersection = (pred * target).sum(dim=[1,2,3])
    dice = (2 * intersection + eps) / (pred.sum(dim=[1,2,3]) + target.sum(dim=[1,2,3]) + eps)
    return dice.mean()

def dice_loss(pred, target, threshold=0.5, eps=1e-6):
    return 1 - dice_coefficient(pred, target, threshold, eps)

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        d_loss = dice_loss(torch.sigmoid(pred), target)
        return bce_loss + d_loss

In [6]:
def visualize_results(image, true_mask, pred_mask, prob_map):
    """
    Visualizes:
    1. A heatmap of the predicted probability map.
    2. An overlay of the predicted binary mask on the original image.
    3. An overlay of the ground truth mask on the original image.
    
    image: Tensor of shape (C, H, W)
    true_mask: Tensor of shape (1, H, W)
    pred_mask: Tensor of shape (1, H, W)
    prob_map: Tensor of shape (1, H, W)
    """
    image_np = image.permute(1, 2, 0).cpu().numpy()
    true_mask_np = true_mask.squeeze().cpu().numpy()
    pred_mask_np = pred_mask.squeeze().cpu().numpy()
    prob_map_np = prob_map.squeeze().cpu().numpy()

    plt.figure(figsize=(15,4))
    
    # Heatmap
    plt.subplot(1, 3, 1)
    plt.imshow(prob_map_np, cmap='hot')
    plt.title('Prediction Heatmap')
    plt.colorbar()
    
    # Overlay prediction: red for predicted risk area
    plt.subplot(1, 3, 2)
    overlay_pred = image_np.copy()
    red_mask = np.zeros_like(image_np)
    red_mask[..., 0] = 1  # red channel
    overlay_pred = np.where(pred_mask_np[..., None] > 0.5, 0.5 * image_np + 0.5 * red_mask, image_np)
    plt.imshow(overlay_pred)
    plt.title('Overlayed Prediction')
    
    # Overlay ground truth: green for true wildfire regions
    plt.subplot(1, 3, 3)
    overlay_gt = image_np.copy()
    green_mask = np.zeros_like(image_np)
    green_mask[..., 1] = 1  # green channel
    overlay_gt = np.where(true_mask_np[..., None] > 0.5, 0.5 * image_np + 0.5 * green_mask, image_np)
    plt.imshow(overlay_gt)
    plt.title('Ground Truth Overlay')
    
    plt.tight_layout()
    plt.show()

In [7]:
def train_model(model, train_loader, val_loader, num_epochs, device):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = BCEDiceLoss()
    best_val_iou = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        model.eval()
        val_iou = 0.0
        val_dice = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                probs = torch.sigmoid(outputs)
                val_iou += iou_metric(probs, masks).item()
                val_dice += dice_coefficient(probs, masks).item()
        avg_val_iou = val_iou / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Val IoU: {avg_val_iou:.4f}, Val Dice: {avg_val_dice:.4f}")

        # Save best model
        if avg_val_iou > best_val_iou:
            best_val_iou = avg_val_iou
            torch.save(model.state_dict(), "best_model.pth")

    print("Training complete. Best Val IoU:", best_val_iou)

In [None]:
def main():      
    try:
        import gdown
    except ImportError:
        os.system('pip install gdown')
        import gdown

    dataset_txt_url = "https://drive.google.com/uc?id=1MVmx927A4AUZksHqb3liW4lUf5ts9l6v"
    labels_zip_url = "https://drive.google.com/uc?id=1Gq0VXTElJWWfCcwnxLxA5CVWuls0is-y"
    
    if not os.path.exists("dataset.txt"):
        gdown.download(dataset_txt_url, output="dataset.txt", quiet=False)
    if not os.path.exists("labels.zip"):
        gdown.download(labels_zip_url, output="labels.zip", quiet=False)
        os.system("unzip -oq labels.zip -d .")
        
    image_paths = []
    mask_paths = []
    with open("dataset.txt", "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # Assuming each line is "image_path mask_path" (space-separated or comma-separated)
            parts = line.split()  # split by whitespace; use split(',') if comma-separated
            if len(parts) >= 2:
                img_path = parts[0]
                mask_path = parts[1]
                image_paths.append(img_path)
                mask_paths.append(mask_path)

    os.makedirs("data/images", exist_ok=True)
    os.makedirs("data/masks", exist_ok=True)
    # Move or copy files into these directories
    for img, mask in zip(image_paths, mask_paths):
        # os.rename or shutil.copy to move files
        dest_img = os.path.join("data/images", os.path.basename(img))
        dest_mask = os.path.join("data/masks", os.path.basename(mask))
        if not os.path.exists(dest_img):
            os.replace(img, dest_img)  # move image file
        if not os.path.exists(dest_mask):
            os.replace(mask, dest_mask)  # move mask file
    # Update paths to point to the new locations
    image_paths = [os.path.join("data/images", os.path.basename(p)) for p in image_paths]
    mask_paths  = [os.path.join("data/masks",  os.path.basename(p)) for p in mask_paths]

    from sklearn.model_selection import train_test_split
    
    # First split off a test set from the full dataset
    train_imgs, test_imgs, train_masks, test_masks = train_test_split(
        image_paths, mask_paths, test_size=0.2, random_state=42)
    
    # From the remaining training data, split out a validation set
    train_imgs, val_imgs, train_masks, val_masks = train_test_split(
        train_imgs, train_masks, test_size=0.1, random_state=42)

    
    # Define transforms: resize images/masks and convert to tensor.
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    
    # Create datasets and dataloaders
    train_dataset = WildfireDataset(train_imgs, train_masks, transform=transform)
    val_dataset   = WildfireDataset(val_imgs, val_masks, transform=transform)
    test_dataset  = WildfireDataset(test_imgs, test_masks, transform=transform)

    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    
    # Initialize the model
    model = UNet(in_channels=3, out_channels=1)
    model = model.to(device)
    
    # Train the model
    num_epochs = 20  # adjust as needed
    train_model(model, train_loader, val_loader, num_epochs, device)
    
    # Load best model checkpoint for evaluation
    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()
    
    # Evaluate on test set and visualize results
    with torch.no_grad():
        for idx, (image, mask) in enumerate(test_loader):
            image = image.to(device)
            mask = mask.to(device)
            output = model(image)
            prob_map = torch.sigmoid(output)
            pred_mask = (prob_map > 0.5).float()
            print(f"Visualizing Test Image {idx+1}")
            visualize_results(image[0].cpu(), mask[0].cpu(), pred_mask[0].cpu(), prob_map[0].cpu())

if __name__ == "__main__":
    main()