In [9]:
import os

OUTPUT_PATH = "../plots/png"

# make output folder if needed
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [None]:
import h5py

with h5py.File("../data/dataset.h5", 'r') as f:
    print("\nKeys found in the HDF5 file:")
    for key in f.keys():
        data = f[key]
        print(f"- {key} : shape={data.shape}, dtype={data.dtype}")


Keys found in the HDF5 file:
- ID_Parcelles : shape=(14836, 24, 24), dtype=float32
- coords : shape=(14836, 2), dtype=float64
- data : shape=(14836, 12, 4, 24, 24), dtype=float32
- dates : shape=(14836,), dtype=object
- labels : shape=(14836, 24, 24), dtype=float32
- zones : shape=(14836,), dtype=object


In [None]:
with h5py.File("../data/dataset.h5", 'r') as f:
    data = f['data'][:]
    labels = f['labels'][:]
    coords = f['coords'][:]

labels

array([[[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        ...,
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [25.,  0.,  0., ...,  0.,  0.,  0.]],

       [[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        ...,
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.]],

       [[ 0.,  0.,  0., ..., 14., 14., 14.],
        [ 0.,  0.,  0., ..., 14., 14., 14.],
        [ 0.,  0.,  0., ..., 14., 14., 14.],
        ...,
        [ 0.,  0.,  0., ..., 14., 14., 14.],
        [ 0.,  0.,  0., ..., 14., 14., 14.],
        [ 0.,  0.,  0., ..., 14., 14., 14.]],

       ...,

       [[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# flatten all pixel labels
flat = labels.ravel()
uniq, counts = np.unique(flat, return_counts=True)

# sort by counts descending
sort_idx = np.argsort(counts)[::-1]
uniq_sorted  = uniq[sort_idx]
counts_sorted = counts[sort_idx]

# convert labels to strings for ticks
label_strs = uniq_sorted.astype(int).astype(str)

plt.figure(figsize=(8, 5))
plt.bar(label_strs, counts_sorted)
plt.xlabel('Label')
plt.ylabel('Pixel count')
plt.yscale('log') # Help to see low represented labels
plt.title('Global Pixel-Level Label Distribution (sorted)')
plt.xticks(rotation=45)
plt.savefig(OUTPUT_PATH+"/label_distribution.png")
plt.close()


In [13]:
n_examples=3

N, T, B, H, W = data.shape

picks = np.random.choice(N, size=n_examples, replace=False)
fig, axes = plt.subplots(2, n_examples, figsize=(n_examples*3, 6))
for i, idx in enumerate(picks):
    # RGB at first time
    patch = data[idx, 0]  # (B, H, W)
    rgb = np.stack([patch[2], patch[1], patch[0]], axis=-1)
    axes[0,i].imshow(rgb / np.percentile(rgb,98))
    axes[0,i].axis('off')
    axes[0,i].set_title(f'Patch {idx}')
    # label map
    axes[1,i].imshow(labels[idx], cmap='tab20')
    axes[1,i].axis('off')
fig.suptitle('Example Patches (Top: RGB t0; Bottom: Labels)')
plt.tight_layout()
plt.savefig(OUTPUT_PATH+"/patch_example.png")
plt.close()


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.1579763].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.1410799].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.4185686].


In [None]:


def plot_band_histograms(data, out_dir):
    """Plot a histogram of all pixel values for each band."""
    n_bands = data.shape[1]
    for b in range(n_bands):
        band_vals = data[:, b].ravel()
        plt.figure(figsize=(6, 4))
        plt.hist(band_vals, bins=100, log=True)
        plt.xlabel('Pixel value')
        plt.ylabel('Frequency (log scale)')
        plt.title(f'Histogram of Band {b+1}')
        plt.tight_layout()
        fn = os.path.join(out_dir, f'band_{b+1}_hist.png')
        plt.savefig(fn)
        plt.close()


In [None]:


def plot_coords_scatter(coords, labels, out_path):
    """Scatter plot of sample coordinates colored by label."""
    if coords.ndim != 2 or coords.shape[1] < 2:
        print("Coords array is not 2D; skipping scatter plot.")
        return
    x, y = coords[:, 0], coords[:, 1]
    plt.figure(figsize=(6, 6))
    scatter = plt.scatter(x, y, c=labels, s=5, cmap='tab20', alpha=0.6)
    plt.xlabel('X coordinate')
    plt.ylabel('Y coordinate')
    plt.title('Spatial Distribution of Samples')
    plt.colorbar(scatter, label='Label')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


In [None]:


def main():
    # Load everything
    data, labels, coords = load_data(H5_PATH)
    print(f"Loaded data: {data.shape[0]} samples, {data.shape[1]} bands, patch size {data.shape[2:]}")

    # 1) Label distribution
    plot_label_distribution(labels, os.path.join(OUTPUT_DIR, 'label_distribution.png'))
    print("→ Saved label distribution.")

    # 2) Example patches per class
    plot_examples_per_class(data, labels,
                            os.path.join(OUTPUT_DIR, 'examples_per_class.png'),
                            n_examples=NUM_EXAMPLES_PER_CLASS)
    print("→ Saved example patches grid.")

    # 3) Band histograms
    plot_band_histograms(data, OUTPUT_DIR)
    print("→ Saved band histograms.")

    # 4) Optional: coords scatter
    plot_coords_scatter(coords, labels, os.path.join(OUTPUT_DIR, 'coords_scatter.png'))
    print("→ Saved coordinate scatter (if coords were 2D).")


if __name__ == '__main__':
    main()
