In [1]:
# Important notes:
# In order to run the notebook and script successfully in a conda environment,
# you need to install necessary dependencies:
# torch numpy matplotlib import_ipynb 

In [1]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [10]:
# Define the Dataset Class:
# Use PyTorch's Dataset class to encapsulate satellite image data and masks into datasets 
# that can be used for model training
class SatelliteDataset(Dataset):
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # Get orginal image and mask
        image = self.data_list[idx]['image']
        mask = self.data_list[idx]['mask']

        #print(f"Original image shape: {image.shape}")  # Print original image shape
        #print(f"Original mask shape: {mask.shape}")    # print original image shape
        #print(f"Original mask unique values: {np.unique(mask)}")  # Print unique number in mask

        # Normalization
        image = image.astype(np.float32) / 10000
        mask = mask.astype(np.float32)

        # Check the unique number in mask after normalization
        #print(f"Normalized mask unique values: {np.unique(mask)}")

        if self.transform:
            image = self.transform(image)

        # Convert to tensor
        image = torch.from_numpy(image.transpose(2, 0, 1))  # HWC -> CHW
        mask = torch.from_numpy(mask).unsqueeze(0)  # Add channel dimension
        
        #print(f"Transformed image shape: {image.shape}")  
        #print(f"Transformed mask shape: {mask.shape}")    
        #print(f"tensor mask unique values: {np.unique(mask)}")

        return image, mask

# Define the model:
# Use a pre-trained UNet model, for brevity, define a simple UNet model manually here
class UNet(nn.Module):
    def __init__(self, dropout_prob=0.5):
        super(UNet, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),  # Dropout after ReLU
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),  # Dropout after ReLU
            nn.MaxPool2d(kernel_size=2)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),  # Dropout after ReLU
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Define loss function - DiceLoss:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, outputs, targets):
        intersection = (outputs * targets).sum(dim=[2, 3])  # Sum over spatial dimensions
        union = outputs.sum(dim=[2, 3]) + targets.sum(dim=[2, 3])

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        dice_loss = 1 - dice.mean()
        return dice_loss

# Define loss function - FocalLoss:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        #BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')  # Using logits for numerical stability
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') # Using numerical stability
        pt = torch.exp(-BCE_loss)  # Probability of true class
        Focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return Focal_loss.mean()
        elif self.reduction == 'sum':
            return Focal_loss.sum()
        else:
            return Focal_loss

# Define loss function - CombinedLoss:
# DiceLoss + FocalLoss
class CombinedLoss(nn.Module):
    def __init__(self, smooth=1.0, dice_weight=3.0, focal_weight=1.0, alpha=1.0, gamma=2.0):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss(smooth=smooth)
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight

    def forward(self, outputs, targets):
        dice = self.dice_loss(outputs, targets)
        focal = self.focal_loss(outputs, targets)
        return self.dice_weight * dice + self.focal_weight * focal

# Define Training Functions:
def train_model(model, dataloader, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, masks in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            
            #print(f"train_model - Model output min: {outputs.min()}, max: {outputs.max()}")
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

# Define Evaluation Functions:
# Use IoU, Dice Coefficient, Precision, Recall, F1 Score as the evaluation metric
def evaluate_model(model, dataloader):
    model.eval()
    
    thresholds = [0.1, 0.3, 0.5, 0.6, 0.7]
    for threshold in thresholds:
        total_iou = 0.0
        total_dice = 0.0
        total_precision = 0.0
        total_recall = 0.0
        total_f1 = 0.0
        num_batches = len(dataloader)
        
        for inputs, masks in dataloader:
            with torch.no_grad():
                outputs = model(inputs)
            
                predicted_masks = (outputs > threshold).float()
            
                # Print some example predictions and masks for debugging
                #print(f"evaluate_model- Model output min: {outputs.min()}, max: {outputs.max()}")
                #print(f"evaluate_model - predicted_masks min: {predicted_masks.min()}, max: {predicted_masks.max()}")
            
                intersection = (predicted_masks * masks).sum().float()
                union = (predicted_masks.sum() + masks.sum() - intersection).float()

                # Calculate IoU
                iou = intersection / union if union != 0 else torch.tensor(0.0)
                total_iou += iou.item()

                # Calculate Dice Coefficient
                dice = (2 * intersection) / (predicted_masks.sum() + masks.sum()) if (predicted_masks.sum() + masks.sum()) != 0 else torch.tensor(0.0)
                total_dice += dice.item()
    
                # Calculate Precision
                tp = (predicted_masks * masks).sum().float()  # True Positives
                fp = (predicted_masks * (1 - masks)).sum().float()  # False Positives
                precision = tp / (tp + fp) if (tp + fp) != 0 else torch.tensor(0.0)
                total_precision += precision.item()

                # Calculate Recall
                fn = ((1 - predicted_masks) * masks).sum().float()  # False Negatives
                recall = tp / (tp + fn) if (tp + fn) != 0 else torch.tensor(0.0)
                total_recall += recall.item()

                # Calculate F1 Score
                f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else torch.tensor(0.0)
                total_f1 += f1.item()
                
        print(f"Validation threshold: {threshold}")
        print(f"Validation IoU: {total_iou / num_batches:.4f}")
        print(f"Validation Dice Coefficient: {total_dice / num_batches:.4f}")
        print(f"Validation Precision: {total_precision / num_batches:.4f}")
        print(f"Validation Recall: {total_recall / num_batches:.4f}")
        print(f"Validation F1 Score: {total_f1 / num_batches:.4f}")

In [8]:
# Function for running inference and calculating IoU on validation set
def run_inference(model_file, dataset_file):
    with open(dataset_file, 'rb') as f:
        data = pickle.load(f)

    val_dataset = SatelliteDataset(data['val'])
    val_loader = DataLoader(val_dataset, batch_size=4)

    # Load the model
    model = UNet()
    model.load_state_dict(torch.load(model_file))

    # Run inference and evaluate
    evaluate_model(model, val_loader)

In [11]:
# In order to avoid running the rest of the code of the entire notebook when calling run_inference,
# separate the execution logic from the function definition by using if __name__ == ‘__main__’ block

# This way, operations such as training, validation, and model saving will only be performed when the notebook 
# is running directly, and will not be triggered when someone just imports and calls run_inference

if __name__ == "__main__":
    # Load the dataset.pickle file
    with open('dataset.pickle', 'rb') as f:
        data = pickle.load(f)

    # Dataset format: [{"image": np.array, "mask": np.array}]
    train_data = data['train']
    val_data = data['val']

    # Create train and validation datasets
    train_dataset = SatelliteDataset(train_data)
    val_dataset = SatelliteDataset(val_data)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=4)
    val_loader = DataLoader(val_dataset, batch_size=4)

    # Instantiate the model, loss function, and optimizer
    model = UNet(dropout_prob=0.3)
    # criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
    # criterion = DiceLoss(smooth=1.0)
    criterion = CombinedLoss(smooth=1.0, dice_weight=4.0, focal_weight=0.5, alpha=1.0, gamma=2.0)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # Train the model
    train_model(model, train_loader, criterion, optimizer, num_epochs=5)

    # Evaluate the model
    evaluate_model(model, val_loader)

    # Save the trained model
    torch.save(model.state_dict(), "unet_model.pth")

Epoch [1/5], Loss: 3.8161
Epoch [2/5], Loss: 3.7946
Epoch [3/5], Loss: 3.7503
Epoch [4/5], Loss: 3.7121
Epoch [5/5], Loss: 3.6646
Validation threshold: 0.1
Validation IoU: 0.0709
Validation Dice Coefficient: 0.1299
Validation Precision: 0.0710
Validation Recall: 0.9728
Validation F1 Score: 0.1299
Validation threshold: 0.3
Validation IoU: 0.1124
Validation Dice Coefficient: 0.1964
Validation Precision: 0.1142
Validation Recall: 0.8684
Validation F1 Score: 0.1964
Validation threshold: 0.5
Validation IoU: 0.1336
Validation Dice Coefficient: 0.2275
Validation Precision: 0.1515
Validation Recall: 0.5372
Validation F1 Score: 0.2275
Validation threshold: 0.6
Validation IoU: 0.0144
Validation Dice Coefficient: 0.0283
Validation Precision: 0.1461
Validation Recall: 0.0171
Validation F1 Score: 0.0283
Validation threshold: 0.7
Validation IoU: 0.0000
Validation Dice Coefficient: 0.0001
Validation Precision: 0.1605
Validation Recall: 0.0000
Validation F1 Score: 0.0001
