In [1]:
%%capture
! pip install rasterio


In [2]:
import os
import pandas as pd
import numpy as np
import sys
import shutil
import re
from PIL import Image
import rasterio
import matplotlib.pyplot as plt
import dask.array as da
from scipy.ndimage import binary_dilation
from skimage.morphology import disk  # For circular structuring elements
import torch
from torchvision import transforms
import torchvision.transforms.functional as vF
import torch.nn.functional as F
import gdown
from tqdm import tqdm
import random

from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, jaccard_score, hamming_loss, label_ranking_loss, coverage_error, classification_report
import sklearn.metrics as metr



In [3]:
import torch
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset

def set_seed(seed):
    """
    Set random seeds for NumPy, PyTorch (CPU and GPU), and Python's random module.
    
    Args:
        seed (int): Seed value for RNGs
    """
    # Python random
    random.seed(seed)
    
    # NumPy
    np.random.seed(seed)
    
    # PyTorch CPU
    torch.manual_seed(seed)
    
    # PyTorch GPU (CUDA)
    torch.cuda.manual_seed(seed)  # Current GPU
    torch.cuda.manual_seed_all(seed)  # All GPUs
    
    # Ensure deterministic behavior
    #torch.use_deterministic_algorithms(True)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    """
    Initialize random seed for DataLoader workers.
    Ensures each worker has a unique but reproducible RNG state.
    
    Args:
        worker_id (int): Worker ID
    """
    max_seed = 2**32 - 1  # NumPy seed limit
    worker_seed = (torch.initial_seed() + worker_id) % max_seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    

In [4]:
def create_LR_dataframe(splits_path, mode='train'):
    split_images_files = {'train' : 'train_X.txt', 'val' : 'val_X.txt', 'test' : 'test_X.txt'}
    split_masks_files = {'train' : 'train_masks.txt', 'val' : 'val_masks.txt', 'test' : 'test_masks.txt'}  
    with open(os.path.join(splits_path, split_images_files[mode]), "r") as file:
        images = file.readlines()  # Reads all lines into a list
        images = [image.strip() for image in images]  # Remove any trailing newline characters
    with open(os.path.join(splits_path, split_masks_files[mode]), "r") as file:
        masks = file.readlines()  # Reads all lines into a list
        masks = [mask.strip() for mask in masks]  # Remove any trailing newline characters
    df = pd.DataFrame({'image' : images, 'mask' : masks})
    return df



In [5]:
# from Sagar and Navodita's code
def compute_fdi_from_tiff(tiff_path):
    with rasterio.open(tiff_path) as src:
        # Assuming band order follows your stacked TIFF (B1–B12, skipping B10 if needed)
        # Band indices are 1-based in rasterio
        R665 = src.read(4)    # B4
        R859 = src.read(9)    # B8A
        R1610 = src.read(10)  # B11
        # Convert to float and mask invalid values
        R665 = R665.astype(np.float32)
        R859 = R859.astype(np.float32)
        R1610 = R1610.astype(np.float32)
        # Calculate FDI
        FDI = R859 - (R665 + ((R1610 - R665) * (859 - 665) / (1610 - 665)))
        return FDI

def cvt_to_fdi(images):
    fdi_images = []
    batch = images.copy()
    if len(images.shape) == 3 : 
        batch = batch[None, :]
    for i in range(batch.shape[0]):
        im = batch[i]
        R665 = im[3]   # B4
        R859 = im[8]   # B8A
        R1610 = im[0]  # B11
        # Convert to float and mask invalid values
        R665 = R665.astype(np.float32)
        R859 = R859.astype(np.float32)
        R1610 = R1610.astype(np.float32)
        # Calculate FDI
        FDI = R859 - (R665 + ((R1610 - R665) * (859 - 665) / (1610 - 665)))
        fdi_images.append(FDI)
    return np.array(fdi_images)
    
def compute_ndwi(tiff_path):
    with rasterio.open(tiff_path) as src:
        Rgreen = src.read(3).astype(np.float32)  # Band 3 (Green)
        Rnir = src.read(8).astype(np.float32)    # Band 8 (NIR)
        ndwi = (Rgreen - Rnir) / (Rgreen + Rnir + 1e-6)  # avoid divide-by-zero
    return ndwi
def plot_fdi(fdi_array, ndwi, img_path, mask_path):
    with rasterio.open(img_path) as src:
        rgb = src.read([4, 3, 2])
        rgb = np.transpose(rgb, (1, 2, 0))
    # Normalization
    rgb = rgb.astype(np.float32)
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
    with rasterio.open(mask_path) as src:
        mask = src.read(1)
    # Create binary mask
    mask_binary = mask > 0
    # Plot side-by-side
    fig, axs = plt.subplots(1, 4, figsize=(15, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Patch")
    axs[1].imshow(mask_binary)  #, cmap='gray')
    axs[1].set_title("Binary Mask (._cl.tif)")
    axs[2].imshow(fdi_array)
    axs[2].set_title("FDI")
    axs[3].imshow(ndwi)
    axs[3].set_title("NDWI")
    for ax in axs:
        ax.axis('off')

    # with rasterio.open(patch_path) as patch_src:
    #     rgb = patch_src.read([4, 3, 2])  # Use bands B4, B3, B2 for RGB
    #     rgb = np.transpose(rgb, (1, 2, 0))
    #     rgb = (rgb - np.min(rgb)) / (np.max(rgb) - np.min(rgb) + 1e-6)
    import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# List of image and mask file paths (replace with your file paths)
image_mask_pairs = [
    ('path_to_image1.jpg', 'path_to_mask1.png'),
    ('path_to_image2.jpg', 'path_to_mask2.png'),
    # Add more pairs as needed
]


def cvt_RGB(images):
    rgb_images = []
    for i in range(images.shape[0]):
        rgb = images[i][[4-1, 3-1, 2-1]] # Use bands B4, B3, B2 for RGB
        rgb = np.transpose(rgb, (1, 2, 0))
        rgb = (rgb - np.min(rgb)) / (np.max(rgb) - np.min(rgb) + 1e-6)
        rgb_images.append(rgb)
    return np.array(rgb_images)

def display(images, masks):
    # Determine the number of pairs
    num_pairs = images.shape[0]

    # Calculate layout: use 2 columns per pair (image + mask), adjust rows dynamically
    cols = 2  # One column for image, one for mask
    rows = num_pairs  # One row per pair

    # Create a figure with subplots
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))

    # Handle case of single pair (axes is not a 2D array)
    if num_pairs == 1:
        axes = np.array([axes]).reshape(1, -1)

    # Iterate through each pair and display image and mask
    for idx, (image, mask) in enumerate(zip(images, masks)):

        # Display the original image
        axes[idx, 0].imshow(image)
        axes[idx, 0].set_title(f'Image {idx + 1}')
        axes[idx, 0].axis('off')  # Hide axes
    
        # Display the segmentation mask
        axes[idx, 1].imshow(mask, cmap='gray')  # Adjust cmap if needed
        axes[idx, 1].set_title(f'Mask {idx + 1}')
        axes[idx, 1].axis('off')  # Hide axes

    # Adjust layout to prevent overlap
    plt.tight_layout()
    
    # Show the plot
    plt.show()

In [6]:

def extract_date_tile(filename):
    """Extract date and tile from filename using regex."""
    pattern = r'^(\d{1,2}-\d{1,2}-\d{2})_([A-Z0-9]+)_\d+$'
    match = re.match(pattern, filename)
    if not match:
        raise ValueError(f"Invalid filename format: {filename}")
    return match.groups()  # Returns tuple (date, tile)

def create_marida_df(data_path, mode='train'):
    """Create DataFrame from MARIDA dataset files."""
    # Determine split file based on mode
    split_files = {'train': 'train_X.txt', 'val': 'val_X.txt', 'test': 'test_X.txt'}
    items_list_path = os.path.join(data_path, 'splits', split_files[mode])

    # Read items list
    with open(items_list_path, 'r') as file:
        items = [item.strip() for item in file]

    # Base path for patches
    items_path = os.path.join(data_path, 'patches')

    # Prepare data lists
    data = {
        'image': [],
        'mask': [],
        'confidence': [],
        'date': [],
        'tile': []
    }

    # Process each item
    for item in items:
        tile = "_".join(item.split("_")[:-1])
        tile_path = os.path.join(items_path, f"S2_{tile}")

        # Define file paths
        base_name = f'S2_{item}'
        paths = {
            'image': os.path.join(tile_path, f'{base_name}.tif'),
            'mask': os.path.join(tile_path, f'{base_name}_cl.tif'),
            'confidence': os.path.join(tile_path, f'{base_name}_conf.tif')
        }

        # Check if all files exist
        if all(os.path.exists(p) for p in paths.values()):
            data['image'].append(paths['image'])
            data['mask'].append(paths['mask'])
            data['confidence'].append(paths['confidence'])
            date, tile = extract_date_tile(item)
            data['date'].append(date)
            data['tile'].append(tile)

    return pd.DataFrame(data)

# MARIDA labels dictionary
MARIDA_LABELS = {
    i: label for i, label in enumerate([
        'Marine Debris', 'Dense Sargassum', 'Sparse Sargassum', 'Natural Organic Material',
        'Ship', 'Clouds', 'Marine Water', 'Sediment-Laden Water', 'Foam', 'Turbid Water',
        'Shallow Water', 'Waves', 'Cloud Shadows', 'Wakes', 'Mixed Water'
    ], 1)
}

In [7]:
import rasterio
import numpy as np

def compute_invalid_pixels(image_paths, mask_paths):
    """
    Compute per-band statistics for Sentinel-2 L1C ACOLITE-processed images using segmentation masks.
    Creates a mask to exclude invalid pixels (NaNs, negative values, specified no-data value).
    
    Parameters:
    - image_paths: List of paths to image files (e.g., GeoTIFF with 11 bands).
    - mask_paths: List of paths to segmentation mask files (single-band, integer class labels).
    - class_labels: List of mask class labels to include (e.g., [1, 2] for vegetation and water).
                   If None, include all non-zero labels (excluding background).
    - invalid_value: Optional value to treat as invalid in images (e.g., -9999).
    
    Returns:
    - mean_per_band: List of per-band means for each image.
    - std_per_band: List of per-band standard deviations for each image.
    """
    mean_per_band = []  # Initialize as list
    std_per_band = []   # Initialize as list
    positive_pixels = []
    tot_pixels = [];
    images_with_invalid_pixels = []
    black_list = []
    accumulator = None
    no_data_pixels = []
    neg_pixels = []
    nan_pixels = []
    gt1_pixels = []
    imgs_with_invalid = []
    positive_pixels = []
    min_vals = []
    max_vals = []
    for img_path, mask_path in zip(image_paths, mask_paths):
        # Load image and mask
        with rasterio.open(img_path) as src_img, rasterio.open(mask_path) as src_mask:
            image = src_img.read()  # Shape: (bands, height, width)
            mask = src_mask.read(1)  # Shape: (height, width)
            
            # Convert image to float for NaN handling
            image = image.astype(float)

            nan_mask = np.isnan(image)
            neg_mask = (image < 0)
            too_big_mask = (image > 1)
            no_data_mask = (image == src_img.nodata)
            nan_pixels.append(np.sum(nan_mask))
            neg_pixels.append(np.sum(neg_mask))
            gt1_pixels.append(np.sum(too_big_mask))
            no_data_pixels.append(np.sum(no_data_mask))
            imgs_with_invalid.append(img_path)
            positive_pixels.append(np.sum(mask > 0))
            min_vals.append(np.min(image))
            max_vals.append(np.max(image))
    df = pd.DataFrame({'image' : imgs_with_invalid, 'no data pixels' : no_data_pixels, 'negative pixels' : neg_pixels,
                      'nan pixels' : nan_pixels, 'high value pixels' :  gt1_pixels, 'debris pixels' : positive_pixels,
                      'min values' : min_vals, 'max values' : max_vals})
    return df

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size=16

In [9]:
def compute_stats(image_files, discard_negatives = False, discard_gt_1 = False):
    bands_std = []
    bands_mean = []
    valid_pixels = []

    for band_idx in range(11):
        arrays = [da.from_array(rasterio.open(f).read(band_idx + 1), chunks='auto')
                  for f in image_files]
        stack = da.stack(arrays)
        #valid = (stack != rasterio.open(image_files[0]).nodata) & (stack >= 0)
        if discard_negatives and  discard_gt_1: 
            valid = da.stack([da.from_array(rasterio.open(f).read(band_idx + 1) != rasterio.open(f).nodata, chunks='auto')
                              & (da.from_array(rasterio.open(f).read(band_idx + 1), chunks='auto') >= 0) & 
                              (da.from_array(rasterio.open(f).read(band_idx + 1), chunks='auto') <= 1) 
                              for f in image_files])
        elif discard_gt_1 :
            valid = da.stack([da.from_array(rasterio.open(f).read(band_idx + 1) != rasterio.open(f).nodata, chunks='auto')
                              & (da.from_array(rasterio.open(f).read(band_idx + 1), chunks='auto') <= 1)  
                              for f in image_files])
        elif discard_negatives:
            valid = da.stack([da.from_array(rasterio.open(f).read(band_idx + 1) != rasterio.open(f).nodata, chunks='auto')
                  & (da.from_array(rasterio.open(f).read(band_idx + 1), chunks='auto') >= 0) 
                  for f in image_files])
        else :
            valid = da.stack([da.from_array(rasterio.open(f).read(band_idx + 1) != rasterio.open(f).nodata, chunks='auto')
                  for f in image_files])
                         
        # Compute number of valid pixels
        valid_count = da.sum(valid).compute()
        valid_pixels.append(valid_count)
        mean = da.nanmean(stack[valid]).compute()
        std = da.nanstd(stack[valid]).compute()
        bands_mean.append(mean)
        bands_std.append(std)
        print(f"Band {band_idx} - Mean: {mean}, Std: {std}")
    return {'mean' : np.array(bands_mean), 'std': np.array(bands_std),'valid pixels' : np.array(valid_pixels) }


In [10]:
def computing_labeled_pixels_stats(mask_paths):
    arrays = [da.from_array(rasterio.open(f).read(1), chunks='auto')
                  for f in mask_paths]
    stack = da.stack(arrays)
    valid = stack > 0
    labeled_count = da.sum(valid).compute()
    return labeled_count

In [11]:
def compute_invalid_mask(path):
    with rasterio.open(path) as src:
        image = src.read()
        
        invalid_mask = image == src.nodata
        invalid_mask |= np.isnan(image)
        invalid_mask |= image < 0
        invalid_mask |= image > 1
        invalid_mask = np.any(invalid_mask, axis=0)
        return invalid_mask

In [12]:
def get_invalid_mask(image, no_data):
    invalid_mask = image == no_data
    invalid_mask |= np.isnan(image)
    #invalid_mask |= image < 0
    invalid_mask |= image > 1
    #invalid_mask = np.any(invalid_mask, axis=0)
    return invalid_mask  #torch.fromnumpy(invalid_mask)

In [13]:
def select_bg_pixels(image, debris_mask, r1=5, r2=20, target_ratio=5):
    H, W = debris_mask.shape
    
    #target_ratio = 5  # Debris-to-background ratio (1:5)

    # Create structuring elements (circular or square)
    se_r1 = disk(r1) if r1 > 0 else np.ones((1, 1))  # Inner dilation kernel
    se_r2 = disk(r2)                         # Outer dilation kernel
    #print('before binary dilation')
    # Dilate debris mask with r1 and r2
    dilated_r1 = binary_dilation(debris_mask, structure=se_r1)
    dilated_r2 = binary_dilation(debris_mask, structure=se_r2)
    #print('before anular mask')
    # Compute annular region: pixels in dilated_r2 but not in dilated_r1
    annular_mask = dilated_r2 & ~dilated_r1

    # Sample background pixels from annular region
    valid_background_coords = np.where(annular_mask)
    num_debris = np.sum(debris_mask)
    num_background = min(len(valid_background_coords[0]), num_debris * target_ratio)
    if num_background > 0:
        sample_idx = np.random.choice(len(valid_background_coords[0]), size=num_background, replace=False)
        background_coords = [(valid_background_coords[0][i], valid_background_coords[1][i]) for i in sample_idx]
    else:
        print("Warning: No valid background pixels in annular region. Increase r2 or check mask.")

    # Create background mask (optional, for visualization or training)
    background_mask = np.zeros_like(debris_mask)
    for x, y in background_coords:
        background_mask[x, y] = 1
    return background_mask

# Optional: Filter by features (e.g., RGB values for water-like pixels)
# Example: If image is RGB, filter pixels with low green channel (common for water)
# image = ...  # Your RGB or multispectral image
# valid_background = [coord for coord in background_coords if image[coord[0], coord[1], 1] < threshold]


In [14]:
def batch_process_marida_masks(masks, dataset_ids, device='cpu'):
    """
    Process masks for dataset_id == 0 (MARIDA) at the batch level.
    - Set classes [1, 2, 3, 4, 9] to 2 (debris).
    - Set class 0 to 0 (unlabeled), other classes to 1 (non-debris).
    
    Args:
        masks: Tensor [batch_size, H, W] (integer-valued masks)
        dataset_ids: Tensor [batch_size] (dataset IDs)
        device: Device for PyTorch operations ('cpu' or 'cuda')
    
    Returns:
        marida_masks: Tensor [batch_size, H, W] with values 0, 1, 2
    """
    batch_size, H, W = masks.shape
    marida_masks = torch.zeros_like(masks, dtype=torch.int64, device=device)
    
    # Identify masks with dataset_id == 0
    marida_mask = (dataset_ids == 0)  # [batch_size], boolean
    if not marida_mask.any():
        return marida_masks
    
    # Select masks for dataset_id == 0
    selected_masks = masks[marida_mask]  # [num_marida, H, W]
    
    # Set classes [1, 2, 3, 4, 9] to 2
    debris_classes = torch.tensor([1, 2, 3, 4, 9], device=device)
    is_debris = torch.isin(selected_masks, debris_classes)  # [num_marida, H, W]
    marida_masks[marida_mask] = torch.where(
        is_debris,
        torch.tensor(2, dtype=torch.int64, device=device),
        selected_masks  # Temporarily keep original values
    )
    # for idx in range( marida_masks[marida_mask].shape[0]):
    #     print(f' {idx} has {torch.sum(is_debris[idx])} : {torch.unique(marida_masks[marida_mask][idx])}')
    # Set non-zero, non-debris pixels to 1
    marida_masks[marida_mask] = torch.where(
        (marida_masks[marida_mask] != 0) & (marida_masks[marida_mask] != 2),
        torch.tensor(1, dtype=torch.int64, device=device),
        marida_masks[marida_mask]
    )
    # print('only 3 values : ')
    # for idx in range( marida_masks[marida_mask].shape[0]):
    #     print(f' {idx} has {torch.sum(is_debris[idx])} : {torch.unique(marida_masks[marida_mask][idx])}')
    marida_masks[marida_mask] = marida_masks[marida_mask] - 1
    #print('after subtr')
    # for idx in range( marida_masks[marida_mask].shape[0]):
    #     print(f' {idx} has {torch.sum(is_debris[idx])} : {torch.unique(marida_masks[marida_mask][idx])}')
    return marida_masks



# # Custom collate function
# def custom_collate_fn(batch):
#     images, masks, dataset_ids = zip(*batch)
#     images = torch.stack(images)
#     masks = torch.stack(masks)
#     dataset_ids = torch.tensor(dataset_ids, dtype=torch.long)
    
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     images, masks, dataset_ids = images.to(device), masks.to(device), dataset_ids.to(device)
    
#     final_masks = batch_select_bg_pixels(images, masks, dataset_ids, r1=5, r2=20, 
#                                          target_ratio=5, threshold=0.5, device=device)
    
#     return images, masks, final_masks, dataset_ids



In [15]:

def torch_dilate(mask, kernel_size, device='cpu'):
    """Apply dilation to a batch of masks using PyTorch convolution."""
    kernel = torch.ones(1, 1, kernel_size, kernel_size, device=device, dtype=torch.float32)
    mask = mask.float().unsqueeze(1)  # [batch_size, 1, H, W]
    dilated = torch.nn.functional.conv2d(mask, kernel, padding=kernel_size // 2) > 0
    return dilated.squeeze(1).bool()  # [batch_size, H, W]

def batch_select_bg_pixels(images, masks, dataset_ids, r1=5, r2=20, target_ratio=5, threshold=None, device='cpu'):
    """
    Compute annular background masks for a batch of masks, only for dataset_id == 1.
    - Set debris pixels (masks == 1) to 2 in bg_masks.
    - Set randomly sampled annular pixels to 1 in bg_masks.
    
    Args:
        images: Tensor [batch_size, C, H, W] 
        masks: Tensor [batch_size, H, W] (binary debris masks)
        dataset_ids: Tensor [batch_size] (dataset IDs)
        r1, r2: Radii for inner and outer dilation
        target_ratio: Debris-to-background pixel ratio
        threshold: Optional threshold for filtering (e.g., green channel)
        device: Device for PyTorch operations ('cpu' or 'cuda')
    
    Returns:
        bg_masks: Tensor [batch_size, H, W] with values 0 (default), 1 (background), 2 (debris)
    """

    batch_size, H, W = masks.shape
    # Initialize bg_masks with zeros (int64 to support values 0, 1, 2)
    bg_masks = torch.zeros_like(masks, dtype=torch.int64, device=device)
    
    # Identify masks to process (dataset_id == 1)
    valid_mask = (dataset_ids == 1)  # [batch_size], boolean{
    #print(f'LR indices {valid_mask}')
    if not valid_mask.any():
        return bg_masks  # Return zeros if no masks need processing
    
    # Select masks for dataset_id == 1
    selected_masks = masks[valid_mask]  # [num_valid, H, W]
    # for idx in range(selected_masks.shape[0]):
    #     print(f'num debris pixels : {torch.sum(selected_masks[idx])}')
    # Set debris pixels to 2 for selected masks
    bg_masks[valid_mask] = selected_masks * 2  # Where selected_masks == 1, set bg_masks to 2
    
    # Perform dilation on selected masks
    dilated_r1 = torch_dilate(selected_masks, 2 * r1 + 1, device=device)  # [num_valid, H, W]
    dilated_r2 = torch_dilate(selected_masks, 2 * r2 + 1, device=device)  # [num_valid, H, W]
    annular_masks = dilated_r2 & ~dilated_r1  # [num_valid, H, W]
    
    # Sample background pixels for each selected mask
    for idx in range(annular_masks.shape[0]):
        valid_coords = torch.where(annular_masks[idx])  # Tuple of (row, col) indices
        #print(f'unique values in mask {idx} : {torch.unique(selected_masks[idx])}')
        num_debris = torch.sum(selected_masks[idx] > 0).item()
        #print(f'num debris for index {idx} : {num_debris}')
        num_background = min(len(valid_coords[0]), int(num_debris * target_ratio))
        
        if num_background > 0:
            # Randomly sample indices and set to 1
            sample_indices = torch.randperm(len(valid_coords[0]), device=device)[:num_background]
            bg_masks[valid_mask.nonzero(as_tuple=True)[0][idx], 
                     valid_coords[0][sample_indices], 
                     valid_coords[1][sample_indices]] = 1
        else :
            print(f'no background selected for index {idx}. Num debrid : {num_debris} Num background : {num_background}')
            print(f'valid coords {len(valid_coords)}')
            print(f'unique valus : {torch.unique(selected_masks[idx])}')
    
    # # Optional: Filter by image features (e.g., green channel) for dataset_id == 1
    # if threshold is not None and images is not None:
    #     valid_pixels = images[valid_mask, 1, :, :] < threshold  # Green channel
    #     # Only apply filtering to background pixels (value 1), preserve debris pixels (value 2)
    #     bg_masks[valid_mask] = torch.where(
    #         bg_masks[valid_mask] == 1,
    #         bg_masks[valid_mask] & valid_pixels,
    #         bg_masks[valid_mask]
    #     )
    bg_masks[valid_mask] = bg_masks[valid_mask] - 1
    return bg_masks

# Custom collate function
def custom_collate_fn(batch):
    # print(f'custom collate function batch {len(batch)}')
    # print(f'custom collate function batch type {type(batch)}')
    # print(f'custom collate function batch[1] type {type(batch[1])}')
    # print(f'custom collate function batch[1] len  {len(batch[1])}')
    images, masks, dataset_ids = zip(*batch)
    images = torch.stack(images)  # [batch_size, C, H, W]
    masks = torch.stack(masks)    # [batch_size, H, W]
    dataset_ids = torch.tensor(dataset_ids, dtype=torch.long)  # [batch_size]
    
    # Move to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    images, masks, dataset_ids = images.to(device), masks.to(device), dataset_ids.to(device)
    
    # Compute background masks
    lr_masks = batch_select_bg_pixels(images, masks, dataset_ids, r1=5, r2=20, 
                                      target_ratio=40, device=device)
    marida_masks = batch_process_marida_masks(masks, dataset_ids, device=device)
    masks = lr_masks + marida_masks
    
    return images, masks, dataset_ids




In [16]:
# Seeding for reproducibility
seed = 42
set_seed(seed)

In [17]:
%%capture
# Download some pre-computed data 


file_id = "1NvgyeN-k-pRXF114IFhB-AcW86M9e7km"
gdown.download(f'https://drive.google.com/uc?id={file_id}', '/kaggle/working/marida_stats.npz', quiet=False)
file_id = "160mw3xELYG_44yemtf2vrBg6n-Q_e8LO"
gdown.download(f'https://drive.google.com/uc?id={file_id}', '/kaggle/working/marida_df_invalid_info.csv', quiet=False)
file_id = "1a-4sJZ4NUZsNHuaSBRKoAG7Dy2O4kHWj"
gdown.download(f'https://drive.google.com/uc?id={file_id}', '/kaggle/working/litter_rows_df_invalid_info.csv', quiet=False)
file_id = "1sEiP73c4I3S4KK-58c6J02wgmayu3EYa"
gdown.download(f'https://drive.google.com/uc?id={file_id}', '/kaggle/working/litter_rows_stats.npz', quiet=False)
file_id = "1wrD41CDQud69AMOyHigw0-DR85Id4zDM"
gdown.download(f'https://drive.google.com/uc?id={file_id}', '/kaggle/working/global_stats.npz', quiet=False)

In [18]:
# check that the 
! ls /kaggle/input/litter-windrows-patches
# add the lr dataset to path to import code to prepare the dataset
sys.path.append('/kaggle/input/litter-windrows-patches')
# import functions to prepare dataset
from prepare_dataset import  get_image_and_mask_paths, split_and_save_data

annotations  prepare_dataset.ipynb  README.md
patches      prepare_dataset.py     splits


In [19]:
#! git clone https://github.com/sheikhazhanmohammed/SADMA.git

In [20]:
#sys.path.append('/kaggle/working/SADMA')

In [21]:
# define a variable for the lr dataset
LW_path = '/kaggle/input/litter-windrows-patches'
lr_images, lr_masks = get_image_and_mask_paths(LW_path)
! mkdir ./LR_splits
split_and_save_data(lr_images, lr_masks, './LR_splits' )

In [22]:
! ls ./LR_splits/splits
LR_splits_path = '/kaggle/working/LR_splits/splits'

test_masks.txt	train_masks.txt  val_masks.txt
test_X.txt	train_X.txt	 val_X.txt


In [23]:
# from IPython.display import display

# with open(LR_splits_path+'/train_X.txt', "r") as file:
#     display(file.read())


In [24]:
! ls /kaggle/input/marida-marine-debrish-dataset
MARIDA_path = '/kaggle/input/marida-marine-debrish-dataset'

labels_mapping.txt  patches  shapefiles  splits


In [25]:
lr_df = create_LR_dataframe(LR_splits_path)
pd.set_option("display.max_colwidth", None)
lr_df.head()

Unnamed: 0,image,mask
0,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20161126T094332_R036_T33SXC/S2A_MSIL1C_20161126T094332_R036_T33SXC_646080_4256500.tif,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20161126T094332_R036_T33SXC/S2A_MSIL1C_20161126T094332_R036_T33SXC_646080_4256500_cl.tif
1,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20180615T100031_R122_T33TUL/S2A_MSIL1C_20180615T100031_R122_T33TUL_353760_5056480.tif,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20180615T100031_R122_T33TUL/S2A_MSIL1C_20180615T100031_R122_T33TUL_353760_5056480_cl.tif
2,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20160814T100032_R122_T33TUL/S2A_MSIL1C_20160814T100032_R122_T33TUL_374240_5056480.tif,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20160814T100032_R122_T33TUL/S2A_MSIL1C_20160814T100032_R122_T33TUL_374240_5056480_cl.tif
3,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20210331T100021_R122_T33TUL/S2A_MSIL1C_20210331T100021_R122_T33TUL_379360_5020640.tif,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20210331T100021_R122_T33TUL/S2A_MSIL1C_20210331T100021_R122_T33TUL_379360_5020640_cl.tif
4,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20150830T100016_R122_T33TUL/S2A_MSIL1C_20150830T100016_R122_T33TUL_358880_5066720.tif,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20150830T100016_R122_T33TUL/S2A_MSIL1C_20150830T100016_R122_T33TUL_358880_5066720_cl.tif


In [26]:
marida_df = create_marida_df(MARIDA_path)
marida_df.head()

Unnamed: 0,image,mask,confidence,date,tile
0,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_0.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_0_cl.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_0_conf.tif,1-12-19,48MYU
1,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_1.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_1_cl.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_1_conf.tif,1-12-19,48MYU
2,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_2.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_2_cl.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_2_conf.tif,1-12-19,48MYU
3,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_3.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_3_cl.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_3_conf.tif,1-12-19,48MYU
4,/kaggle/input/marida-marine-debrish-dataset/patches/S2_11-1-19_19QDA/S2_11-1-19_19QDA_0.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_11-1-19_19QDA/S2_11-1-19_19QDA_0_cl.tif,/kaggle/input/marida-marine-debrish-dataset/patches/S2_11-1-19_19QDA/S2_11-1-19_19QDA_0_conf.tif,11-1-19,19QDA


In [27]:
print(f'litter rows length {len(lr_df)}')
print(f'marida length {len(marida_df)}')

litter rows length 1221
marida length 694


In [28]:
marida_val_df = create_marida_df(MARIDA_path, 'val')
lr_val_df = create_LR_dataframe(LR_splits_path, 'val')


In [29]:
lr_val_df_invalid = compute_invalid_pixels(lr_val_df['image'].tolist(), lr_val_df['mask'].tolist())

In [30]:
lr_test_df = create_LR_dataframe(LR_splits_path, 'test')
lr_test_df_invalid = compute_invalid_pixels(lr_test_df['image'].tolist(), lr_test_df['mask'].tolist())

In [31]:
#marida_df_invalid = compute_invalid_pixels(marida_df['image'].tolist(), marida_df['mask'].tolist())
#marida_df_invalid.to_csv('/kaggle/working/marida_with_invalid.csv')
marida_df_invalid = pd.read_csv('/kaggle/working/marida_df_invalid_info.csv')
marida_df_invalid.head()

Unnamed: 0.1,Unnamed: 0,image,no data pixels,negative pixels,nan pixels,high value pixels,debris pixels,min values,max values
0,0,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_0.tif,0,0,0,0,529,0.014372,0.271291
1,1,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_1.tif,0,0,0,0,34,0.013171,0.097217
2,2,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_2.tif,0,0,0,0,33,0.014072,0.102946
3,3,/kaggle/input/marida-marine-debrish-dataset/patches/S2_1-12-19_48MYU/S2_1-12-19_48MYU_3.tif,0,0,0,0,408,0.012871,0.177958
4,4,/kaggle/input/marida-marine-debrish-dataset/patches/S2_11-1-19_19QDA/S2_11-1-19_19QDA_0.tif,0,0,0,0,20,0.000247,0.102157


In [32]:
##lr_df_invalid = compute_invalid_pixels(lr_df['image'].tolist(), lr_df['mask'].tolist())
#lr_df_invalid.to_csv('/kaggle/working/lr_with_invalid.csv')
lr_df_invalid = pd.read_csv('/kaggle/working/litter_rows_df_invalid_info.csv')
lr_df_invalid.head()

Unnamed: 0.1,Unnamed: 0,image,no data pixels,negative pixels,nan pixels,high value pixels,debris pixels,min values,max values
0,0,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20161126T094332_R036_T33SXC/S2A_MSIL1C_20161126T094332_R036_T33SXC_646080_4256500.tif,0,0,0,0,30,0.000256,0.061144
1,1,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20180615T100031_R122_T33TUL/S2A_MSIL1C_20180615T100031_R122_T33TUL_353760_5056480.tif,0,22493,0,0,102,-0.001337,0.087866
2,2,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20160814T100032_R122_T33TUL/S2A_MSIL1C_20160814T100032_R122_T33TUL_374240_5056480.tif,0,3533,0,0,45,-0.001495,0.36959
3,3,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20210331T100021_R122_T33TUL/S2A_MSIL1C_20210331T100021_R122_T33TUL_379360_5020640.tif,0,382,0,0,27,-0.000738,0.042183
4,4,/kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20150830T100016_R122_T33TUL/S2A_MSIL1C_20150830T100016_R122_T33TUL_358880_5066720.tif,0,26933,0,0,235,-0.001922,0.487958


In [33]:
lr_blacklist = lr_df_invalid[lr_df_invalid['high value pixels']>0]
#lr_blacklist.head()
lr_df_filt = lr_df.drop(lr_blacklist.index)

In [34]:
# lr valid = 79495168

In [35]:
#lr_stats = compute_stats(lr_df_filt['image'].tolist())
#np.savez("/kaggle/working/lr_stats.npz", first=lr_stats['mean'], second=lr_stats['std'])
#marida_stats = compute_stats(marida_df['image'].tolist())
#np.savez("/kaggle/working/my_marida_stats.npz", first=marida_stats['mean'], second=marida_stats['std'])
#global_stats = compute_stats(marida_df['image'].tolist() + lr_df_filt['image'].to_list())
#np.savez("/kaggle/working/global_stats.npz", first=global_stats['mean'], second=global_stats['std'])

In [36]:
global_stats = np.load('/kaggle/working/global_stats.npz')
global_bands_mean = global_stats['first']
global_bands_std = global_stats['second']

In [37]:
# global_bands_mean =np.array([0.03721786, 0.03547978, 0.03033651, 0.01722546, 0.01574046,
#         0.01738895, 0.01939084, 0.01724032, 0.01895351, 0.0109694 ,
#         0.00784716])
# global_bands_std = np.array([0.03185222, 0.03198375, 0.03251331, 0.03379553, 0.03407218,
#         0.04551132, 0.05334419, 0.05064404, 0.0578197 , 0.03721222,
#         0.02560836])

In [38]:
#computing_labeled_pixels_stats(lr_df_filt['mask'].tolist())
#computing_labeled_pixels_stats(marida_df['mask'].tolist())

In [39]:
marida_classes_distr = np.array([0.00452, 0.00203, 0.00254, 0.00168, 0.00766, 0.15206, 0.20232,
 0.35941, 0.00109, 0.20218, 0.03226, 0.00693, 0.01322, 0.01158, 0.00052])
lr_debris_pixels = 92090
marida_pixels = 429412
marida_debris_pixels = np.sum(marida_classes_distr[[0,1,2,3,8]]) * marida_pixels
print(f'marida debris pixels {marida_debris_pixels}')
tot_glob_pixels = (len(lr_df_filt) + len(marida_df))*256**2
marida_debris_fraction = np.sum(marida_classes_distr[[0,1,2,3,8]])
#debris_fraction = (lr_debris_pixels + marida_debris_pixels)/tot_glob_pixels
print(f'marida_debris_fraction : {marida_debris_fraction}')

marida debris pixels 5092.826320000001
marida_debris_fraction : 0.011860000000000002


In [40]:
# Computing here the percentage of debris pixels across the two datasets
# This will be used as class distribution to generate weights for the loss function
LR_ratio = 40 # for 1 debrix pixel, choose 90 bk pixels

# For MARIDA the loss function uses only pixels in the 15 classes 
# The fraction of classes assimilated to marine debris is 
marida_debrix_pixels_distr = np.sum(marida_classes_distr[[0,1,2,3,8]])
# For LR the DataSet will sample backgroung pixels with a given ratio, stored in the variable LR_ratio
# Then the effective ratio 
effective_ratio = (1/40 * len(lr_df_filt) + 0.011860000000000002 * len(marida_df))/(len(lr_df_filt) + len(marida_df))
#print(f'effective global ratio {effective_ratio}')
class_distribution = np.array([1 - effective_ratio, effective_ratio])
print(f'class distribution {class_distribution}')

class distribution [0.97978194 0.02021806]


In [41]:
# MARIDA statistics

class_distr = np.array([0.00452, 0.00203, 0.00254, 0.00168, 0.00766, 0.15206, 0.20232,
 0.35941, 0.00109, 0.20218, 0.03226, 0.00693, 0.01322, 0.01158, 0.00052])

bands_mean = np.array([0.05197577, 0.04783991, 0.04056812, 0.03163572, 0.02972606, 0.03457443,
 0.03875053, 0.03436435, 0.0392113,  0.02358126, 0.01588816]).astype(np.float32)

bands_std = np.array([0.04725893, 0.04743808, 0.04699043, 0.04967381, 0.04946782, 0.06458357,
 0.07594915, 0.07120246, 0.08251058, 0.05111466, 0.03524419]).astype(np.float32)

In [42]:
# Other code references  
# https://github.com/MarcCoru/marinedebrisdetector

In [43]:
# MARIDA CLASSES
# {
#  1: "Marine Debris",
#  2: "Dense Sargassum", 
#  3: "Sparse Sargassum", 
#  4: "Natural Organic Material", 
#  5: "Ship", 
#  6: "Clouds", 
#  7: "Marine Water", 
#  8: "Sediment-Laden Water", 
#  9: "Foam", 
#  10: "Turbid Water", 
#  11: "Shallow Water", 
#  12: "Waves", 
#  13: "Cloud Shadows", 
#  14: "Wakes", 
#  15: "Mixed Water"
# }


# From marinedebrisdetector 
# DEBRIS_CLASSES = [1,2,3,4,9]

In [44]:
# https://drive.google.com/drive/folders/1rntiw5BvOs80eIbpOu7dk9g1BfOVw61-?usp=drive_link

In [45]:

class RandomRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return vF.rotate(x, angle)
    
def gen_weights(class_distribution, c = 1.02):
    return 1/torch.log(c + class_distribution)
    
transformTrain = transforms.Compose([transforms.ToTensor(),
                                    RandomRotationTransform([-90, 0, 90, 180]),
                                    transforms.RandomHorizontalFlip()])
    
transformTest = transforms.Compose([transforms.ToTensor()])
    
standardization = transforms.Normalize(global_bands_mean, global_bands_std) 

In [46]:
def gen_weights(class_distribution, c = 1.02):
    return 1/torch.log(c + class_distribution)

In [47]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision import transforms
# model import UNet, AttentionUNet, ResidualAttentionUNet  # From original script
#f#rom dataloader import bands_mean, bands_std, RandomRotationTransform, class_distr, gen_weights
#from metrics import Evaluation
#from customLosses import FocalLoss
import pandas as pd
from torch.utils.data import Dataset

class MergedSegmentationDataset(Dataset):
    """
    df_dataset1 : MARIDA dataset
    df_dataset2 : LR dataset
    """
    def __init__(self, df_dataset1, df_dataset2, bands_mean, bands_std, transform=None, standardization=None):
        """
        df_dataset1 : MARIDA
        df_dataset2 : Litter Windrows
        """
        self.bands_mean = bands_mean
        self.bands_std = bands_std
        self.transform = transform
        self.standardization = standardization
        self.image_paths = []
        self.mask_paths = []
        self.dataset_ids = []
        self.image_paths = df_dataset1['image'].tolist() + df_dataset2['image'].tolist() 
        self.mask_paths =  df_dataset1['mask'].tolist() + df_dataset2['mask'].tolist() 
        self.dataset_ids = [0] * len(df_dataset1['image']) + [1] * len(df_dataset2['image'])
        # Generate shuffled indices
        indices = np.random.permutation(len(self.image_paths))
        self.image_paths = np.array(self.image_paths)[indices]
        self.mask_paths = np.array(self.mask_paths)[indices]
        self.dataset_ids = np.array(self.dataset_ids)[indices]        
        #print(self.dataset_ids)
        if self.transform is None:
            self.transform = transforms.Compose([transforms.ToTensor()])
        ## preloading images in memory 

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        #print(f'idx {idx}')
        # Load Classsification Mask
        dataset_id = self.dataset_ids[idx]
        # Open t#he GeoTIFF image file
        #print(f'image path {self.image_paths[idx]}')
        #print(f'mask path {self.mask_paths[idx]}')
        with rasterio.open(self.image_paths[idx]) as src:
            #print(f#"Number of bands: {dataset.count}")  # Check the number of bands
            # Read all bands as a NumPy array
            image = src.read()
            #print(f'image shape {image.shape}')
            invalid_mask = get_invalid_mask(image, src.nodata)
            #print(bands.shape)  # Shape will be (bands, height, width)
            #print(f'invalid mask shape {invalid_mask.shape}')
            with rasterio.open(self.mask_paths[idx]) as src_mask:
                mask = src_mask.read().astype(int)
            # if dataset_id == 0: #MARIDA
            #     #print(f'sample from marida')
            #     temp = mask.copy()
            #     #assimilate several classes to marine debris
            #     temp[temp==1]=2
            #     temp[temp==2]=2          
            #     temp[temp==3]=2          
            #     temp[temp==4]=2          
            #     temp[temp==9]=2          
            #     # Leaving unlabeled pixels to 0 and pixels in classes not in [1,2,3, 4,9] to 1
            #     temp[(temp != 0) & (temp != 2)] = 1
                
            #     # Categories from 1 to 0
            #     mask = np.copy(temp)
            # else : #LR
            #     #print('sample from litter rows')
            #     bg_mask = select_bg_pixels(image, mask[0], target_ratio=40)
            #     #print(f'bg mask shape {bg_mask.shape}')
            #     mask[mask==1] = 2
            #     mask[bg_mask[None,...].astype(bool)] = 1
            #print(f'mask before inputing {mask.shape}')
            debris_before_invalid = np.sum(mask)
            invalid_pixels = np.sum(np.any(invalid_mask, axis=0))
            mask[np.any(invalid_mask.astype(bool), axis=0, keepdims=True)] = 0 #I guess it makes sense not to feed invalid pixels to the loss function
            #print(f'before inputing 2')
            image[invalid_mask.astype(bool)] = np.tile(self.bands_mean[:, np.newaxis, np.newaxis], (1, 256, 256))[invalid_mask.astype(bool)]
            #print(f'after inputing')
            ## Since the model sees unvalid pixels anyway, it's better (?) to replace those with mean values ? 
            #print(f'mask type before transh {type(mask)} - {mask.dtype}')
            #print(f'image type before transh {type(image)} - {image.dtype}')
            #############
            debris_after_invalid = np.sum(mask)
            #############
            if self.transform is not None:
                # applying the same rotation on the image-mask pair
                #print(f'transform - image shape {image.shape}')
                #print(f'transform - mask shape {mask.shape}')
                stack = np.concatenate([image, mask], axis=0).astype(np.float32) 
                stack = np.transpose(stack,(1, 2, 0)) #to channel last
                #print(f'stack shape before transfrom {stack.shape}')
                stack = self.transform(stack) #expects channel last, returns channel first
               
                #print(f'stack shape after transfrom {stack.shape}')
                image = stack[:-1,:,:]
                mask = stack[-1,:,:].long()
                #print(f'image type {image.dtype}')
                #print(f'image shape after transform {image.shape}')
                #print(f'mask shape after transform {mask.shape}')

                   
            
            if self.standardization is not None:
                image = self.standardization(image)
                
            #mask = mask - 1 Moved to collate function
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask).to(torch.long)
            else:
                mask = mask.to(torch.long)
            if isinstance(image, np.ndarray):
                image = torch.from_numpy(image).to(torch.float32)
            else:
                im = image.to(torch.float32)
            if torch.sum(mask) == 0 :
                print(f'{self.mask_paths[idx]} has no debris pixels')
                print(f'debris pixels before invalid mask : {debris_before_invalid}')
                print(f'debris pixels after invalid mask : {debris_after_invalid}')
                print(f'invalid pixels : {invalid_pixels}')
           
        ## Add logic for transform

            return image, mask, dataset_id

In [48]:
def conv3x3(in_channels, out_channels, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class ChannelAttention(nn.Module):
    def __init__(self, channels, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.mlp = nn.Sequential(nn.Conv2d(channels, channels // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(channels // 16, channels, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class ResidualBlock(nn.Module):
    def __init__(self, inputChannel, outputChannel, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(inputChannel, outputChannel, stride)
        self.bn1 = nn.BatchNorm2d(outputChannel)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(outputChannel, outputChannel)
        self.bn2 = nn.BatchNorm2d(outputChannel)
        self.downsample = downsample
        self.ca = ChannelAttention(outputChannel)
        self.sa = SpatialAttention()
        
    # def forward(self, x):
    #     residual = x
    #     out = self.conv1(x)
    #     out = self.bn1(out)
    #     out = self.relu(out)
    #     out = self.conv2(out)
    #     out = self.bn2(out)
    #     if self.downsample:
    #         residual = self.downsample(x)
    #     out += residual
    #     out = self.relu(out)
    #     caOutput = self.ca(out)
    #     out = caOutput * out
    #     saOutput = self.sa(out)
    #     out = saOutput * out
    #     return out, saOutput

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.bn2(out)
        out = self.relu(out)
        caOutput = self.ca(out)
        out = caOutput * out
        saOutput = self.sa(out)
        out = saOutput * out
        return out, saOutput


class DownSampleWithAttention(nn.Module):
    def __init__(self, inputChannel, outputChannel):
        super().__init__()
        self.convolution = nn.Sequential(
            nn.Conv2d(inputChannel, outputChannel, kernel_size=3, padding=1),
            nn.BatchNorm2d(outputChannel),
            nn.LeakyReLU(0.2),
            nn.Conv2d(outputChannel, outputChannel, kernel_size=3, padding=1),
            nn.BatchNorm2d(outputChannel),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(2)
        )
        self.ca = ChannelAttention(outputChannel)
        self.sa = SpatialAttention()
    
    def forward(self,x):
        x = self.convolution(x)
        caOutput = self.ca(x)
        x = caOutput * x
        saOutput = self.sa(x)
        x = saOutput * x
        return x, saOutput

    
class UpSampleWithAttention(nn.Module):
    def __init__(self, inputChannel, outputChannel):
        super().__init__()
        self.convolution = nn.Sequential(
            nn.Conv2d(inputChannel, outputChannel, kernel_size=3, padding=1),
            nn.BatchNorm2d(outputChannel),
            nn.LeakyReLU(0.2),
            nn.Conv2d(outputChannel, outputChannel, kernel_size=3, padding=1),
            nn.BatchNorm2d(outputChannel),
            nn.LeakyReLU(0.2)
        )
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.ca = ChannelAttention(outputChannel)
        self.sa = SpatialAttention()
    
    def forward(self, x):
        x = self.upsample(x)
        x = self.convolution(x)
        caOutput = self.ca(x)
        x = caOutput * x
        saOutput = self.sa(x)
        x = saOutput * x
        return x, saOutput

class ResidualAttentionUNet(nn.Module):
  def __init__(self, inputChannel, outputChannel):
    super().__init__()
    self.downsample1 = DownSampleWithAttention(inputChannel, 32)
    self.downsample2 = DownSampleWithAttention(32, 64)
    self.downsample3 = DownSampleWithAttention(64, 128)
    self.downsample4 = DownSampleWithAttention(128, 256)
    self.downsample5 = DownSampleWithAttention(256, 512)

    self.residualBlock1 = ResidualBlock(512, 512)
    self.residualBlock2 = ResidualBlock(512, 512)
    self.residualBlock3 = ResidualBlock(512, 512)

    self.upsample1 = UpSampleWithAttention(512, 256)
    self.upsample2 = UpSampleWithAttention(512, 128)
    self.upsample3 = UpSampleWithAttention(256, 64)
    self.upsample4 = UpSampleWithAttention(128, 32)
    self.upsample5 = UpSampleWithAttention(64, 32)
    self.classification = nn.Sequential(
            nn.Conv2d(32, outputChannel, kernel_size=1),
        )

  def forward(self, x):
    scale128, sa128down = self.downsample1(x)
    scale64, sa64down = self.downsample2(scale128)
    scale32, sa32down = self.downsample3(scale64)
    scale16, sa64down = self.downsample4(scale32)
    scale8, sa8down = self.downsample5(scale16)
    scale8, sa8down = self.residualBlock1(scale8)
    scale8, sa8down = self.residualBlock2(scale8)
    scale8, sa8down = self.residualBlock3(scale8)
    upscale16, sa16up = self.upsample1(scale8)
    upscale16 = torch.cat([upscale16, scale16], dim=1)
    upscale32, sa32up = self.upsample2(upscale16)
    upscale32 = torch.cat([upscale32, scale32], dim=1)
    upscale64, sa64up = self.upsample3(upscale32)
    upscale64 = torch.cat([upscale64, scale64], dim=1)
    upscale128, sa128up = self.upsample4(upscale64)
    upscale128 = torch.cat([upscale128, scale128], dim=1)
    upscale256, sa256up = self.upsample5(upscale128)
    finaloutput = self.classification(upscale256)
    return finaloutput

In [49]:
def Evaluation(y_predicted, y_true):

    micro_prec = precision_score(y_true, y_predicted, average='micro')
    macro_prec = precision_score(y_true, y_predicted, average='macro')
    weight_prec = precision_score(y_true, y_predicted, average='weighted')
    
    micro_rec = recall_score(y_true, y_predicted, average='micro')
    macro_rec = recall_score(y_true, y_predicted, average='macro')
    weight_rec = recall_score(y_true, y_predicted, average='weighted')
        
    macro_f1 = f1_score(y_true, y_predicted, average="macro")
    micro_f1 = f1_score(y_true, y_predicted, average="micro")
    weight_f1 = f1_score(y_true, y_predicted, average="weighted")
        
    subset_acc = accuracy_score(y_true, y_predicted)
    
    iou_acc = jaccard_score(y_true, y_predicted, average='macro')

    info = {
            "macroPrec" : macro_prec,
            "microPrec" : micro_prec,
            "weightPrec" : weight_prec,
            "macroRec" : macro_rec,
            "microRec" : micro_rec,
            "weightRec" : weight_rec,
            "macroF1" : macro_f1,
            "microF1" : micro_f1,
            "weightF1" : weight_f1,
            "subsetAcc" : subset_acc,
            "IoU": iou_acc
            }
    
    return info

In [50]:
transformTrain = transforms.Compose([transforms.ToTensor(),
                                    RandomRotationTransform([-90, 0, 90, 180]),
                                    transforms.RandomHorizontalFlip()])
    
transformTest = transforms.Compose([transforms.ToTensor()])
    
standardization = transforms.Normalize(global_bands_mean.tolist(), global_bands_std.tolist())
merged_ds = MergedSegmentationDataset(marida_df, lr_df_filt, global_bands_mean, global_bands_std, transform=transformTrain, standardization= standardization)
val_ds = MergedSegmentationDataset(marida_val_df, lr_val_df, global_bands_mean, global_bands_std, transform=transformTest, standardization= standardization )


In [51]:
trainLoader = DataLoader(merged_ds,
                        batch_size=batch_size, 
                        shuffle=True,  #num_workers=4, 
                        #pin_memory=True,
                        #prefetch_factor=2,
                        collate_fn=custom_collate_fn,
                        #worker_init_fn=worker_init_fn,
                        #generator=torch.Generator().manual_seed(seed) 
                        )


testLoader = DataLoader(val_ds, 
                        batch_size=batch_size, 
                        shuffle=False,
                        collate_fn=custom_collate_fn,
                        #worker_init_fn=worker_init_fn,
                        #generator=torch.Generator().manual_seed(seed) 
                        )
                        
    

In [52]:
model = ResidualAttentionUNet(11, 2).to(device)
weight = gen_weights(torch.from_numpy(class_distribution), c = 1.03).to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='mean', weight=weight.to(torch.float32))
optimizer = torch.optim.Adam(model.parameters(), lr=8e-4, weight_decay=1e-2)

# assuming about 40 reductions => .9 ** 40 = 1e-2, starting from 8e-4 ending with 8e-6
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5)

In [53]:
output_classes = 2
metrics_history = []
epochs = 30
for epoch in range(1, epochs+1):
    model.train()
    pb = tqdm(trainLoader, desc=f"epoch {epoch}/{epochs}: ")
    yTrue = []
    yPredicted = []

    bg_yTrue = []
    bg_yPredicted = []
    for image, target, _ in pb:
        image, target = image.to(device), target.to(device)
        optimizer.zero_grad()

        logits = model(image)
        # print(f'logits shape : {logits.shape}')
        # print(f'target shape : {target.shape}')
        # print(f'image dtype {image.dtype}')
        # print(f'logits dtype {logits.dtype}')
        # print(f'target dtype {target.dtype}')
        loss = criterion(logits, target)

        loss.backward()
        optimizer.step()
        pb.set_postfix(loss=loss.item())

        if epoch % 10 == 0:
            with torch.no_grad():
                logits = logits.detach()
                logits = torch.movedim(logits, (0,1,2,3), (0,3,1,2))
                logits = logits.reshape((-1,output_classes))
                target = target.reshape(-1)
                ###################################################################################
                mask = target != -1
                ###################################################################################
                
                # bg_logits = logits[~mask]
                # bg_target = target[~mask]
    
                # only considering annotated pixels
                logits = logits[mask]
                target = target[mask]
    
                probs = F.softmax(logits, dim=1).cpu().numpy()
                target = target.cpu().numpy()
                yPredicted += probs.argmax(1).tolist()
                yTrue += target.tolist()
        
                
                # bg_probs = torch.nn.functional.softmax(bg_logits, dim=1).cpu().numpy()
                # bg_target = bg_target.cpu().numpy()
                
                # bg_yPredicted += bg_probs.argmax(1).tolist()
                # bg_yTrue += bg_target.tolist()


    if epoch % 10 == 0:
        yPredicted = np.asarray(yPredicted)
        yTrue = np.asarray(yTrue)
        acc = Evaluation(yPredicted, yTrue)
        print(acc)
    
        # bg_yPredicted = np.asarray(bg_yPredicted)
        # bg_yTrue = np.asarray(bg_yTrue)
        # bg_acc = Evaluation(bg_yPredicted, bg_yTrue)
        # print("background:", bg_acc)


    model.eval()
    yTrue = []
    yPredicted = []
    testLossF = []
    
    # bg_yTrue = []
    # bg_yPredicted = []
    with torch.no_grad():
        for image, target, _ in testLoader:

            image, target = image.to(device), target.to(device)
            logits = model(image)
            # print(f'image dtype {image.dtype}')
            # print(f'logits dtype {logits.dtype}')
            # print(f'target dtype {target.dtype}')
            # print(f'test - target shape {target.shape}')
            # print(f'test - logit shape {logits.shape}')
            loss = criterion(logits, target)

            logits = torch.movedim(logits, (0,1,2,3), (0,3,1,2))
            logits = logits.reshape((-1,output_classes))
            target = target.reshape(-1)
            ###################################################################################
            mask = target != -1
            ###################################################################################
            
            # bg_logits = logits[~mask]
            # bg_target = target[~mask]
            
            logits = logits[mask]
            target = target[mask]
            

            probs = F.softmax(logits, dim=1).cpu().numpy()
            target = target.cpu().numpy()
            # testBatches += target.shape[0]
            testLossF.append((loss.data*target.shape[0]).tolist())
            yPredicted += probs.argmax(1).tolist()
            yTrue += target.tolist()


            # bg_probs = torch.nn.functional.softmax(bg_logits, dim=1).cpu().numpy()
            # bg_target = bg_target.cpu().numpy()

            # bg_yPredicted += bg_probs.argmax(1).tolist()
            # bg_yTrue += bg_target.tolist()
        
        yPredicted = np.asarray(yPredicted)
        yTrue = np.asarray(yTrue)
        print('########### Validation Set Evaluation : #############')
        acc = Evaluation(yPredicted, yTrue)
        metrics_history.append(acc)

        # bg_yPredicted = np.asarray(bg_yPredicted)
        # bg_yTrue = np.asarray(bg_yTrue)
        # bg_acc = Evaluation(bg_yPredicted, bg_yTrue)
        print(acc)
        # print("background:", bg_acc)
    scheduler.step(sum(testLossF) / len(testLoader.dataset))

  invalid_mask |= image > 1
epoch 1/30: 100%|██████████| 120/120 [04:11<00:00,  2.10s/it, loss=0.486]


########### Validation Set Evaluation : #############
{'macroPrec': 0.5880462650937845, 'microPrec': 0.908698085273679, 'weightPrec': 0.9794917433645293, 'macroRec': 0.8992861264924438, 'microRec': 0.908698085273679, 'weightRec': 0.908698085273679, 'macroF1': 0.6244537412354473, 'microF1': 0.908698085273679, 'weightF1': 0.9369564383805971, 'subsetAcc': 0.908698085273679, 'IoU': 0.5408997398567341}


  invalid_mask |= image > 1
epoch 2/30: 100%|██████████| 120/120 [02:36<00:00,  1.31s/it, loss=0.245] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.8519043421647148, 'microPrec': 0.9872139067358982, 'weightPrec': 0.9870861589209086, 'macroRec': 0.844754637131489, 'microRec': 0.9872139067358982, 'weightRec': 0.9872139067358982, 'macroF1': 0.8482907456724691, 'microF1': 0.9872139067358982, 'weightF1': 0.9871482960307061, 'subsetAcc': 0.9872139067358982, 'IoU': 0.7645872008155867}


  invalid_mask |= image > 1
epoch 3/30: 100%|██████████| 120/120 [02:37<00:00,  1.32s/it, loss=0.171] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.6612411592807774, 'microPrec': 0.960189560030885, 'weightPrec': 0.9805717927950888, 'macroRec': 0.8753448571033854, 'microRec': 0.960189560030885, 'weightRec': 0.960189560030885, 'macroF1': 0.7208372643674215, 'microF1': 0.960189560030885, 'weightF1': 0.9680800656642571, 'subsetAcc': 0.960189560030885, 'IoU': 0.6300890759257904}


  invalid_mask |= image > 1
epoch 4/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.0868]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8301358609111384, 'microPrec': 0.9868115521112992, 'weightPrec': 0.9882112512545314, 'macroRec': 0.8926760584382254, 'microRec': 0.9868115521112992, 'weightRec': 0.9868115521112992, 'macroF1': 0.8585351927600802, 'microF1': 0.9868115521112992, 'weightF1': 0.9873819975229571, 'subsetAcc': 0.9868115521112992, 'IoU': 0.7768819566435514}


  invalid_mask |= image > 1
epoch 5/30: 100%|██████████| 120/120 [02:36<00:00,  1.31s/it, loss=0.395] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.7214726966330601, 'microPrec': 0.9733369224121443, 'weightPrec': 0.9859412784651994, 'macroRec': 0.9428920351584213, 'microRec': 0.9733369224121443, 'weightRec': 0.9733369224121443, 'macroF1': 0.7920630398361085, 'microF1': 0.9733369224121443, 'weightF1': 0.9777620538882531, 'subsetAcc': 0.9733369224121443, 'IoU': 0.6996222427301875}


  invalid_mask |= image > 1
epoch 6/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.0703]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8304474079465927, 'microPrec': 0.9870028122321473, 'weightPrec': 0.9885815135307731, 'macroRec': 0.900985975046958, 'microRec': 0.9870028122321473, 'weightRec': 0.9870028122321473, 'macroF1': 0.8621068182507771, 'microF1': 0.9870028122321473, 'weightF1': 0.9876293773438131, 'subsetAcc': 0.9870028122321473, 'IoU': 0.7813278858339494}


  invalid_mask |= image > 1
epoch 7/30: 100%|██████████| 120/120 [02:37<00:00,  1.32s/it, loss=0.16]  


########### Validation Set Evaluation : #############
{'macroPrec': 0.8062619582648308, 'microPrec': 0.9849357861853523, 'weightPrec': 0.9875604951655406, 'macroRec': 0.9008525449932168, 'microRec': 0.9849357861853523, 'weightRec': 0.9849357861853523, 'macroF1': 0.8468175873130594, 'microF1': 0.9849357861853523, 'weightF1': 0.9859428637066713, 'subsetAcc': 0.9849357861853523, 'IoU': 0.762369886393308}


  invalid_mask |= image > 1
epoch 8/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.193] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.8432240808519205, 'microPrec': 0.986862554810192, 'weightPrec': 0.9870410673488468, 'macroRec': 0.8523734218764348, 'microRec': 0.986862554810192, 'weightRec': 0.986862554810192, 'macroF1': 0.8477351174919954, 'microF1': 0.986862554810192, 'weightF1': 0.98694895425445, 'subsetAcc': 0.986862554810192, 'IoU': 0.7638550390242833}


  invalid_mask |= image > 1
epoch 9/30: 100%|██████████| 120/120 [02:37<00:00,  1.32s/it, loss=0.0891]


########### Validation Set Evaluation : #############
{'macroPrec': 0.741203985610704, 'microPrec': 0.9768745262770155, 'weightPrec': 0.9873918827589045, 'macroRec': 0.9579096630522819, 'microRec': 0.9768745262770155, 'weightRec': 0.9768745262770155, 'macroF1': 0.8132194949499056, 'microF1': 0.9768745262770155, 'weightF1': 0.9804465677096167, 'subsetAcc': 0.9768745262770155, 'IoU': 0.7226175899997852}


  invalid_mask |= image > 1
epoch 10/30: 100%|██████████| 120/120 [02:39<00:00,  1.33s/it, loss=0.128] 


{'macroPrec': 0.7683597270969219, 'microPrec': 0.9795061537278222, 'weightPrec': 0.9877115484649577, 'macroRec': 0.9619192300174098, 'microRec': 0.9795061537278222, 'weightRec': 0.9795061537278222, 'macroF1': 0.8373729716366622, 'microF1': 0.9795061537278222, 'weightF1': 0.9822152445771444, 'subsetAcc': 0.9795061537278222, 'IoU': 0.7501708312381097}


  invalid_mask |= image > 1


########### Validation Set Evaluation : #############
{'macroPrec': 0.8423053091116409, 'microPrec': 0.9881305385743329, 'weightPrec': 0.9895120593392289, 'macroRec': 0.9110159129570583, 'microRec': 0.9881305385743329, 'weightRec': 0.9881305385743329, 'macroF1': 0.873339280442998, 'microF1': 0.9881305385743329, 'weightF1': 0.9886717303818092, 'subsetAcc': 0.9881305385743329, 'IoU': 0.7957258883532594}


  invalid_mask |= image > 1
epoch 11/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.57]  


########### Validation Set Evaluation : #############
{'macroPrec': 0.759683118093358, 'microPrec': 0.979750511797916, 'weightPrec': 0.983101487854572, 'macroRec': 0.8427542469868469, 'microRec': 0.979750511797916, 'weightRec': 0.979750511797916, 'macroF1': 0.7950770727716553, 'microF1': 0.979750511797916, 'weightF1': 0.9811459033154277, 'subsetAcc': 0.979750511797916, 'IoU': 0.7042812078728516}


  invalid_mask |= image > 1
epoch 12/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.216] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.825029672825959, 'microPrec': 0.9860875971353484, 'weightPrec': 0.9872708941652028, 'macroRec': 0.8757862202333815, 'microRec': 0.9860875971353484, 'weightRec': 0.9860875971353484, 'macroF1': 0.8484574719839837, 'microF1': 0.9860875971353484, 'weightF1': 0.9865913328256588, 'subsetAcc': 0.9860875971353484, 'IoU': 0.7645546607024434}


  invalid_mask |= image > 1
epoch 13/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.202] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.7588902806018363, 'microPrec': 0.979734927639921, 'weightPrec': 0.9857138809202162, 'macroRec': 0.9036053851818977, 'microRec': 0.979734927639921, 'weightRec': 0.979734927639921, 'macroF1': 0.8142543352227103, 'microF1': 0.979734927639921, 'weightF1': 0.9819450480509633, 'subsetAcc': 0.979734927639921, 'IoU': 0.7244005684423471}


  invalid_mask |= image > 1
epoch 14/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.141] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.8126813452750962, 'microPrec': 0.9861343496093335, 'weightPrec': 0.9895072920974555, 'macroRec': 0.9396294170294288, 'microRec': 0.9861343496093335, 'weightRec': 0.9861343496093335, 'macroF1': 0.8647419513600556, 'microF1': 0.9861343496093335, 'weightF1': 0.9873032739093467, 'subsetAcc': 0.9861343496093335, 'IoU': 0.7844476556952917}


  invalid_mask |= image > 1
epoch 15/30: 100%|██████████| 120/120 [02:39<00:00,  1.33s/it, loss=0.0929]


########### Validation Set Evaluation : #############
{'macroPrec': 0.7603655264621089, 'microPrec': 0.9799701067514822, 'weightPrec': 0.9881345105064392, 'macroRec': 0.9573274368460899, 'microRec': 0.9799701067514822, 'weightRec': 0.9799701067514822, 'macroF1': 0.8297417747835893, 'microF1': 0.9799701067514822, 'weightF1': 0.9827117188077206, 'subsetAcc': 0.9799701067514822, 'IoU': 0.741550668336961}


  invalid_mask |= image > 1
epoch 16/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.0815]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8257375217633102, 'microPrec': 0.9874760039385417, 'weightPrec': 0.9903303601588304, 'macroRec': 0.9460764095016057, 'microRec': 0.9874760039385417, 'weightRec': 0.9874760039385417, 'macroF1': 0.8759282437864566, 'microF1': 0.9874760039385417, 'weightF1': 0.9884516903743165, 'subsetAcc': 0.9874760039385417, 'IoU': 0.7989503191006054}


  invalid_mask |= image > 1
epoch 17/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.102] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.8829037482498396, 'microPrec': 0.9885144755576649, 'weightPrec': 0.9878661630627272, 'macroRec': 0.8319234344639976, 'microRec': 0.9885144755576649, 'weightRec': 0.9885144755576649, 'macroF1': 0.8555017513719353, 'microF1': 0.9885144755576649, 'weightF1': 0.9881049560424653, 'subsetAcc': 0.9885144755576649, 'IoU': 0.7735138232668023}


  invalid_mask |= image > 1
epoch 18/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.102] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.8480879087417952, 'microPrec': 0.9882382109386622, 'weightPrec': 0.989171065366306, 'macroRec': 0.8973840143631757, 'microRec': 0.9882382109386622, 'weightRec': 0.9882382109386622, 'macroF1': 0.8710105243768678, 'microF1': 0.9882382109386622, 'weightF1': 0.9886270564441362, 'subsetAcc': 0.9882382109386622, 'IoU': 0.792763752411922}


  invalid_mask |= image > 1
epoch 19/30: 100%|██████████| 120/120 [02:39<00:00,  1.33s/it, loss=0.0428]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8112746897738243, 'microPrec': 0.9861471002840567, 'weightPrec': 0.9898750451113504, 'macroRec': 0.9494395973102969, 'microRec': 0.9861471002840567, 'weightRec': 0.9861471002840567, 'macroF1': 0.8669796627834931, 'microF1': 0.9861471002840567, 'weightF1': 0.98740391459943, 'subsetAcc': 0.9861471002840567, 'IoU': 0.7872626844428303}


  invalid_mask |= image > 1
epoch 20/30: 100%|██████████| 120/120 [02:39<00:00,  1.33s/it, loss=0.0985]


{'macroPrec': 0.7812816967179355, 'microPrec': 0.9814361002472006, 'weightPrec': 0.9885163295453285, 'macroRec': 0.9668835210093809, 'microRec': 0.9814361002472006, 'weightRec': 0.9814361002472006, 'macroF1': 0.8492302269283183, 'microF1': 0.9814361002472006, 'weightF1': 0.9837337759677555, 'subsetAcc': 0.9814361002472006, 'IoU': 0.7645263751478732}


  invalid_mask |= image > 1


########### Validation Set Evaluation : #############
{'macroPrec': 0.7623249269986121, 'microPrec': 0.9802704559782955, 'weightPrec': 0.9887168055340849, 'macroRec': 0.9687169684182605, 'microRec': 0.9802704559782955, 'weightRec': 0.9802704559782955, 'macroF1': 0.8341477331333056, 'microF1': 0.9802704559782955, 'weightF1': 0.9830480830119825, 'subsetAcc': 0.9802704559782955, 'IoU': 0.7466256100037287}


  invalid_mask |= image > 1
epoch 21/30: 100%|██████████| 120/120 [02:37<00:00,  1.32s/it, loss=0.131] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.7833013852744681, 'microPrec': 0.9827412533913253, 'weightPrec': 0.9871310698461134, 'macroRec': 0.9130040125919762, 'microRec': 0.9827412533913253, 'weightRec': 0.9827412533913253, 'macroF1': 0.8352258887161801, 'microF1': 0.9827412533913253, 'weightF1': 0.984347041308781, 'subsetAcc': 0.9827412533913253, 'IoU': 0.7483955116000515}


  invalid_mask |= image > 1
epoch 22/30: 100%|██████████| 120/120 [02:41<00:00,  1.34s/it, loss=0.111] 


########### Validation Set Evaluation : #############
{'macroPrec': 0.7954423243423113, 'microPrec': 0.9845688501016512, 'weightPrec': 0.9896370993885479, 'macroRec': 0.9589776935243659, 'microRec': 0.9845688501016512, 'weightRec': 0.9845688501016512, 'macroF1': 0.8582553269822328, 'microF1': 0.9845688501016512, 'weightF1': 0.9862389592819394, 'subsetAcc': 0.9845688501016512, 'IoU': 0.7760992880625648}


  invalid_mask |= image > 1
epoch 23/30: 100%|██████████| 120/120 [02:39<00:00,  1.33s/it, loss=0.0525]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8163941547721898, 'microPrec': 0.986543787942112, 'weightPrec': 0.9897964600980873, 'macroRec': 0.9427352267528569, 'microRec': 0.986543787942112, 'weightRec': 0.986543787942112, 'macroF1': 0.8683671396571463, 'microF1': 0.986543787942112, 'weightF1': 0.9876626577556429, 'subsetAcc': 0.986543787942112, 'IoU': 0.7890905325510985}


  invalid_mask |= image > 1
epoch 24/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.0798]


########### Validation Set Evaluation : #############
{'macroPrec': 0.7696632727579982, 'microPrec': 0.9813273452386856, 'weightPrec': 0.9887925685265664, 'macroRec': 0.9637187345721362, 'microRec': 0.9813273452386856, 'weightRec': 0.9813273452386856, 'macroF1': 0.83909336961762, 'microF1': 0.9813273452386856, 'weightF1': 0.9837921194561936, 'subsetAcc': 0.9813273452386856, 'IoU': 0.7525528679962301}


  invalid_mask |= image > 1
epoch 25/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.0734]


########### Validation Set Evaluation : #############
{'macroPrec': 0.841994444541895, 'microPrec': 0.9887709057937649, 'weightPrec': 0.9907879457327673, 'macroRec': 0.9409451884987441, 'microRec': 0.9887709057937649, 'weightRec': 0.9887709057937649, 'macroF1': 0.8848353671299538, 'microF1': 0.9887709057937649, 'weightF1': 0.9894801867059186, 'subsetAcc': 0.9887709057937649, 'IoU': 0.8108871328021097}


  invalid_mask |= image > 1
epoch 26/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=1.61]  


########### Validation Set Evaluation : #############
{'macroPrec': 0.8685085457079849, 'microPrec': 0.9898192946043395, 'weightPrec': 0.9903941252439711, 'macroRec': 0.9057676986880282, 'microRec': 0.9898192946043395, 'weightRec': 0.9898192946043395, 'macroF1': 0.8861912449202807, 'microF1': 0.9898192946043395, 'weightF1': 0.990064234889695, 'subsetAcc': 0.9898192946043395, 'IoU': 0.8128749370382383}


  invalid_mask |= image > 1
epoch 27/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.0824]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8233153000386644, 'microPrec': 0.9867591326707705, 'weightPrec': 0.9890569850326878, 'macroRec': 0.918972737826548, 'microRec': 0.9867591326707705, 'weightRec': 0.9867591326707705, 'macroF1': 0.8645915368593295, 'microF1': 0.9867591326707705, 'weightF1': 0.9876119134996764, 'subsetAcc': 0.9867591326707705, 'IoU': 0.7843779247115112}


  invalid_mask |= image > 1
epoch 28/30: 100%|██████████| 120/120 [02:37<00:00,  1.31s/it, loss=0.0494]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8244799430946932, 'microPrec': 0.9871402361708307, 'weightPrec': 0.9897592365686588, 'macroRec': 0.9340958333864748, 'microRec': 0.9871402361708307, 'weightRec': 0.9871402361708307, 'macroF1': 0.8708655165189632, 'microF1': 0.9871402361708307, 'weightF1': 0.988068832624309, 'subsetAcc': 0.9871402361708307, 'IoU': 0.7923770949963399}


  invalid_mask |= image > 1
epoch 29/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.00737]


########### Validation Set Evaluation : #############
{'macroPrec': 0.8027491491842487, 'microPrec': 0.9852219679958064, 'weightPrec': 0.989389201946361, 'macroRec': 0.9463885058757027, 'microRec': 0.9852219679958064, 'weightRec': 0.9852219679958064, 'macroF1': 0.8598674673522029, 'microF1': 0.9852219679958064, 'weightF1': 0.9866372032020722, 'subsetAcc': 0.9852219679958064, 'IoU': 0.7782119172697235}


  invalid_mask |= image > 1
epoch 30/30: 100%|██████████| 120/120 [02:38<00:00,  1.32s/it, loss=0.162] 


{'macroPrec': 0.7949022986753064, 'microPrec': 0.9832922954996636, 'weightPrec': 0.9892278653011504, 'macroRec': 0.9691895384469061, 'microRec': 0.9832922954996636, 'weightRec': 0.9832922954996636, 'macroF1': 0.8606463963789573, 'microF1': 0.9832922954996636, 'weightF1': 0.9851944216858431, 'subsetAcc': 0.9832922954996636, 'IoU': 0.7788010027953466}


  invalid_mask |= image > 1


########### Validation Set Evaluation : #############
{'macroPrec': 0.8364416811801099, 'microPrec': 0.9881404557657842, 'weightPrec': 0.9902034364782886, 'macroRec': 0.9327927532217986, 'microRec': 0.9881404557657842, 'weightRec': 0.9881404557657842, 'macroF1': 0.8782114963513497, 'microF1': 0.9881404557657842, 'weightF1': 0.9888828800754125, 'subsetAcc': 0.9881404557657842, 'IoU': 0.8020389634856595}


In [54]:


# Save everything in a checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'epoch': 10  # Optional: Save the epoch number
}

torch.save(checkpoint, 'model_checkpoint_30_epochs.pth')


In [55]:
marida_test_df = create_marida_df(MARIDA_path, 'test')
empty_df =  pd.DataFrame(columns=marida_test_df.columns)
marida_test_ds = MergedSegmentationDataset(marida_test_df, empty_df, global_bands_mean, global_bands_std, transform=transformTest, standardization= standardization )

marida_testLoader = DataLoader(marida_test_ds, 
                        batch_size=batch_size, 
                        shuffle=False,
                        collate_fn=custom_collate_fn,
                        #worker_init_fn=worker_init_fn,
                        #generator=torch.Generator().manual_seed(seed) 
                        )

test_metrics_history = []
model.eval()
yTrue = []
yPredicted = []
testLossF = []
with torch.no_grad():
    for image, target, _ in marida_testLoader:

        image, target = image.to(device), target.to(device)
        logits = model(image)
        # print(f'image dtype {image.dtype}')
        # print(f'logits dtype {logits.dtype}')
        # print(f'target dtype {target.dtype}')
        # print(f'test - target shape {target.shape}')
        #print(f'test - logit shape {logits.shape}')
        loss = criterion(logits, target)

        logits = torch.movedim(logits, (0,1,2,3), (0,3,1,2))
        logits = logits.reshape((-1,output_classes))
        target = target.reshape(-1)
        ###################################################################################
        mask = target != -1
        ###################################################################################
        
        # bg_logits = logits[~mask]
        # bg_target = target[~mask]
        
        logits = logits[mask]
        target = target[mask]
        

        probs = F.softmax(logits, dim=1).cpu().numpy()
        print(f'test - probs shape {probs.shape}')
        target = target.cpu().numpy()
        # testBatches += target.shape[0]
        testLossF.append((loss.data*target.shape[0]).tolist())
        yPredicted += probs.argmax(1).tolist()
        yTrue += target.tolist()


        # bg_probs = torch.nn.functional.softmax(bg_logits, dim=1).cpu().numpy()
        # bg_target = bg_target.cpu().numpy()

        # bg_yPredicted += bg_probs.argmax(1).tolist()
        # bg_yTrue += bg_target.tolist()
    
    yPredicted = np.asarray(yPredicted)
    yTrue = np.asarray(yTrue)
    acc = Evaluation(yPredicted, yTrue)
    test_metrics_history.append(acc)

    # bg_yPredicted = np.asarray(bg_yPredicted)
    # bg_yTrue = np.asarray(bg_yTrue)
    # bg_acc = Evaluation(bg_yPredicted, bg_yTrue)
    print(acc)
                    

test - probs shape (7177, 2)
test - probs shape (17871, 2)
test - probs shape (2032, 2)
test - probs shape (4769, 2)
test - probs shape (3834, 2)
test - probs shape (3491, 2)
test - probs shape (18137, 2)


  invalid_mask |= image > 1


test - probs shape (21602, 2)
test - probs shape (9102, 2)
test - probs shape (7595, 2)
test - probs shape (5344, 2)
test - probs shape (14574, 2)
test - probs shape (1183, 2)
test - probs shape (19516, 2)
test - probs shape (6836, 2)
test - probs shape (3626, 2)
test - probs shape (1100, 2)
test - probs shape (11923, 2)
test - probs shape (23276, 2)
test - probs shape (5089, 2)
test - probs shape (1839, 2)
test - probs shape (4423, 2)
test - probs shape (524, 2)
{'macroPrec': 0.7237010091998346, 'microPrec': 0.9850459040454063, 'weightPrec': 0.9908632264056609, 'macroRec': 0.907481110001789, 'microRec': 0.9850459040454063, 'weightRec': 0.9850459040454063, 'macroF1': 0.7875735559452557, 'microF1': 0.9850459040454063, 'weightF1': 0.9872194988371132, 'subsetAcc': 0.9850459040454063, 'IoU': 0.6980411337174643}


In [None]:
# All black
# /kaggle/input/litter-windrows-patches/patches/S2A_MSIL1C_20180916T101021_R022_T33TUL/S2A_MSIL1C_20180916T101021_R022_T33TUL_366560_5053920.tif

In [None]:
#Lightning implementation. To be used later.

# class BinaryClassificationModel(pl.LightningModule):
#     def __init__(self, hparams):
#         super().__init__()
#         self.save_hyperparameters(hparams)

#         # Model selection
#         if hparams.model_name == "resattunet":
#             self.model = ResidualAttentionUNet(11, 11)
#             # Modify for binary classification
#             self.model.decoder = nn.Sequential(
#                 self.model.decoder,
#                 nn.AdaptiveAvgPool2d(1),
#                 nn.Flatten(),
#                 nn.Linear(11, 2)  # Binary output
#             )
#         elif hparams.model_name == "attunet":
#             self.model = AttentionUNet(11, 11)
#             self.model.decoder = nn.Sequential(
#                 self.model.decoder,
#                 nn.AdaptiveAvgPool2d(1),
#                 nn.Flatten(),
#                 nn.Linear(11, 2)
#             )
#         elif hparams.model_name == "unet":
#             self.model = UNet(11, 11)
#             self.model.decoder = nn.Sequential(
#                 self.model.decoder,
#                 nn.AdaptiveAvgPool2d(1),
#                 nn.Flatten(),
#                 nn.Linear(11, 2)
#             )
#         else:
#             raise ValueError("Invalid model name")

#         # Loss function
#         if hparams.focal_loss:
#             self.criterion = FocalLoss()
#         else:
#             weight = gen_weights(class_distr, c=1.03)[:2]  # Binary classes
#             self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=-1)

#         # Track best metrics
#         self.best_macro_f1 = 0.0
#         self.best_micro_f1 = 0.0
#         self.best_weight_f1 = 0.0

#     def forward(self, x):
#         return self.model(x)

#     def training_step(self, batch, batch_idx):
#         images, labels, _ = batch
#         logits = self(images)
#         loss = self.criterion(logits, labels)
#         self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
#         return loss

#     def validation_step(self, batch, batch_idx):
#         images, labels, _ = batch
#         logits = self(images)
#         loss = self.criterion(logits, labels)
#         probs = torch.softmax(logits, dim=1).cpu().numpy()
#         labels = labels.cpu().numpy()
#         preds = probs.argmax(1)
#         return {"loss": loss, "preds": preds.tolist(), "labels": labels.tolist()}

#     def validation_epoch_end(self, outputs):
#         preds = np.concatenate([o["preds"] for o in outputs])
#         labels = np.concatenate([o["labels"] for o in outputs])
#         loss = torch.stack([o["loss"] for o in outputs]).mean()
#         acc = Evaluation(preds, labels)

#         self.log("val_loss", loss, prog_bar=True)
#         self.log("val_macro_precision", acc["macroPrec"], prog_bar=True)
#         self.log("val_macro_recall", acc["macroRec"])
#         self.log("val_macro_f1", acc["macroF1"])
#         self.log("val_micro_precision", acc["microPrec"])
#         self.log("val_micro_recall", acc["microRec"])
#         self.log("val_micro_f1", acc["microF1"])
#         self.log("val_weight_precision", acc["weightPrec"])
#         self.log("val_weight_recall", acc["weightRec"])
#         self.log("val_weight_f1", acc["weightF1"])
#         self.log("val_iou", acc["IoU"])

#         # Update best metrics
#         if acc["macroF1"] > self.best_macro_f1:
#             self.best_macro_f1 = acc["macroF1"]
#         if acc["microF1"] > self.best_micro_f1:
#             self.best_micro_f1 = acc["microF1"]
#         if acc["weightF1"] > self.best_weight_f1:
#             self.best_weight_f1 = acc["weightF1"]

#     def configure_optimizers(self):
#         optimizer = optim.Adam(
#             self.parameters(),
#             lr=self.hparams.initial_lr,
#             weight_decay=self.hparams.decay_lr
#         )
#         if self.hparams.scheduler_lr == "rop":
#             scheduler = optim.lr_scheduler.ReduceLROnPlateau(
#                 optimizer, mode="min", factor=0.1, patience=10, verbose=True
#             )
#             return {
#                 "optimizer": optimizer,
#                 "lr_scheduler": scheduler,
#                 "monitor": "val_loss"
#             }
#         else:
#             scheduler = optim.lr_scheduler.MultiStepLR(
#                 optimizer, milestones=[40, 80, 120, 160], gamma=0.5, verbose=True
#             )
#             return {"optimizer": optimizer, "lr_scheduler": scheduler}

#     def train_dataloader(self):
#         transform = transforms.Compose([
#             transforms.ToTensor(),
#             RandomRotationTransform([-90, 0, 90, 180]),
#             transforms.RandomHorizontalFlip(),
#             transforms.Normalize(bands_mean, bands_std)
#         ])
#         dataset = MergedSegmentationDataset(
#             dataset1_paths=("path/to/dataset1/images", "path/to/dataset1/masks"),
#             dataset2_paths=("path/to/dataset2/images", "path/to/dataset2/masks"),
#             transform=transform
#         )
#         return DataLoader(
#             dataset,
#             batch_size=self.hparams.train_batch_size,
#             shuffle=True,
#             num_workers=4,
#             worker_init_fn=seed_worker,
#             generator=torch.Generator().manual_seed(0)
#         )

#     def val_dataloader(self):
#         transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize(bands_mean, bands_std)
#         ])
#         dataset = MergedSegmentationDataset(
#             dataset1_paths=("path/to/dataset1/images", "path/to/dataset1/masks"),
#             dataset2_paths=("path/to/dataset2/images", "path/to/dataset2/masks"),
#             transform=transform
#         )
#         return DataLoader(
#             dataset,
#             batch_size=self.hparams.test_batch_size,
#             shuffle=False,
#             num_workers=4,
#             worker_init_fn=seed_worker,
#             generator=torch.Generator().manual_seed(0)
#         )

# def seed_worker(worker_id):
#     worker_seed = torch.initial_seed() % 2**32
#     np.random.seed(worker_seed)
#     random.seed(worker_seed)

# def main():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--train_batch_size', type=int, default=8)
#     parser.add_argument('--test_batch_size', type=int, default=4)
#     parser.add_argument('--total_epochs', type=int, default=50)
#     parser.add_argument('--experiment_name', type=str, required=True)
#     parser.add_argument('--initial_lr', type=float, default=1e-3)
#     parser.add_argument('--decay_lr', type=float, default=0)
#     parser.add_argument('--scheduler_lr', type=str, default="ms")
#     parser.add_argument('--focal_loss', type=bool, default=False)
#     parser.add_argument('--model_name', type=str, default="resattunet")
#     args = parser.parse_args()

#     # Set seeds for reproducibility
#     pl.seed_everything(0, workers=True)

#     # Initialize model
#     model = BinaryClassificationModel(args)

#     # Logger
#     logger = TensorBoardLogger(save_dir=args.experiment_name, name="logs")

#     # Callbacks for saving best models
#     checkpoint_macro = ModelCheckpoint(
#         dirpath=args.experiment_name,
#         filename="bestMacroF1Model",
#         monitor="val_macro_f1",
#         mode="max",
#         save_top_k=1
#     )
#     checkpoint_micro = ModelCheckpoint(
#         dirpath=args.experiment_name,
#         filename="bestMicroF1Model",
#         monitor="val_micro_f1",
#         mode="max",
#         save_top_k=1
#     )
#     checkpoint_weight = ModelCheckpoint(
#         dirpath=args.experiment_name,
#         filename="bestWeightF1Model",
#         monitor="val_weight_f1",
#         mode="max",
#         save_top_k=1
#     )

#     # Trainer
#     trainer = pl.Trainer(
#         max_epochs=args.total_epochs,
#         accelerator="gpu" if torch.cuda.is_available() else "cpu",
#         devices=1,
#         logger=logger,
#         callbacks=[checkpoint_macro, checkpoint_micro, checkpoint_weight],
#         deterministic=True
#     )

#     # Train
#     trainer.fit(model)

# # if __name__ == "__main__":
# #     main()