In [None]:
%%capture
! pip install rasterio
! pip install pyresample
! pip install netCDF4
! pip install dvc dagshub

In [None]:
import os
import sys
import rasterio
from rasterio.windows import Window
from rasterio.windows import from_bounds
from rasterio.enums import Resampling
from rasterio.transform import Affine
import shutil
import numpy as np
import tempfile
from pyproj import Transformer
import gdown
import xarray as xr
import dask
import dask.dataframe as dd
import glob
from tqdm import tqdm
import cv2
import dagshub
from dagshub.upload import Repo
import matplotlib.pyplot as plt
import pandas as pd
from glob import glob
import requests



In [None]:
#clone ACOLITE repository 
! git clone https://github.com/acolite/acolite.git

In [None]:
ACOLITE_PATH = "./acolite"
sys.path.append(ACOLITE_PATH)
# Import acolite_run
from acolite.acolite.acolite_run import acolite_run

In [None]:
# Downloading the Littern Windrows Catalog Annotaiton from Zenodo 
record_id = "11045944"
netcdf_file_name = "WASP_LW_SENT2_MED_L1C_B_201506_202109_10m_6y_NRT_v1.0.nc" 
zenodo_url = f"https://zenodo.org/api/records/{record_id}"

# Get the actual download URL

r = requests.get(zenodo_url).json()
download_url = None

for file in r['files']:
    print(file)
    if file['key'] == netcdf_file_name:
        download_url = file['links']['self']
        break

if download_url:
    !wget -O /kaggle/working/annotations.nc {download_url}
else:
    print("File not found in Zenodo record")

In [None]:
from dagshub.auth import add_app_token
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
dagshub_username = user_secrets.get_secret("DAGSHUB_USERNAME")
dagshub_token = user_secrets.get_secret("DAGSHUB_TOKEN")
add_app_token(dagshub_token)

In [None]:
! ls

In [None]:
! ls /kaggle/input/litter-windrows-batch-cala
input_subdirs = [os.path.join('/kaggle/input', f) for f in os.listdir('/kaggle/input/') if 'windrows' in f]
if len(input_subdirs) != 0:
    safe_files_dir = input_subdirs[0]
else:
    print('No Litter Windrows Batch Dataset Found in /kaggle/input dir. Add a Litter Windrows Batch Dataset to the notebook !')
#safe_files_dir ='/kaggle/input/litter-windrows-batch-cala'

In [None]:
def get_utm_limit(utm_x, utm_y,utm_crs, box_meters, pixel_size=10):
    """Compute WGS84 limit from UTM center for box_meters x box_meters."""
    wgs84_crs = "EPSG:4326"
    utm_to_wgs = Transformer.from_crs(utm_crs, wgs84_crs, always_xy=True)
    
    half_box = box_meters / 2
    box_left = utm_x - half_box
    box_right = utm_x + half_box
    box_bottom = utm_y - half_box
    box_top = utm_y + half_box
    
    lon_min, lat_min = utm_to_wgs.transform(box_left, box_bottom)
    lon_max, lat_max = utm_to_wgs.transform(box_right, box_top)
    limit = [lat_min, lon_min, lat_max, lon_max]
    
    print(f"UTM box: left={box_left:.1f}, bottom={box_bottom:.1f}, right={box_right:.1f}, top={box_top:.1f}")
    print(f"WGS84 limit: {limit}")
    return limit, (box_left, box_bottom, box_right, box_top)

In [None]:
def get_dataset_tiles(safe_files_dir):
    tiles = os.listdir(safe_files_dir)
    tiles = [str(t)+'.SAFE' for t in tiles]
    return tiles

In [None]:
# List the tiles in the dataset
#safe_files_dir = '/kaggle/input/litter-windrows-batch-cala'
# tiles = os.listdir(safe_files_dir)
# tiles = [str(t)+'.SAFE' for t in tiles]
tiles = get_dataset_tiles(safe_files_dir)
tiles

In [None]:
def extract_crs_and_bounds(safe_file):
    #granule_path = os.path.join(os.path.join(safe_files_dir, tiles[0]), tiles[0]+'.SAFE/GRANULE/')
    granule_path = os.path.join(safe_file, 'GRANULE/')
    #print(granule_path)
    granule_subdirs = os.listdir(granule_path)  
    img_data_path = os.path.join(granule_path, os.path.join(granule_subdirs[0]),'IMG_DATA/')
    jpg2_files = os.listdir(img_data_path)
    #print('jpg2_files')
    #print(jpg2_files)
    B02_files = [f  for f in jpg2_files if f.endswith('B02.jp2')]
    #print('B02_files')
    #print(B02_files)
    band_path = os.path.join(img_data_path, B02_files[0])
    print(band_path)
    with rasterio.open(band_path) as src:
        print(f"CRS: {src.crs}")  # Should be EPSG:32631
        print(f"Bounds: {src.bounds}")  # Exact UTM coordinates
        data = src.read(1)
        print(f"B02 Valid Pixels: {np.sum(data > 0)} / {data.size}")
    return src.crs, src.bounds, src.transform, src.res

In [None]:
def get_all_tiles(files):
    """
    Given the list of the parquet annotation file, builds the list of all the 
    annotated tiles.
    """
    df = None
    # Process 10 files per batch
    batch_size = 10
    tiles = []
    for i in range(0, len(files), batch_size):
        batch_files = files[i:i + batch_size]
        
        # Read batch with Dask (lazy loading)
        ddf = dd.read_parquet(batch_files, engine='pyarrow')
        
        # Compute to pandas (triggers parallel read)
        batch_df = ddf.compute()  # Now contains data from 10 files
        tiles.append(batch_df['s2_product'].unique())
       
        print(f"Batch {i//batch_size + 1}: {len(batch_df)} rows (from {len(batch_files)} files)")
    
    tiles = np.unique(np.hstack(tiles))
    tiles = [t.decode('utf-8') for t in tiles]
    return tiles

In [None]:
def find_jp2_ref_file(safe_file):
    granule_path = os.path.join(safe_file, 'GRANULE/')
    granule_subdirs = os.listdir(granule_path)  
    img_data_path = os.path.join(granule_path, os.path.join(granule_subdirs[0]),'IMG_DATA/')
    jpg2_files = os.listdir(img_data_path)
    B02_files = [f  for f in jpg2_files if f.endswith('B02.jp2')]
    if len(B02_files) == 0:
        raise ValueError(f'No B2 band jpg2 file retrieve in the safe file {safe_file}')
    else :
        return os.path.join(img_data_path, B02_files[0])
        
def check_for_invalid_data(file_path, bbox):
    with rasterio.open(file_path) as src:
        # Read the window
        window = src.window(*bbox)
        data = src.read(1, window=window)
    
    # Sentinel-2 often uses 0 for invalid pixels
    invalid_mask = (data == 0)
    
    # Or sometimes very high values (like 65535 for uint16)
    if data.dtype == np.uint16:
        invalid_mask = invalid_mask | (data == 65535)
    
    has_invalid = invalid_mask.any()
    
    if has_invalid:
        invalid_count = invalid_mask.sum()
        print(f"Found {invalid_count} invalid pixels in the bounding box")
    else:
        print("No invalid pixels found in the bounding box")
    return has_invalid

def generate_tile_regions(safe_file_path):
    """
    Returns regions splitting the tile image
    in the form of (xmin, ymin, xmax, ymax)
    """
    w = 10980
    h = 10980
    stepx = 2700
    stepy = 2700
    sx = 3000
    sy = 3000
    regions = []
    for i in range(4):
        for j in range(4):
            regions.append((i*stepx, j*stepy, min(w, i*stepx + sx), min(h, j*stepy + sy)))
    ref_file = find_jp2_ref_file(safe_file_path)
    filtered_regions = []
   
    for r in regions :
        invalid = check_for_invalid_data(ref_file, r)
        if not invalid:
            filtered_regions.append(r)
    regions = {i : r for i, r in enumerate(filtered_regions)}
    return regions

def assign_patches(tile_regions, patches):
    splitted_regions = {}
    for p in patches: 
        for idr, r in tile_regions.items():
            inside = p[0] >= r[0] and p[1] >= r[1] and p[2] <= r[2] and p[3] <= r[3]
            if inside:
                if idr  not in splitted_regions:
                    splitted_regions[idr] = [p]
                else :
                    splitted_regions[idr].append(p)
                break
    return splitted_regions

In [None]:
def annotations_from_netcdf(target, netcdf_file_path, output_dir) :
    
    # Create output directory if it does not exist
    os.makedirs(output_dir, exist_ok=True)
    # Step 1: Open NetCDF file with optimized chunks
    try:
        # Chunk by filaments for better memory management
        ds = xr.open_dataset(netcdf_file_path, engine='netcdf4', chunks={'n_filaments': 1000})
    except Exception as e:
        print(f"Error opening NetCDF file: {e}")
        raise

    # Step 2: Create filter for target product
    selected_indices = [2, 4, 5]  
    parts = target.split('_')
    proc_target = '_'.join(parts[i] for i in selected_indices)
    print(f'Target pattern: {proc_target}')

    def is_target_product(x):
        try:
            #if isinstance(x, bytes):
            #    x = x.decode('utf-8')
            parts = str(x).strip().split('_')
            return '_'.join([parts[i] for i in selected_indices]) == proc_target
        except:
            return False

    # Step 3: Vectorized filtering
    # Convert s2_product to string format first
    s2_strings = xr.apply_ufunc(
        lambda x: str(x.decode('utf-8').strip()) if isinstance(x, bytes) else str(x).strip(),
        ds['s2_product'],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[object]
    )

    # Create boolean mask
    mask = xr.apply_ufunc(
        is_target_product,
        s2_strings,
        vectorize=True,
        dask='parallelized',
        output_dtypes=[bool]
    )

    # Step 4: Apply filtering
    ds_filtered = ds.where(mask.compute(), drop=True)  # Compute mask before filtering

    # Step 5: Convert to Dask DataFrame
    df = ds_filtered[['pixel_x', 'pixel_y', 'lat_centroid', 'lon_centroid']].to_dask_dataframe()
    # Step 6: Additional filtering
    df = df[
        (df['pixel_x'] != -999) & 
        (df['pixel_y'] != -999) &
        df['pixel_x'].notnull() & 
        df['pixel_y'].notnull()
    ]

    # Step 7: Process in partitions
    filtered_chunks = []
    for partition in df.partitions:
        chunk = partition.compute()
        filtered_chunks.append(chunk)

    filtered_pandas = pd.concat(filtered_chunks)

    print(f"Final filtered DataFrame shape: {filtered_pandas.shape}")
    print(filtered_pandas.head())

    # Step 8: Save results
    output_fname = target
    if '.SAFE' in output_fname:
        output_fname = output_fname[:-5]
    output_csv = f"{output_dir}/{output_fname}_LWC_annotations.csv"
    filtered_pandas.to_csv(output_csv, index=False)
    print(f"Saved filtered DataFrame: {output_csv}")
    ds.close()
    return filtered_pandas

In [None]:

def build_tile_df(target, files):
    """
    Retrieves all annotations for a given target tile.
    """
    df = None
    # Process 10 files per batch
    batch_size = 10
    tg_id_1 = target.split('_')[2]
    tg_id_2 = '_'.join(target.split('_')[4:6])
    for i in range(0, len(files), batch_size):
        batch_files = files[i:i + batch_size]
        
        # Read batch with Dask (lazy loading)
        ddf = dd.read_parquet(batch_files, engine='pyarrow')
        
        # Compute to pandas (triggers parallel read)
        batch_df = ddf.compute()  # Now contains data from 10 files
        batch_df["s2_product"] = batch_df["s2_product"].str.decode('utf-8')
        batch_df = batch_df[batch_df["s2_product"].str.contains(tg_id_1) & batch_df["s2_product"].str.contains(tg_id_2)]
        if not batch_df.empty:
            if df is None:
                df = batch_df.copy()
            else:
                df = pd.concat([df, batch_df], ignore_index=True)
        print(f"Batch {i//batch_size + 1}: {len(batch_df)} rows (from {len(batch_files)} files)")
    return df

In [None]:
def read_tile_df(target, tile_parquet_dir):
    fpathname = os.path.join(tile_parquet_dir, f'LWR_{target[:-5]}.parquet')
    ddf = dd.read_parquet(fpathname, engine='pyarrow')
    return ddf

In [None]:

def build_tiles_parquet_annotations(tiles, files):
    """
    Retrieves all annotations for a given target tile.
    """
    for target in tiles:
        df = None
        # Process 10 files per batch
        batch_size = 10
        tg_id_1 = target.split('_')[2]
        tg_id_2 = '_'.join(target.split('_')[4:6])
        for i in range(0, len(files), batch_size):
            batch_files = files[i:i + batch_size]
            
            # Read batch with Dask (lazy loading)
            ddf = dd.read_parquet(batch_files, engine='pyarrow')
            
            # Compute to pandas (triggers parallel read)
            batch_df = ddf.compute()  # Now contains data from 10 files
            batch_df["s2_product"] = batch_df["s2_product"].str.decode('utf-8')
            batch_df = batch_df[batch_df["s2_product"].str.contains(tg_id_1) & batch_df["s2_product"].str.contains(tg_id_2)  &
    (~batch_df["pixel_x"].isna()) &
    (~batch_df["pixel_y"].isna()) &
    (batch_df["pixel_x"] != -999) &
    (batch_df["pixel_y"] != -999) ]
            if not batch_df.empty:
                if df is None:
                    df = batch_df.copy()
                else:
                    df = pd.concat([df, batch_df], ignore_index=True)
            print(f"Batch {i//batch_size + 1}: {len(batch_df)} rows (from {len(batch_files)} files)")
        # Save to Parquet
        df.to_parquet(f'LWR_{target[:-5]}.parquet')

In [None]:

def get_relative_path(target_path, start_path):
    """
    Returns the relative path from start_path to target_path
    """
    return os.path.relpath(target_path, start_path)


def get_all_files_recursive(directory):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        #print(files)
        for file in files:
            full_path = os.path.abspath(os.path.join(root, file))
            if os.path.isfile(full_path):  # Check to ensure it's a file
                file_paths.append(full_path)
    return file_paths

def upload_to_dagshub(root_folder, src_folder, dst_folder, batch_id, dagshub_token):
    """
    dst_folder: Must have the format './dest_folder',
    where dst_folder is the destination folder for the images
    relative to the root of the repository.
    
    root_folder: The relative path of src_folder with respect to root_folder
    will determine the position of the uploaded file relative to the root
    of the repository.
    """

    # Defining the Repo & directory
    add_app_token(dagshub_token)
    repo = Repo("elena-andreini", "TriesteItalyChapter_PlasticDebrisDetection")
    ds = repo.directory(dst_folder)
    files = get_all_files_recursive(src_folder)
    print(f'uploading files {files}')
    rel_files_paths = [get_relative_path(f, src_folder) for f in files]
    print(f'relative files {rel_files_paths}')

    for i in rel_files_paths:
      ds.add(file=os.path.join(src_folder, i), path=f'./{i}')
    
    # Commit the changes
    ds.commit(f'Adding batch {batch_id}  to {dst_folder} folder', versioning='dvc')
    print(f'Uploaded batch {batch_id}')

In [None]:

# Define settings form map-mapper paper
settings = {
    # Ensure TIFF export
    'l2r_export_geotiff': True,  
    # delete .nc once made to geotiff
    'l1r_delete_netcdf' : True,
    'l2w_delete_netcdf' : True,
    'verbosity': 5,  # Detailed logging
    's2_target_res': 10,  # 10m resolution
    'resampling_method': 'nearest' , # Set to nearest neighbor
    'atmospheric_correction_method': 'dark_spectrum',
    'dsf_exclude_bands' : ['B9', 'B10'],
    # resolves issue with none type see this thread - https://odnature.naturalsciences.be/remsem/acolite-forum/viewtopic.php?t=319
    'geometry_type' : 'grids',
    #masking
    'l2w_mask' : False,
    #sunglint
    'glint_correction' : True,
    'dsf_residual_glint_correction' : True,
    'dsf_residual_glint_correction_method' : 'alternative', # index error occuring with default
    'dsf_residual_glint_wave_range' : [1500,2400],
    'glint_force_band' : None,
    'glint_mask_rhos_wave' : 1600,
    'glint_mask_rhos_threshold' : 0.11,
    'reproject' : False
}


In [None]:
def generate_tile_mask(tile_df):
    """
    Generate mask for the whole tile.
    Based on dhia's code.
    Not filtering filaments by box_dims size at the moment
    """
    mask = np.zeros((10980, 10980))
    pixels = 0
    #pb = tqdm(tile_df[tile_df["box_dims"] == 3].iterrows())
    pb = tqdm(tile_df.iterrows())
    for index, row in pb:
        x, y = row["pixel_x"], row["pixel_y"]
        if not np.isnan(x) and not np.isnan(y):
            pixels += 1
            x, y = int(x), int(y)
            mask[x, y] = 1
    return mask, pixels
    
def find_subregions_efficient(binary_image, subregion_size, threshold):
    """
    Finds regions with debris pixels in the whole tile annotation mask.
    Using integral images for large binary images.
    """
    h, w = subregion_size
    img_h, img_w = binary_image.shape
    print(f'input maks shape : {binary_image.shape}')
    # Convert to binary and create integral image
    binary = (binary_image == 1).astype(np.uint8)
    integral = cv2.integral(binary)
    
    subregions = []
    covered = np.zeros_like(binary, dtype=bool)
    
    for y in range(0, img_h - h + 1, h):
        for x in range(0, img_w - w + 1, w):
            #if covered[y:y+h, x:x+w].any():
            #    continue
                
            # Calculate sum using integral image
            total = integral[y+h, x+w] - integral[y, x+w] - integral[y+h, x] + integral[y, x]
            white_fraction = total / (h * w)
            #print(f'white_fraction {white_fraction}')
            if total > threshold:
                subregions.append((x, y,  x + w,  y + h, total))
                covered[y:y+h, x:x+w] = True
                
    return subregions

def minimal_1024_cover(small_regions):
    """
    This function could be used to find larger regions around patches
    for ACOLITE application.
    DSF algorithm could be not completely reliable applied to regions of 2560x2560 m unless
    they are very homogeneous
    """
    # Converti le coordinate in celle della griglia 256×256
    cells = set((x // 1024, y // 1024) for (y, x, _, _,_) in small_regions)
    if not cells:
        return []
    large_regions = set()
    for (i, j) in cells:
        large_i = i * 1024
        large_j = j * 1024
        large_regions.add((large_i, large_j))
    return large_regions


def upload_to_drive(file_path, drive_folder_id):
    """Upload file to Google Drive."""
    output = gdown.upload(file_path, parent_id=drive_folder_id, quiet=False)
    print(f"Uploaded {file_path} to Drive folder {drive_folder_id}")
    os.remove(file_path)  # Clear Kaggle disk


In [None]:
def validate_image(image):
    # Define invalid values (NaN, Inf, or negative)
    invalid_mask = np.isnan(image) | np.isinf(image) | (image < 0)
    total_pixels = image.size
    invalid_count = invalid_mask.sum()
    invalid_percentage = (invalid_count / total_pixels) * 100
    print(f"Percentage of invalid values (NaN, Inf, negative) across all bands: {invalid_percentage:.2f}%")
    if invalid_percentage > 10:
        print('skipping region, too many invalid pixels ')
        return False
    else :
        return True

In [None]:
#### Stacking 


def stack_bands(acolite_output_dir, stacked_dir, tile):
    # Mapping wawelength to band names
    wl_to_band  ={
        443 : 'B1',
        492 : 'B2',
        560 : 'B3',
        665 : 'B4',
        704 : 'B5',
        740 : 'B6',
        783 : 'B7',
        833 : 'B8',
        865 : 'B8A',
        1614 : 'B11',
        2202 : 'B12'
    }
    rhos_files = [f for f in os.listdir(acolite_output_dir) if 'rhos' in f and '.tif' in f]
    rhos_files = sorted(rhos_files, key=lambda x: int(x.split('_')[-1][:-4]))
    band_files = [(wl_to_band[int(f.split('_')[-1][:-4])], f) for i,f in enumerate(rhos_files) ]

    # Verify all files exist
    for band, filename in band_files:
        if not os.path.exists(os.path.join(acolite_output_dir, filename)):
            raise FileNotFoundError(f"Missing file: {filename}")

    # Reference band (use B4 for metadata, since all bands are 10m)
    reference_band = 'B4'
    reference_file = os.path.join(acolite_output_dir, [f for b, f in band_files if b == reference_band][0])

    # Open the reference GeoTIFF to get metadata
    with rasterio.open(reference_file) as ref:
        ref_profile = ref.profile
        ref_transform = ref.transform
        ref_crs = ref.crs
        ref_width = ref.width
        ref_height = ref.height
        ref_resolution = ref.res  # Should be (10.0, 10.0)

    # Initialize an array to store all bands
    stacked_data = np.zeros((len(band_files), ref_height, ref_width), dtype=np.float32)
    
    # Read each band (no resampling needed, all are 10m)
    for i, (band, filename) in enumerate(band_files):
        file_path = os.path.join(acolite_output_dir, filename)
        with rasterio.open(file_path) as src:
            # Verify resolution matches
            if src.res != ref_resolution:
                raise ValueError(f"Resolution mismatch in {filename}: expected {ref_resolution}, got {src.res}")
            # Verify dimensions match
            if src.width != ref_width or src.height != ref_height:
                raise ValueError(f"Dimensions mismatch in {filename}: expected {ref_width}x{ref_height}, got {src.width}x{src.height}")
            # Read the band
            data = src.read(1, out_dtype=np.float32)  # Single band, float32
            stacked_data[i] = data

    # Update the profile for the stacked GeoTIFF
    stacked_profile = ref_profile.copy()
    stacked_profile.update({
        'count': len(band_files),  # 11 bands
        'dtype': np.float32,  # For reflectance
        'transform': ref_transform,
        'crs': ref_crs,
        'width': ref_width,
        'height': ref_height,
        'nodata': -999  # Optional: Set nodata value (adjust if ACOLITE uses NaN)
    })
    # Determine patch name similar to MARIDA : S2_DATE_TILE_REGION in folder S2_DATE_TILE
    stacked_tif_fname =  'S2_'+('_'.join(tile.split('_')[2:-3]))+'_stacked.tiff'
    stacked_tif = os.path.join(stacked_dir, stacked_tif_fname)
    # Save the stacked GeoTIFF
    with rasterio.open(stacked_tif, 'w', **stacked_profile) as dst:
        dst.write(stacked_data)
        # Set band descriptions in the usual order
        for i, (band, _) in enumerate(band_files, 1):
            dst.set_band_description(i, band)
    
    print(f"Stacked GeoTIFF saved to: {stacked_tif}")

    # Verify the stacked GeoTIFF
    with rasterio.open(stacked_tif) as stacked:
        print(f"Stacked GeoTIFF:")
        print(f"Bounds: {stacked.bounds}")
        print(f"Dimensions: {stacked.width} columns, {stacked.height} rows")
        print(f"Resolution: {stacked.res}")
        print(f"CRS: {stacked.crs}")
        print(f"Number of bands: {stacked.count}")
        print(f"Band descriptions: {stacked.descriptions}")
    valid = validate_image(stacked_data)
    return stacked_tif, valid

In [None]:
######## Cropping


# Paths to the original GeoTIFF, stacked GeoTIFF, and output
def crop_stacked_file(stacked_tif, cropped_tif, utm_bounds, crs, tile_res):
    
    # Computing metadata for the original patch (before ACOLITE correction)
    ##########################################################
    with rasterio.open(stacked_tif) as src:
        # Calculate the window to read
        window = from_bounds(*utm_bounds, transform=src.transform)

        # Read the data in this window
        data = src.read(window=window)
    
        # Update the transform for the new cropped image
        transform = rasterio.windows.transform(window, src.transform)
    
        # Get band descriptions to preserve metadata
        band_descriptions = src.descriptions
        
        # Write the cropped image
        with rasterio.open(
            cropped_tif,
            'w',
            driver='GTiff',
            height=window.height,
            width=window.width,
            count=src.count,
            dtype=data.dtype,
            crs=src.crs,
            transform=transform,
        ) as dst:
            # Preserve band descriptions (e.g., B1, B2, ...)
            for i, desc in enumerate(band_descriptions, 1):
                if desc:  # Only set if description exists
                    dst.set_band_description(i, desc)
            dst.write(data)
    
    print(f"Cropped GeoTIFF saved to: {cropped_tif}")

    # Verify the cropped GeoTIFF
    with rasterio.open(cropped_tif) as cropped:
        print(f"Cropped GeoTIFF:")
        print(f"Bounds: {cropped.bounds}")
        print(f"Dimensions: {cropped.width} columns, {cropped.height} rows")
        print(f"Resolution: {cropped.res}")
        print(f"CRS: {cropped.crs}")
        print(f"Number of bands: {cropped.count}")
        print(f"Band descriptions: {cropped.descriptions}")

In [None]:
##### Run if you want to extract all tiles from the LW annotation file
#%%time
#all_tiles = get_all_tiles(annotations_files)

In [None]:
def process_tile(target, safe_files_dir, annotations_file, output_dir, remote_dataset_location = './LWC_DataSet/test_data', batch_id='', clean_patches = True, dagshub_token=''):
    acolite_output_dir = os.path.join(output_dir, 'acolite_output/')
    stacked_dir = os.path.join(output_dir, 'stacked/')
    final_dir = os.path.join(output_dir, 'patches/')
    os.makedirs(acolite_output_dir, exist_ok=True)
    os.makedirs(stacked_dir, exist_ok=True)
    if clean_patches : 
        shutil.rmtree(acolite_output_dir, ignore_errors=True)
    os.makedirs(final_dir, exist_ok=True)
    print(f'building annotations dataframe for tile {target}')
    #tile_df = build_tile_df(target, annotations_files)
    tile_df = annotations_from_netcdf(target, annotations_file, output_dir)
    mask, debris_pixels_count = generate_tile_mask(tile_df)
    print(f'tile mask shape {mask.shape}')
    if debris_pixels_count == 0:
        print(f'no valid debris pixels in tile {target}')
        return False
    patches = find_subregions_efficient(mask, (256, 256), 25) #### CHANGED THRESHOLD TO 25 HERE TO EXLUDE MINIMAL PLASTIC PATCHES 
    print(f'patches with marine debris : {patches}')
    target_file_path = os.path.join(os.path.join(safe_files_dir, target[:-5]), target)
    crs, bounds, tile_transform, tile_res = extract_crs_and_bounds(target_file_path)       
    print(f'crs : {crs}')
    print(f'bounds : {bounds}')
    res = tile_res[0] #sentinel 2 max resolution
    tile_regions = generate_tile_regions(target_file_path)
    patches_per_region = assign_patches(tile_regions, patches)
    print(f'patches per region {patches_per_region}')
    for i, ppr in patches_per_region.items(): #iterates over regions containing patches
        shutil.rmtree(stacked_dir, ignore_errors=True)
        os.makedirs(stacked_dir, exist_ok=True)
        current_region = tile_regions[i]
        current_center = ((current_region[0] + current_region[2])/2 * res + bounds.left,
        (current_region[1] + current_region[3])/2 * (-res) + bounds.top)
        #subregion_centers = [((int(p[1] + p[3])/2) * res + bounds.left, int((p[0] + p[2])/2) * res + bounds.bottom) for p in patches]
        print(f'region center : {current_center}')
        #subregion_size = (256 * res, 256 * res)
        current_region_size = ((current_region[2] - current_region[0]) * res, 
                               (current_region[3] - current_region[1]) * res )
        try:
            # Check disk usage
            print("Disk usage before run:")
            !du -sh /kaggle/input/*
            !df -h /kaggle/tmp /kaggle/working
        
            # Clear output directory
            shutil.rmtree(acolite_output_dir, ignore_errors=True)
            os.makedirs(acolite_output_dir, exist_ok=True)


        # Process each subregion
        #for i, (subregion_utm_x, subregion_utm_y) in enumerate(subregion_centers):
            # Clear output directory
            subregion_utm_x, subregion_utm_y = current_center
            print(f'Cleaning previous acolite output')
            shutil.rmtree(acolite_output_dir, ignore_errors=True)
            os.makedirs(acolite_output_dir, exist_ok=True)
            print(f"\nProcessing subregion {i+1}/{len(tile_regions)} centered at UTM ({subregion_utm_x}, {subregion_utm_y})")
    
            # Compute UTM-based limit
            limit, utm_box = get_utm_limit(subregion_utm_x, subregion_utm_y, crs, current_region_size[0]) #subregion_size[0])
            print(f'setting limit : {limit}')
            settings['limit'] = limit
            print(f"Settings for subregion {i+1}: {settings}")
            output_files = acolite_run(settings=settings, inputfile=target_file_path, output=acolite_output_dir)
            print(f"Generated output files {output_files}")
            print('stacking bands')
            stacked_tiff, valid = stack_bands(acolite_output_dir, stacked_dir, target)
            if not valid : 
                print(f'skipping region {current_region} : too many invalid pixels')
                continue
            patch_dir =  os.path.join(final_dir, '_'.join((stacked_tiff.split('/')[-1]).split('_')[:-1]))# '_'.join(stacked_tiff.split('_')[:-1])
            print(f'patch dir {patch_dir}')
            for p in ppr:
                print(f'cropping  patch {p}')
                patch_utm_bound = (p[0] * res + bounds.left, (p[3]) * (-res) + bounds.top,  ## CHANGED -1 removed
                                   p[2] * res + bounds.left,  p[1] * (-res) + bounds.top)
                patch_name_prefix = patch_dir.split('/')[-1]
                patch_name = f'{patch_name_prefix}_{int(patch_utm_bound[0])}_{int(patch_utm_bound[1])}.tif'  ## CHANGED one f removed from .tiff
                patch_pathname = os.path.join(patch_dir, patch_name)
                os.makedirs(patch_dir, exist_ok=True)
                print(f'patch pathname {patch_pathname}')
                crop_stacked_file(stacked_tiff, os.path.join(patch_dir, patch_pathname), patch_utm_bound, crs, tile_res)
                patch_mask = mask[p[1] : p[3], p[0] : p[2]]
                print(f'patch mask shape {patch_mask.shape}')
                try:
                    #cv2.imwrite(os.path.join(patch_dir, f'{patch_pathname[:-4]}_cl.tif'), patch_mask)
                    written_patch_path = os.path.join(patch_dir, patch_name)
                    mask_output_path = os.path.join(patch_dir, f'{patch_name[:-4]}_cl.tif')
                    print(f"Writing mask for: {mask_output_path}")
                    print(f"Mask shape: {patch_mask.shape}, dtype: {patch_mask.dtype}")
                    with rasterio.open(written_patch_path) as src:
                        meta = src.meta.copy()
                        print(f"Original patch size: {(src.width, src.height)}")
                        if patch_mask.shape != (src.height, src.width):
                            raise ValueError(" Mask size does not match patch size")
                    meta.update({
                        "count": 1,
                        "dtype": "uint8",
                        "compress": "lzw"
                    })
                    with rasterio.open(mask_output_path, "w", **meta) as dst:
                        dst.write(patch_mask.astype("uint8"), 1)

                except Exception as e:
                    print(f"Error writing mask: {e}")
            # Insert Sanity Check for ACOLITE output ? 
            
            # Upload to DagsHub
            if dagshub_token:
                upload_to_dagshub(output_dir, final_dir, remote_dataset_location, batch_id, dagshub_token)
    
    

            # Final disk usage
            print("\nDisk usage after run:")
            ! du -sh /kaggle/input/*
            ! df -h /kaggle/tmp /kaggle/working
      
    
        except Exception as e:
            print(f"Exception captured: {e}")


In [None]:
### Setting the directory of the annotations file split and converted to parquet format
#parquet_annotations_dir = '/kaggle/input/lw-parquet/kaggle/working'
#annotations_files = [os.path.join(parquet_annotations_dir, f) for f in os.listdir(parquet_annotations_dir)]

In [None]:
#build_tiles_parquet_annotations([tiles[5]], annotations_files)

In [None]:
def process_batch(tiles, safe_files_dir, annotations_file, batch_id='', dagshub_token=''):
    """
    tiles :  list of tiles (SAFE files) in batch batch_id
    """
    clean_patches = len(dagshub_token) != 0
    for target in tiles:
        process_tile(target, safe_files_dir, annotations_file, '/kaggle/working', batch_id, clean_patches =clean_patches, dagshub_token=dagshub_token)
    

In [None]:
process_batch(tiles,safe_files_dir, '/kaggle/working/annotations.nc',  batch_id='', dagshub_token='')

In [None]:
# will this be fully automated in process_tile with dagshubtoken = dagshubtoken?
upload_to_dagshub('/kaggle/working', '/kaggle/working/patches', './LWC_DataSet/test_data', "", dagshub_token=dagshub_token)

In [None]:
# Sagar's & Navodita's added matching visualisations

In [None]:
def compute_fdi_from_tiff(tiff_path):
    with rasterio.open(tiff_path) as src:
        # Assuming band order follows your stacked TIFF (B1–B12, skipping B10 if needed)
        # Band indices are 1-based in rasterio
        R665 = src.read(4)    # B4
        R859 = src.read(9)    # B8A
        R1610 = src.read(10)  # B11
        # Convert to float and mask invalid values
        R665 = R665.astype(np.float32)
        R859 = R859.astype(np.float32)
        R1610 = R1610.astype(np.float32)
        # Calculate FDI
        FDI = R859 - (R665 + ((R1610 - R665) * (859 - 665) / (1610 - 665)))
        return FDI
def compute_ndwi(tiff_path):
    with rasterio.open(tiff_path) as src:
        Rgreen = src.read(3).astype(np.float32)  # Band 3 (Green)
        Rnir = src.read(8).astype(np.float32)    # Band 8 (NIR)
        ndwi = (Rgreen - Rnir) / (Rgreen + Rnir + 1e-6)  # avoid divide-by-zero
    return ndwi
def plot_fdi(fdi_array, ndwi, img_path, mask_path):
    with rasterio.open(img_path) as src:
        rgb = src.read([4, 3, 2])
        rgb = np.transpose(rgb, (1, 2, 0))
    # Normalization
    rgb = rgb.astype(np.float32)
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
    with rasterio.open(mask_path) as src:
        mask = src.read(1)
    # Create binary mask
    mask_binary = mask > 0
    # Plot side-by-side
    fig, axs = plt.subplots(1, 4, figsize=(15, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Patch")
    axs[1].imshow(mask_binary)  #, cmap='gray')
    axs[1].set_title("Binary Mask (._cl.tif)")
    axs[2].imshow(fdi_array)
    axs[2].set_title("FDI")
    axs[3].imshow(ndwi)
    axs[3].set_title("NDWI")
    for ax in axs:
        ax.axis('off')

In [None]:
# Visualises RGB, mask, FDI and NDWI
# Edit paths to loop over all patches in folder

image_path = '/kaggle/working/patches/S2_20190531T100031_N0500/S2_20190531T100031_N0500_305120_5002720.tif'
mask_path = '/kaggle/working/patches/S2_20190531T100031_N0500/S2_20190531T100031_N0500_305120_5002720_cl.tif'
fdi = compute_fdi_from_tiff(image_path)
ndwi = compute_ndwi(image_path)
plot_fdi(fdi, ndwi, image_path, mask_path)

In [None]:
# Visualises RGB, mask, and matched overlay
# Edit paths to loop over all patches in folder and combine with FDI 
patch_dir = "/kaggle/working/patches/S2_20190531T100031_N0500"
patch_paths = sorted([p for p in glob(os.path.join(patch_dir, "*.tif")) if not p.endswith('*_cl.tif')])
mask_paths = sorted(glob(os.path.join(patch_dir, "*_cl.tif")))
# Ensure both have same count
print(f"Found {len(patch_paths)} RGB patches and {len(mask_paths)} masks.")
# Plot Patch, Mask, and Overlay side-by-side
debris_pixel_counts = []
plt.figure(figsize=(15, 5 * len(mask_paths)))
for i, (patch_path, mask_path) in enumerate(zip(patch_paths, mask_paths)):
    with rasterio.open(patch_path) as patch_src:
        rgb = patch_src.read([4, 3, 2])  # Use bands B4, B3, B2 for RGB
        rgb = np.transpose(rgb, (1, 2, 0))
        rgb = (rgb - np.min(rgb)) / (np.max(rgb) - np.min(rgb) + 1e-6)
    with rasterio.open(mask_path) as mask_src:
        mask = mask_src.read(1)
        mask_binary = (mask > 0).astype(np.uint8)
    debris_pixel_counts.append(np.sum(mask_binary))
    # Create overlay manually
    overlay = rgb.copy()
    overlay[mask_binary == 1] = [1.0, 0.0, 0.0]  # Red color on debris
    # Plot
    plt.subplot(len(mask_paths), 3, 3*i+1)
    plt.imshow(rgb)
    plt.title(f"Patch\n{os.path.basename(patch_path)}")
    plt.axis('off')
    plt.subplot(len(mask_paths), 3, 3*i+2)
    plt.imshow(mask_binary, cmap='gray')
    plt.title(f"Mask\nDebris pixels: {int(np.sum(mask_binary))}")
    plt.axis('off')
    plt.subplot(len(mask_paths), 3, 3*i+3)
    plt.imshow(overlay)
    plt.title("Overlay: Mask on Patch")
    plt.axis('off')
plt.tight_layout()
plt.show()
# Histogram for debris pixels
plt.figure(figsize=(10, 4))
plt.hist(debris_pixel_counts, bins=10, color='teal', edgecolor='black')
plt.title("Histogram of Debris Pixels per Patch")
plt.xlabel("Debris Pixels")
plt.ylabel("Number of Patches")
plt.grid(True)
plt.show()
