In [1]:
try:
    # Essential imports
    import re
    import os
    import sys
    import json
    import glob
    import torch
    import shutil
    import random
    import torch.nn as nn
    import logging
    import rasterio
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import geopandas as gpd
    import matplotlib.pyplot as plt
    from PIL import Image
    from datetime import datetime
    from pathlib import Path
    from scipy import stats
    from typing import Dict, Tuple, Optional
    from collections import defaultdict
    from rasterio import plot
    from rasterio.mask import mask
    from shapely.ops import unary_union
    from shapely.wkt import dumps, loads
    from shapely.geometry import mapping, box, Polygon, MultiPolygon
    from rasterio.windows import from_bounds, bounds as window_bounds
    from s2cloudless import S2PixelCloudDetector
    from tqdm import tqdm
    from tqdm.notebook import tqdm
    from sklearn.model_selection import TimeSeriesSplit
except Exception as e:
    print(f"Error : {e}")

In [2]:
# Print the PyTorch version
print(f"PyTorch version: {torch.__version__}")

# Check if running in Google Colab
if "google.colab" in str(get_ipython()):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = 'cpu'
        print("GPU not available in Colab, consider enabling a GPU runtime.")
# Running on a local machine
else:
    if torch.backends.mps.is_available():
        device = 'mps'
        print(f"Is Apple MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
        print(f"Is Apple MPS available? {torch.backends.mps.is_available()}")
    elif torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

# TODO: Add support for AMD ROCm GPU if needed

# Print the device being used
print(f"Using device: {device}")

PyTorch version: 2.6.0.dev20241112
Is Apple MPS (Metal Performance Shader) built? True
Is Apple MPS available? True
Using device: mps


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)

class UNetDiff(nn.Module):
    def __init__(self):
        super().__init__()
        
        # ResNet-18 backbone
        resnet = models.resnet18(pretrained=True)
        
        # Encoder (modified ResNet-18)
        self.encoder1 = nn.Sequential(
            nn.Conv2d(27, 64, kernel_size=7, stride=2, padding=3, bias=False),
            resnet.bn1,
            resnet.relu
        )
        self.pool = resnet.maxpool
        self.encoder2 = resnet.layer1
        self.encoder3 = resnet.layer2
        
        # Decoder
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder3 = ConvBlock(128, 64)
        
        self.upconv2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder2 = ConvBlock(96, 32)
        
        self.upconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.decoder1 = ConvBlock(80, 16)
        
        self.final_conv = nn.Conv2d(16, 1, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)  # 64 channels
        pool1 = self.pool(enc1)
        
        enc2 = self.encoder2(pool1)  # 64 channels
        enc3 = self.encoder3(enc2)  # 128 channels
        
        # Decoder with skip connections
        dec3 = self.upconv3(enc3)
        dec3 = torch.cat([dec3, enc2], dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc1], dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, x], dim=1)
        dec1 = self.decoder1(dec1)
        
        # Final 1x1 convolution
        out = self.final_conv(dec1)
        return torch.sigmoid(out)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        dice = (2.0 * intersection + self.smooth) / (
            predictions.sum() + targets.sum() + self.smooth
        )
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.2, dice_weight=0.8):
        super().__init__()
        self.bce_loss = nn.BCELoss()
        self.dice_loss = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        
    def forward(self, predictions, targets):
        bce = self.bce_loss(predictions, targets)
        dice = self.dice_loss(predictions, targets)
        return self.bce_weight * bce + self.dice_weight * dice

def create_optimizer_and_scheduler(model):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[10, 40, 80, 150],
        gamma=0.1
    )
    
    return optimizer, scheduler

import albumentations as A
from torch.utils.data import Dataset, DataLoader

# Define augmentations
train_transform = A.Compose([
    A.RandomCrop(int(224 * 0.7), int(224 * 0.7)),
    A.Resize(224, 224),
    A.RandomBrightnessContrast(p=0.5),
    A.ElasticTransform(p=0.5),
    A.GridDistortion(p=0.5),
    A.CoarseDropout(max_holes=8, max_height=20, max_width=20, p=0.5)
])

def train_model(model, train_loader, val_loader, num_epochs=200):
    model = model.to(device)
    
    criterion = CombinedLoss()
    optimizer, scheduler = create_optimizer_and_scheduler(model)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        # Validation phase
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_loss += criterion(output, target).item()
        
        scheduler.step()
        
        # Print epoch results
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_unet_diff.pth')

In [4]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path

class TemporalStackDataset(Dataset):
    def __init__(self, root_dir, transform=None, split='train'):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.split = split
        
        # Get all plot directories
        plot_dirs = sorted([d for d in self.root_dir.iterdir() if d.is_dir()])
        
        # Split train/val
        split_idx = int(0.8 * len(plot_dirs))
        self.plot_dirs = plot_dirs[:split_idx] if split == 'train' else plot_dirs[split_idx:]
    
    def __len__(self):
        return len(self.plot_dirs)
    
    def load_temporal_stack(self, stack_dir):
        npy_files = sorted(list(stack_dir.glob('*.npy')))
        if not npy_files:
            raise ValueError(f"No .npy files found in {stack_dir}")
        return np.concatenate([np.load(f) for f in npy_files], axis=-1)
    
    def __getitem__(self, idx):
        plot_dir = self.plot_dirs[idx]
        
        # Load stacks
        pre_event_dir = plot_dir / 'Pre-event'
        post_event_dir = plot_dir / 'Post-event'
        
        try:
            pre_stack = self.load_temporal_stack(pre_event_dir)
            post_stack = self.load_temporal_stack(post_event_dir)
            
            # Concatenate pre and post stacks
            x = np.concatenate([pre_stack, post_stack], axis=-1)
            
            # Create dummy mask if needed
            mask = np.zeros((224, 224, 1), dtype=np.float32)
            
            if self.transform:
                transformed = self.transform(image=x, mask=mask)
                x = transformed['image']
                mask = transformed['mask']
            
            # Convert to torch tensors
            x = torch.from_numpy(x).float().permute(2, 0, 1)
            mask = torch.from_numpy(mask).float()
            
            return x, mask
            
        except Exception as e:
            print(f"Error loading data from {plot_dir}: {str(e)}")
            raise

In [5]:
from pathlib import Path
import numpy as np

def inspect_data_shapes(root_dir):
    root = Path(root_dir)
    for plot_dir in root.glob('PLOT-*'):
        print(f"\nInspecting {plot_dir.name}")
        
        # Check pre-event
        pre_dir = plot_dir / 'Pre-event'
        pre_files = list(pre_dir.glob('*.npy'))
        if pre_files:
            pre_shape = np.load(pre_files[0]).shape
            print(f"Pre-event shape: {pre_shape}")
            
        # Check post-event    
        post_dir = plot_dir / 'Post-event'
        post_files = list(post_dir.glob('*.npy'))
        if post_files:
            post_shape = np.load(post_files[0]).shape
            print(f"Post-event shape: {post_shape}")

# Run inspection
inspect_data_shapes('../Datasets/Testing/TemporalStacks')


Inspecting PLOT-00026
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00019
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00021
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00017
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00028
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00044
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00043
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00011
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00016
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00029
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00020
Pre-event shape: (9, 224, 224)
Post-event shape: (9, 224, 224)

Inspecting PLOT-00027
Pre-event shape: (9, 224, 224)


In [6]:
# from dataset import TemporalStackDataset
import torch
from torch.utils.data import DataLoader

# Initialize datasets with fewer workers for debugging
train_dataset = TemporalStackDataset(
    root_dir='../Datasets/Testing/TemporalStacks',
    transform=train_transform,
    split='train'
)

val_dataset = TemporalStackDataset(
    root_dir='../Datasets/Testing/TemporalStacks',
    transform=None,
    split='val'
)

# Create data loaders with fewer workers
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0  # Start with 0 for debugging
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0  # Start with 0 for debugging
)

# Test loading one batch
try:
    test_batch = next(iter(train_loader))
    print("Test batch shape:", test_batch[0].shape)
except Exception as e:
    print("Error loading test batch:", str(e))

Error loading data from ../Datasets/Testing/TemporalStacks/PLOT-00015: No .npy files found in ../Datasets/Testing/TemporalStacks/PLOT-00015/Pre-event
Error loading test batch: No .npy files found in ../Datasets/Testing/TemporalStacks/PLOT-00015/Pre-event
