In [1]:
from src import *

# 🗺️ Major-TOM Filtering
[![HF](https://img.shields.io/badge/%F0%9F%A4%97-Datasets-yellow)](https://www.huggingface.co/Major-TOM) [![paper](https://img.shields.io/badge/arXiv-2402.12095-D12424)](https://www.arxiv.org/abs/2402.12095) [![GitHub stars](https://img.shields.io/github/stars/ESA-PhiLab/Major-TOM?style=social&label=Star&maxAge=2592000)](https://github.com/ESA-PhiLab/Major-TOM/)

This notebook demonstrates how to access MajorTOM-Core-S2L2A data quickly and filter a subset of interest.

Examples:
1. Filtering based on location, time, and cloud cover
2. Downloading a filtered subset of the dataset
3. PyTorch Dataset with a local copy
4. HuggingFace `datasets` fast access via streaming

### 1. 📅 Filtering based on location, time, and cloud cover
First we will download a local copy of the dataset metadata, in this case from `Major-TOM/Core-S2L2a`

In [None]:
from pathlib import Path
import urllib.request

SOURCE_DATASET = 'Major-TOM/Core-S2L1C' # Identify HF Dataset
DATASET_DIR = Path('./data/Major-TOM/')
DATASET_DIR.mkdir(exist_ok=True, parents=True)
ACCESS_URL = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet?download=true'.format(SOURCE_DATASET)
LOCAL_URL = DATASET_DIR / '{}.parquet'.format(ACCESS_URL.split('.parquet')[0].split('/')[-1])

# download from server to local url
gdf = metadata_from_url(ACCESS_URL, LOCAL_URL)

gdf.head()

In [None]:
len(gdf), type(gdf)

In [9]:
import folium
from folium.plugins import MarkerCluster
from folium.plugins import HeatMap

def create_map(gdf, r=3):
    # Creating a Folium map centered around the coordinates
    m = folium.Map(location=[0, 0], zoom_start=3, control_scale=True)

    # Adding points to the map
    for idx, row in gdf.iterrows():
        folium.CircleMarker(
            location=[row['centre_lat'], row['centre_lon']],
            radius=r,
            color='blue',
            fill=True,
            fill_color='blue',
            fill_opacity=0.7,
            popup=f"Cloud Cover: {row['cloud_cover']}%",
            tooltip=row['product_id']
        ).add_to(m)

    return m

def create_heatmap(gdf, r=15, save_html=False, filename='heatmap.html'):
    # Creating a Folium map centered around the coordinates
    m = folium.Map(location=[0, 0], zoom_start=3, control_scale=True)
    
    # Adding heatmap layer
    HeatMap(data=gdf[['centre_lat', 'centre_lon']], radius=r).add_to(m)

    if save_html:
        m.save(filename)

    return m


# create_map(gdf.sample(1000))

Then, we can specify a few regions using shapely geometry, for example:

In [9]:
from shapely.geometry import box

# Example bounding boxes used for filtering
switzerland = box(5.9559111595,45.8179931641,10.4920501709,47.808380127)
gabon = box(8.1283659854,-4.9213919841,15.1618722208,2.7923006325)
napoli = box(14.091710578,40.7915558593,14.3723765416,40.9819258062)
pacific = box(-153.3922893485,39.6170415622,-152.0423077748,40.7090892316) # a remote patch over pacific - no data

and then use it via our `filter_metadata` function - let's try to get some recent images around ⚽🍕 Napoli!

In [None]:
filtered_df = filter_metadata(gdf,
                              cloud_cover = (0,10), # cloud cover between 0% and 10%
                            #   region=switzerland, # you can try with different bounding boxes, like in the cell above
                              daterange=('2020-01-01', '2025-01-01'), # temporal range
                              nodata=(0.0,0.0) # only 0% of no data allowed
                              )
filtered_df = filtered_df[::12]
display(filtered_df.head())
print(f'Number of images: {len(filtered_df)}')

Any row from the metadata can be very easily read into a `dict` of numpy arrays using our `read_row` function:

In [7]:
# out = read_row(filtered_df.iloc[0], columns = ['B04', 'thumbnail'])

### 📩 Downloading a filtered subset of the dataset

Use the `filter_download` function to download all files to the local directory at `local_dir`. Your new dataset will be named using `source_name`.

More importantly, the `by_row` option allows to download specific rows from the archives. Set it to `True`, if you think you will take only a few files from each parquet file (most parquet files contain samples that are close to each other in space).

If you expect to take most of the samples from the parquet file, setting `by_row` to `False` will probably be quicker (you then download the data as the entire file, before you rearrange it onto folders with only the files from your dataframe).

In [None]:
import warnings
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*")

n_splits = 10
df_list = np.array_split(filtered_df, n_splits)

np.array([len(df_list[i]) for i in range(n_splits)]).sum() == len(filtered_df)

In [9]:
# filter_download(df_list[0], local_dir='/home/ccollado/phileo_phisat2/MajorTOM/', source_name='L1C', by_row=True)

You can now check your local directory for the local version of your dataset!

### 🔥 PyTorch Dataset with a local copy
We can use it directly with our `PyTorch` definition of the Dataset `MajorTOM`, just supply the metadata file and teh directory of the files:

# -----------------------------------------------------
# EXPERIMENT WITH DOWNLOADED DATA
# -----------------------------------------------------

In [1]:
from src import *

In [2]:
tif_bands = ['B02', 'B03', 'B04', 'B08', 'B05', 'B06', 'B07', 'B8A', 'B11', 'B12', 'cloud_mask']

In [None]:
filtered_df = pd.read_csv('filtered_df.csv')
dfs = np.array_split(filtered_df, 100)
index_download = 1
df_to_download = dfs[index_download]
print(f'Subset: {len(df_to_download)} out of {len(filtered_df)}')

In [None]:
ds = MajorTOM(filtered_df, '/home/ccollado/phileo_phisat2/MajorTOM/L1C', tif_bands = tif_bands, combine_bands=False)

ds[0]['meta']

In [None]:
existing_indices, missing_indices, df_existing, df_missing = ds.check_file_existence()

In [None]:
len(df_existing), len(df_missing)

In [7]:
# df_existing.to_csv('df_existing.csv', index=False)
# df_missing.to_csv('df_missing.csv', index=False)

In [8]:
# filter_download(df_missing, local_dir='/home/ccollado/phileo_phisat2/MajorTOM/', source_name='L1C', by_row=True)

In [None]:
create_heatmap(filtered_df,r=10, save_html=True)

# Check Completed Simulations

In [None]:
import os
import statistics
from tabulate import tabulate

def calculate_file_stats(directory, threshold_mb=None):
    try:
        # List all items in the directory
        items = os.listdir(directory)
    except FileNotFoundError:
        print(f"Error: The directory '{directory}' does not exist.")
        return []
    except PermissionError:
        print(f"Error: Permission denied to access '{directory}'.")
        return []

    file_sizes_mb = []
    zero_size_count = 0  # Counter for files with size 0
    files_below_threshold = []  # List to store files below the threshold

    for item in items:
        filepath = os.path.join(directory, item)
        if os.path.isfile(filepath):
            try:
                size_bytes = os.path.getsize(filepath)
                if size_bytes == 0:
                    zero_size_count += 1
                else:
                    size_mb = size_bytes / (1024 ** 2)  # Convert bytes to MB
                    file_sizes_mb.append(size_mb)
                    
                    # Check if the file size is below the threshold
                    if threshold_mb is not None and size_mb < threshold_mb:
                        files_below_threshold.append((item, size_mb))
            except OSError as e:
                print(f"Warning: Could not access '{filepath}'. Reason: {e}")

    if not file_sizes_mb and zero_size_count == 0:
        print(f"No files found in directory '{directory}'.")
        return []

    total_size_gb = sum(file_sizes_mb) / 1024  # Convert MB to GB
    mean_size = statistics.mean(file_sizes_mb) if file_sizes_mb else 0
    std_dev = statistics.stdev(file_sizes_mb) if len(file_sizes_mb) > 1 else 0.0
    min_size = min(file_sizes_mb) if file_sizes_mb else 0
    max_size = max(file_sizes_mb) if file_sizes_mb else 0

    # Prepare data for tabulation
    table = [
        ["Number of files", len(file_sizes_mb) + zero_size_count],
        ["Number of zero-size files", zero_size_count],
        ["Total size (GB)", f"{total_size_gb:.2f}"],
        ["Mean file size (MB)", f"{mean_size:.2f}"],
        ["Standard deviation (MB)", f"{std_dev:.2f}"],
        ["Minimum file size (MB)", f"{min_size:.2f}"],
        ["Maximum file size (MB)", f"{max_size:.2f}"],
    ]

    print(f"\nStatistics for directory: {directory}\n")
    print(tabulate(table, headers=["Statistic", "Value"], tablefmt="grid"))

    # If a threshold is specified, display the list of files below the threshold
    if threshold_mb is not None:
        if files_below_threshold:
            print(f"\nFiles below {threshold_mb} MB:")
            # Prepare table for files below threshold
            threshold_table = [
                [filename, f"{size_mb:.2f} MB"] for filename, size_mb in files_below_threshold
            ]
            print(tabulate(threshold_table, headers=["Filename", "Size"], tablefmt="grid"))
        else:
            print(f"\nNo files are below {threshold_mb} MB.")

    return [filename for filename, _ in files_below_threshold] if threshold_mb is not None else []


directory_path = '/home/ccollado/phileo_phisat2/MajorTOM/tiff_files'
size_threshold = 20.0  # Threshold in MB (e.g., 1 MB)

files_below = calculate_file_stats(directory_path, threshold_mb=size_threshold)


In [11]:
df_existing = pd.read_csv('df_existing.csv')
df_simulation = df_existing.copy().iloc[::6]

df_simulation['unique_identifier'] = (
    df_simulation['product_id'].astype(str) + '__' +
    df_simulation['grid_row_u'].astype(str) + '_' +
    df_simulation['grid_col_r'].astype(str)
)

df_simulation.to_csv('df_simulation.csv', index=False)

In [None]:
create_heatmap(gdf, r=10)

In [None]:
import pandas as pd
from pathlib import Path
import string

# Function to sanitize filenames
def sanitize_filename(name):
    valid_chars = f"-_.() {string.ascii_letters}{string.digits}"
    sanitized = ''.join(c for c in name if c in valid_chars)
    return sanitized

# Ensure required columns exist
if not {'product_id', 'unique_identifier'}.issubset(df_simulation.columns):
    raise ValueError("DataFrame must contain 'product_id' and 'unique_identifier' columns.")

# Define the folder path
folder_path = Path('/home/ccollado/phileo_phisat2/MajorTOM/tiff_files')  # Replace with your actual folder path

if not folder_path.exists():
    raise FileNotFoundError(f"The folder path {folder_path} does not exist.")

# Create the mapping dictionary
id_mapping = pd.Series(df_simulation.unique_identifier.values,
                       index=df_simulation.product_id).to_dict()

# Iterate and rename files
for file in folder_path.glob('*.tif'):
    product_id = file.stem
    unique_id = id_mapping.get(product_id)

    if unique_id:
        sanitized_unique_id = sanitize_filename(unique_id)
        new_filename = f"{sanitized_unique_id}.tif"
        new_file_path = folder_path / new_filename

        if new_file_path.exists():
            print(f"Warning: {new_filename} already exists. Skipping renaming of {file.name}.")
            continue

        try:
            file.rename(new_file_path)
            print(f"Renamed '{file.name}' to '{new_filename}'.")
        except Exception as e:
            print(f"Error renaming {file.name} to {new_filename}: {e}")
    else:
        print(f"No unique_identifier found for product_id '{product_id}'. Skipping file '{file.name}'.")


In [None]:

# Step 1: Calculate the counts of each product_id
product_counts = df_existing['product_id'].value_counts()

# Step 2: Filter product_ids with counts greater than 1
product_ids_gt1 = product_counts[product_counts > 1].index

# Option 1: Using List Comprehension
files = os.listdir('/home/ccollado/phileo_phisat2/MajorTOM/tiff_files')
files = [f[:-4] for f in files if f.endswith('.tif')]

matching_files = [file for file in files if file in product_ids_gt1]

# Option 2: Using Set Intersection for Better Performance (Especially with Large Lists)
files_set = set(files)
product_ids_set = set(product_ids_gt1)
matching_files = list(files_set & product_ids_set)
len(matching_files)

In [None]:
# Delete the files
total_deleted = 0
for file_path in matching_files:
    file_path_tif = os.path.join(directory_path, file_path + '.tif')
    try:
        os.remove(file_path_tif)
        print(f"Deleted: {file_path_tif}")
        total_deleted += 1
    except FileNotFoundError:
        print(f"File not found: {file_path_tif}")
    except PermissionError:
        print(f"Permission denied: {file_path_tif}")
    except Exception as e:
        print(f"Error deleting {file_path_tif}: {e}")

print(f"Total files deleted: {total_deleted}")

In [None]:
import os
import matplotlib.pyplot as plt

def get_file_sizes(directory):
    file_sizes_mb = []
    zero_size_count = 0  # Counter for files with size 0

    try:
        items = os.listdir(directory)
        for item in items:
            filepath = os.path.join(directory, item)
            if os.path.isfile(filepath):
                size_bytes = os.path.getsize(filepath)
                if size_bytes == 0:
                    zero_size_count += 1
                else:
                    size_mb = size_bytes / (1024 ** 2)  # Convert bytes to MB
                    file_sizes_mb.append(size_mb)
    except Exception as e:
        print(f"Error while accessing directory: {e}")

    return file_sizes_mb, zero_size_count

def plot_histogram(file_sizes, plt_title='Distribution of File Sizes'):
    if not file_sizes:
        print("No non-zero files to plot.")
        return

    plt.figure(figsize=(10, 6))
    plt.hist(file_sizes, bins=30, edgecolor='black')
    plt.title(plt_title)
    plt.xlabel('File Size (MB)')
    plt.ylabel('Frequency')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

directory_path = '/home/ccollado/phileo_phisat2/MajorTOM/tiff_files'
file_sizes, zero_size_count = get_file_sizes(directory_path)
plot_histogram(file_sizes)


In [None]:
import time

def get_recent_file_sizes(directory, hours=2):
    current_time = time.time()
    time_threshold = current_time - (hours * 3600)  # Convert hours to seconds
    file_sizes_mb = []
    zero_size_count = 0  # Counter for files with size 0

    try:
        items = os.listdir(directory)
        for item in items:
            filepath = os.path.join(directory, item)
            if os.path.isfile(filepath):
                file_creation_time = os.path.getctime(filepath)
                if file_creation_time >= time_threshold:
                    size_bytes = os.path.getsize(filepath)
                    if size_bytes == 0:
                        zero_size_count += 1
                    else:
                        size_mb = size_bytes / (1024 ** 2)  # Convert bytes to MB
                        file_sizes_mb.append(size_mb)
    except Exception as e:
        print(f"Error while accessing directory: {e}")

    return file_sizes_mb, zero_size_count

hours = 1
recent_file_sizes, recent_zero_size_count = get_recent_file_sizes(directory_path, hours=1)
plot_histogram(recent_file_sizes, plt_title=f'Distribution of File Sizes in the Last {hours} Hours')

if recent_zero_size_count > 0:
    print(f"Number of zero-size files in the last 2 hours: {recent_zero_size_count}")


### Ecco!

In [None]:
for key in ds[0].keys():
    if type(ds[0][key])==torch.Tensor:
        print(key, ds[0][key].mean())

In [None]:
print(f'Keys: {ds[0].keys()}')
if "B04" in ds[0].keys():
    print(f'Shape: {ds[0]["B04"].shape}')
    print(f'dtype: {ds[0]["B04"].dtype}')
elif "bands" in ds[0].keys():
    print(f'Shape: {ds[0]["bands"].shape}')
    print(f'dtype: {ds[0]["bands"].dtype}')

In [None]:
class MultiArrayDataset(Dataset):
    def __init__(
        self, 
        x_data, 
        y_data, 
        transform_x=None, 
        transform_y=None,
        apply_zoom_task=True,
        apply_reconstruction_task=True,
        zoom_range=(1.0, 2.0),
        augment_drop=None,
        device='cpu',
        clip_values=(-3.0, 3.0)
    ):
        """
        Args:
            x_data (MultiArray): Input features (H, W, C).
            y_data (dict): Dictionary with keys like 'coords', 'climate'.
            transform_x (callable, optional): Optional transform on x_data.
            transform_y (callable, optional): Optional transform on y_data.
            apply_zoom_task (bool): Whether to apply the zoom-level prediction task.
            apply_reconstruction_task (bool): Whether to apply the masking reconstruction task.
            zoom_range (tuple): Range (min_zoom, max_zoom) for zoom factor.
            augment_drop (callable): Transformations (RandomErasing) for masking rectangular areas.
            device (str): 'cpu' or 'cuda'.
        """
        self.x_data = x_data
        self.y_coords = y_data['coords']
        self.y_climate = y_data['climate']
        self.transform_x = transform_x
        self.transform_y = transform_y
        self.apply_zoom_task = apply_zoom_task
        self.apply_reconstruction_task = apply_reconstruction_task
        self.zoom_range = zoom_range
        self.augment_drop = augment_drop
        self.device = device
        self.clip_values = clip_values
        self.num_classes = 31

        if not (len(self.x_data) == len(self.y_coords) == len(self.y_climate)):
            raise ValueError("x_data, y_coords, and y_climate must have the same length.")

    def __len__(self):
        return len(self.x_data)

    def zero_to_noise(self, image):
        """
        Applies random erasing transformations to the combined image and a white image to identify erased areas.
        Then replaces the erased areas in the original image with noise.

        image: (C, H, W) torch tensor
        """
        mean_val = torch.mean(image)
        std_val = torch.std(image) + 1e-6

        noise = torch.normal(mean=mean_val, std=std_val, size=image.size(), device=self.device)
        noise = torch.clamp(noise, torch.min(image), torch.max(image))

        # Create a white image
        white = torch.ones_like(image, device=self.device)

        # Concatenate original and white along the channel dimension
        # merged: (2*C, H, W)
        merged = torch.cat([image, white], dim=0)

        # Apply random erasing transforms
        dropped = self.augment_drop(merged)  # Should erase areas in both parts

        # The second half of channels correspond to white image
        C, H, W = image.shape
        erased_mask = (dropped[C:2*C, :, :] == 0)

        # Replace erased areas in the original with noise
        reconstructed = torch.where(erased_mask, noise, dropped[:C, :, :])

        return reconstructed

    def augment_drop_fn(self, image):
        # Just apply zero_to_noise to the entire multispectral image at once
        return self.zero_to_noise(image)

    def __getitem__(self, idx):
        # Load raw data
        x = self.x_data[idx]           # NumPy array (H, W, C)
        y_coords = self.y_coords[idx]  # NumPy array
        y_climate = self.y_climate[idx]# NumPy array

        # Convert to torch and permute to (C, H, W)
        x = torch.tensor(x, dtype=torch.float32, device=self.device).permute(2, 0, 1)
        y_coords = torch.tensor(y_coords, dtype=torch.float32, device=self.device)
        y_climate = torch.tensor(y_climate, dtype=torch.float32, device=self.device).permute(2, 0, 1)

        if self.transform_x is not None:
            x, y_climate = self.transform_x(x, y_climate)  # Modified transform that handles both x and y_climate

        if self.transform_y:
            y_coords = self.transform_y(y_coords)
            y_climate = self.transform_y(y_climate)

        # Label dictionary
        y = {'coords': y_coords}

        # ---------------------------
        # Self-Supervised: Zoom Task
        # ---------------------------
        if self.apply_zoom_task:
            zoom_factor = random.uniform(*self.zoom_range)
            C, H, W = x.shape
            new_H, new_W = int(H * zoom_factor), int(W * zoom_factor)
            
            # Resize x
            zoomed = F.resize(x, (new_H, new_W), antialias=True)
            
            # Resize y_climate similarly
            zoomed_climate = F.resize(y_climate, (new_H, new_W), antialias=True)

            if zoom_factor >= 1.0:
                # Center crop to original size
                top = (new_H - H) // 2
                left = (new_W - W) // 2
                x_zoomed = zoomed[:, top:top+H, left:left+W]
                y_climate_zoomed = zoomed_climate[:, top:top+H, left:left+W]
            else:
                # If zoom_factor < 1.0, pad instead
                x_zoomed = torch.zeros(C, H, W, device=self.device)
                y_climate_zoomed = torch.zeros_like(y_climate)

                pad_h = (H - new_H) // 2
                pad_w = (W - new_W) // 2
                x_zoomed[:, pad_h:pad_h+new_H, pad_w:pad_w+new_W] = zoomed
                y_climate_zoomed[:, pad_h:pad_h+new_H, pad_w:pad_w+new_W] = zoomed_climate

            x = x_zoomed
            y_climate_one_hot = Ftorch.one_hot(y_climate_zoomed.to(torch.int64), num_classes=self.num_classes)
            y_climate_one_hot = y_climate_one_hot.permute(3, 1, 2, 0).squeeze(3).to(torch.float32)
            y['climate'] = y_climate_one_hot
            y['zoom_factor'] = torch.tensor(zoom_factor, dtype=torch.float32, device=self.device).unsqueeze(0)
            
        else:
            y_climate_one_hot = Ftorch.one_hot(y_climate.to(torch.int64), num_classes=self.num_classes)
            y_climate_one_hot = y_climate_one_hot.permute(3, 1, 2, 0).squeeze(3).to(torch.float32)
            y['climate'] = y_climate_one_hot
            y['zoom_factor'] = torch.tensor(1.0, dtype=torch.float32, device=self.device).unsqueeze(0)

        # ---------------------------
        # Self-Supervised: Reconstruction Task
        # ---------------------------
        if self.apply_reconstruction_task and self.augment_drop is not None:
            x_original = x.clone()
            # Mask out areas in the image using RandomErasing transformations
            x_masked = self.augment_drop_fn(x)
            y['reconstruction'] = x_original
            x = torch.clip(x_masked, self.clip_values[0], self.clip_values[1])
        else:
            x = torch.clip(x, self.clip_values[0], self.clip_values[1])
            y['reconstruction'] = None

        # print(f'out shape: {x.shape}, {y["coords"].shape}, {y["climate"].shape}', {y["zoom_factor"].shape}, y["reconstruction"].shape)
        return x, y



class TransformX(nn.Module):
    def __init__(self, device, means, stds, augmentations, clip_values=(-3.0, 3.0), rot_prob=0.2, flip_prob=0.2, noise_prob=0.2, noise_std_range=(0.005, 0.01)):
        super(TransformX, self).__init__()
        self.device = device
        self.means = (means.view(-1, 1, 1)).to(device)
        self.stds = (stds.view(-1, 1, 1)).to(device)
        self.augmentations = augmentations
        self.clip_min, self.clip_max = clip_values
        self.rot_prob = rot_prob
        self.flip_prob = flip_prob
        self.noise_prob = noise_prob
        self.noise_std_low, self.noise_std_high = noise_std_range
        self.rotations = 4  # 0, 90, 180, 270 degrees

    def forward(self, x, y_climate):
        # x: (C, H, W)
        # y_climate: (H, W) or (C', H, W) depending on label format

        # Spatial transforms
        if self.augmentations:
            # 1. Rotation
            if torch.rand(1).item() < self.rot_prob:
                k = random.randint(0, self.rotations - 1)
                x = torch.rot90(x, k, [1, 2])
                if y_climate is not None:
                    y_climate = torch.rot90(y_climate, k, [0, 1]) if y_climate.dim() == 2 else torch.rot90(y_climate, k, [1,2])

            # 2. Flips
            # Horizontal flip (flip width axis)
            if torch.rand(1).item() < self.flip_prob:
                x = torch.flip(x, [2])
                if y_climate is not None:
                    y_climate = torch.flip(y_climate, [1] if y_climate.dim() == 2 else [2])

            # Vertical flip (flip height axis)
            if torch.rand(1).item() < self.flip_prob:
                x = torch.flip(x, [1])
                if y_climate is not None:
                    y_climate = torch.flip(y_climate, [0] if y_climate.dim() == 2 else [1])

        # Non-spatial transforms
        # Normalize x (not y_climate)
        x = x / 10000.0
        x = (x - self.means) / self.stds

        if self.augmentations:
            # Add noise
            if torch.rand(1).item() < self.noise_prob:
                noise_std = random.uniform(self.noise_std_low, self.noise_std_high)
                noise = torch.randn_like(x) * noise_std
                x = x + noise

        x = torch.clamp(x, self.clip_min, self.clip_max)

        return x, y_climate

