<a href="https://colab.research.google.com/github/liangchow/zindi-amazon-secret-runway/blob/main/zindi_airstrip_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and Setup.

In [None]:
%%capture
!pip -q install rasterio
!pip -q install torch
!pip -q install torchvision
!pip -q install albumentations
!pip -q install segmentation-models-pytorch

In [None]:
import os
import torch
import rasterio
import numpy as np
import albumentations as A
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from torchvision import transforms as T
from albumentations.pytorch import ToTensorV2
import torch.nn.functional as F

# Download training data to local compute node

## Mount your Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Compress training files, copy over and uncompress

In [None]:
# Navigate to the shared directory
%cd /content/drive/MyDrive/Zindi-Amazon/training
# Zip the data
!zip -r /content/images.zip images
!zip -r /content/masks.zip masks
# Unzip the files
!unzip /content/images.zip -d /content
!unzip /content/masks.zip -d /content

/content/drive/.shortcut-targets-by-id/14mw0v8Bi-MzhsqSI0K3KO23YrUHttM7P/Zindi-Amazon/training
  adding: images/ (stored 0%)
  adding: images/Sentinel_AllBands_Training_Id_20.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_59.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_61.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_78.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_79.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_105.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_120.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_135.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_139.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_138.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_66.tif (deflated 4%)
  adding: images/Sentinel_AllBands_Training_Id_176.tif (deflated 5%)
  adding: images/Sentinel_AllBands_Training_Id_157.ti

# Custom classes and functions to handle Sentinel 1/2 data and corresponding masks

In [None]:
# Custom Dataset class for Sentinel 1/2 bands and mask
class Sentinel2Dataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image and mask using rasterio
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        with rasterio.open(image_path) as src:
            # Extract band indexes from descriptions and read directly
            bands = {'B4': None, 'B3': None, 'B2': None}
            for i, desc in enumerate(src.descriptions):
                if desc in bands:
                    bands[desc] = src.read(i + 1)  # Read 1-based bands

            # Check if all required bands were found
            if any(value is None for value in bands.values()):
                raise ValueError(f"Not all bands found in image: {image_path}")

            # Stack the selected bands to form the final image array
            image = np.stack([bands['B4'], bands['B3'], bands['B2']], axis=-1)

        with rasterio.open(mask_path) as mask_src:
            mask = mask_src.read(1)

        # Ensure both image and mask are numpy arrays before applying transforms
        image = np.array(image)
        mask = np.array(mask)

        # Apply transformations if provided
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Only convert to PyTorch tensors if not already tensors (skip if the transform does it)
        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image).permute(2, 0, 1).float()  # Channels first
        if not isinstance(mask, torch.Tensor):
            mask = torch.from_numpy(mask).long()

        return image, mask


In [None]:
# Augmentations using albumentations and PyTorch's ToTensor
def get_augmentations():
    return A.Compose([
        A.RandomCrop(width=224, height=224),  # ResNet50 input size can be adjusted here
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # ImageNet normalization
        ToTensorV2(),
    ])

In [None]:
# UNet model with ResNet50 encoder
def get_model():
    model = smp.Unet(
        encoder_name='resnet50',        # Choose encoder, ResNet50 in this case
        encoder_weights='imagenet',     # Use pre-trained weights
        in_channels=3,                  # RGB, modify later for more channels
        classes=1,                      # Binary segmentation
    )
    return model

In [None]:
# Training loop
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()

    for images, masks in dataloader:
        images = images.to(device)
        # Change data type of masks to float32 before moving to device
        masks = masks.type(torch.float32).to(device).unsqueeze(1)  # Add channel dimension to masks
        #masks = masks.to(device).unsqueeze(1)  # Add channel dimension to masks

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

# Google Colab GPU
Check that the GPU enabled in your colab notebook by running the cell below.

In [None]:
# Check is GPU is enabled
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: {}".format(device))

# Get specific GPU model
if str(device) == "cuda:0":
  print("GPU: {}".format(torch.cuda.get_device_name(0)))

Device: cpu


# Main loop

In [None]:
image_dir = '/content/images'
mask_dir = '/content/masks'
batch_size = 8
epochs = 1

# List image and mask file paths
image_paths = [os.path.join(image_dir, img) for img in sorted(os.listdir(image_dir)) if img.endswith('.tif')]
mask_paths = [os.path.join(mask_dir, mask) for mask in sorted(os.listdir(mask_dir)) if mask.endswith('.tif')]

# Create dataset and dataloader
dataset = Sentinel2Dataset(image_paths=image_paths, mask_paths=mask_paths, transform=get_augmentations())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Load the model, define loss and optimizer
model = get_model().to(device)
criterion = torch.nn.BCEWithLogitsLoss()  # Binary cross-entropy loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train_model(model, dataloader, criterion, optimizer, device)




Epoch 1/1
Loss: 1.0809482336044312
Loss: 0.8926908373832703
Loss: 0.7775659561157227
Loss: 0.6872310042381287
Loss: 0.6094898581504822
Loss: 0.5601739883422852
Loss: 0.5101388096809387
Loss: 0.5109736919403076
Loss: 0.4596864879131317
Loss: 0.4398171603679657
Loss: 0.42671141028404236
Loss: 0.3719524145126343
Loss: 0.3568221628665924
Loss: 0.3903155028820038
