In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install opendatasets pandas
!pip install evaluate datasets
!pip install transformers[torch] accelerate -U
!pip install rasterio

In [None]:
# unzip the folder and save the data
!unzip /content/drive/MyDrive/boundary_demarcation/Planet/DATA_full/planet_patches_png.zip -d /content/drive/MyDrive/boundary_demarcation/Planet/DATA_full/Images
!unzip /content/drive/MyDrive/boundary_demarcation/Planet/DATA_full/planet_masks_png.zip -d /content/drive/MyDrive/boundary_demarcation/Planet/DATA_full/Masks

In [None]:
# import libraries
import cv2
import tifffile
from pathlib import Path
import shutil
import concurrent.futures
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import (
    SegformerForSemanticSegmentation,
    TrainingArguments, Trainer,
    SegformerImageProcessor)
from datasets import Dataset, Image
import evaluate
import matplotlib.pyplot as plt

In [None]:
plt.rcParams['figure.figsize'] = 12, 12

In [None]:
DATA_DIR = Path('/content/drive/MyDrive/boundary_demarcation/Planet/DATA_full')

In [None]:
# Pre-trained models
MODEL_CHECKPOINT = 'nvidia/mit-b4'


VAL_SIZE = 0.2
BATCH_SIZE = 4
EPOCHS = 10
LR = 0.00006

IMG_SIZE = 512

## Data processing

In [None]:
# Combine the base directory with the 'Images' subdirectory
images_path = DATA_DIR / 'Images'
images = list(images_path.glob('**/*.png'))
images = [str(path) for path in images]
print(f'{len(images)} images detected.')

# Combine the base directory with the 'Masks' subdirectory
masks_path = DATA_DIR / 'Masks'
masks = list(masks_path.glob('**/*.png'))
masks = [str(path) for path in masks]
print(f'{len(masks)} masks detected.')

In [None]:
# Ensure the image and mask paths are sorted consistently
images.sort()
masks.sort()

In [None]:
# Split the data into train and validation sets

train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=VAL_SIZE, random_state=0, shuffle=True)
print(f'Train images: {len(train_images)}\nValidation images: {len(val_images)}')

In [None]:
# check the training masks
train_masks

In [None]:
def create_dataset(image_paths, mask_paths):
    dataset = Dataset.from_dict({'pixel_values': image_paths,
                                 'label': mask_paths})
    dataset = dataset.cast_column('pixel_values', Image())
    dataset = dataset.cast_column('label', Image())
    return dataset


ds_train = create_dataset(train_images, train_masks)
ds_valid = create_dataset(val_images, val_masks)

In [None]:
ds_train

In [None]:
ds_valid

### display and check images and masks

In [None]:
import cv2
import matplotlib.pyplot as plt

# Path to the image file
image_path = '/content/drive/MyDrive/boundary_demarcation/results/water/53987_image.tif'

# Path to the image file
mask_path = '/content/drive/MyDrive/boundary_demarcation/results/water/53987_mask.png'


# Read the image using tifffile
image = tifffile.imread(image_path)
# Read the image using OpenCV
mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)


fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(6, 6))

# Display the image
ax1.imshow(image)
ax1.set_title('Image')
ax1.axis('Off')  # Turn off axis


# Display the image
# ax2.imshow(mask)
ax2.imshow(mask, cmap='Blues')
ax2.set_title('Mask')
ax2.axis('Off')  # Turn off axis

# Show the plot
plt.show()

In [None]:
np.unique(mask)

In [None]:
# Get the properties of the image
height, width, channels = image.shape
data_type = image.dtype

print(f'Image Dimensions: {width}x{height}')
print(f'Number of Channels: {channels}')
print(f'Data Type: {data_type}')

## .

In [None]:
# Image preprocessing native to the pretrained model.
feature_extractor = SegformerImageProcessor.from_pretrained(MODEL_CHECKPOINT)

In [None]:
# apply transformations to the image
def apply_transforms(batch):
    images = [x for x in batch['pixel_values']]
    labels = [x for x in batch['label']]
    inputs = feature_extractor(images, labels)
    return inputs


ds_train.set_transform(apply_transforms)
ds_valid.set_transform(apply_transforms)

In [None]:
ds_train, ds_valid

## Model Training

In [None]:
# for rgb water
id2label = {0: 'background', 1: 'water'}

label2id = {label: id for id, label in id2label.items()}
num_labels = len(id2label)

model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

In [None]:
metric = evaluate.load('mean_iou')


def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode='bilinear',
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=feature_extractor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

        return metrics

In [None]:
training_args = TrainingArguments(
    'segformer_finetuned_water_planet_RGB_fullimage_2607',
    learning_rate=LR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    save_total_limit=3,
    evaluation_strategy='steps',
    save_strategy='steps',
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to='none'
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
# save the trained model
model.save_pretrained('/content/drive/MyDrive/MA/Models/segformer_finetuned_water_planet_RGB_fullimage_2607')

In [None]:
id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
num_labels = len(id2label)

model = SegformerForSemanticSegmentation.from_pretrained(
   '/content/drive/MyDrive/MA/Models/segformer_finetuned_water_planet_RGB_fullimage_2607',
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)


feature_extractor = SegformerImageProcessor.from_pretrained(MODEL_CHECKPOINT)

## Visualise and compare the satellite image, true mask and predicted mask

In [None]:
for i in range(len(val_images)):
    image_path = val_images[i]
    mask_path = val_masks[i]

    image = tifffile.imread(image_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
    print(f'Validation image #{i + 1}')

    inputs = np.moveaxis(image, -1, 0)
    inputs = feature_extractor(images=image, return_tensors='pt')

    outputs = model(**inputs)
    logits = outputs.logits

    # Rescale logits to original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.shape[:-1], # (height, width)
        mode='bilinear',
        align_corners=False
    )

    # Apply argmax on the class dimension
    pred_mask = upsampled_logits.argmax(dim=1)[0]


    fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3)

    ax1.imshow(image)
    ax1.set_title('Image')
    ax1.axis('Off')

    ax2.imshow(mask, cmap='Blues')
    ax2.set_title('True mask')
    ax2.axis('Off')

    ax3.imshow(pred_mask, cmap='Blues')
    ax3.set_title('Predicted mask')
    ax3.axis('Off')
    plt.show()

In [None]:
val_images

In [None]:
val_masks

### predict and save all images

In [None]:
import os
import numpy as np
import tifffile
import cv2
import matplotlib.pyplot as plt
import torch.nn.functional as nnf
from PIL import Image

def save_predictions(val_images, val_masks, model, feature_extractor, output_folder):
    os.makedirs(output_folder, exist_ok=True)

    for i in range(len(val_images)):
        image_path = val_images[i]
        mask_path = val_masks[i]

        image = tifffile.imread(image_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        print(f'Validation image #{i + 1}')

        inputs = np.moveaxis(image, -1, 0)
        inputs = feature_extractor(images=image, return_tensors='pt')

        outputs = model(**inputs)
        logits = outputs.logits

        # Rescale logits to original image size
        upsampled_logits = nnf.interpolate(
            logits,
            size=image.shape[:-1], # (height, width)
            mode='bilinear',
            align_corners=False
        )

        # Apply argmax on the class dimension
        pred_mask = upsampled_logits.argmax(dim=1)[0].cpu().numpy()

        # Create the output file path with '_pred' attached to the original mask name
        mask_name = os.path.basename(mask_path)
        mask_name_without_ext, ext = os.path.splitext(mask_name)
        pred_mask_name = f"{mask_name_without_ext}_pred.png"
        pred_mask_path = os.path.join(output_folder, pred_mask_name)

        # Save the predicted mask as a PNG file
        pred_mask_image = Image.fromarray(pred_mask.astype(np.uint8))
        pred_mask_image.save(pred_mask_path)

        # Display the images and masks
        fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3)

        ax1.imshow(image)
        ax1.set_title('Image')
        ax1.axis('Off')

        ax2.imshow(mask, cmap='Blues')
        ax2.set_title('True mask')
        ax2.axis('Off')

        ax3.imshow(pred_mask, cmap='Blues')
        ax3.set_title('Predicted mask')
        ax3.axis('Off')

        plt.show()


output_folder = "/content/drive/MyDrive/MA/Output/Prediction_masks"


save_predictions(val_images, val_masks, model, feature_extractor, output_folder)