In [3]:
import os

# make output folder if needed
os.makedirs("../plots/png", exist_ok=True)

In [4]:
import h5py

with h5py.File("data/Dataset.h5", 'r') as f:
    data = f['data'][:]    # shape: (N, bands, H, W)
    labels = f['labels'][:]  # shape: (N,)
    coords = f['coords'][:]  # shape: (N, 2) or (N, something)

data

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'data/Dataset.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_label_distribution(labels, out_path):
    """Bar chart of the number of samples per label."""
    unique, counts = np.unique(labels, return_counts=True)
    plt.figure(figsize=(8, 5))
    plt.bar(unique.astype(str), counts)
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.title('Dataset Label Distribution')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()



In [None]:


def plot_examples_per_class(data, labels, out_path, n_examples=3):
    """Grid of random image patches per class."""
    unique_labels = np.unique(labels)
    n_classes = len(unique_labels)
    fig, axes = plt.subplots(n_classes, n_examples, figsize=(n_examples*2, n_classes*2))
    for i, lab in enumerate(unique_labels):
        idxs = np.where(labels == lab)[0]
        chosen = np.random.choice(idxs, size=min(n_examples, len(idxs)), replace=False)
        for j, idx in enumerate(chosen):
            patch = data[idx]  # shape (bands, H, W)
            # make an RGB composite: bands 3,2,1 → indices [3]=B4, [2]=B3, [1]=B2 if zero-based
            rgb = np.stack([
                patch[3],  # B4 → red
                patch[2],  # B3 → green
                patch[1],  # B2 → blue
            ], axis=-1)
            ax = axes[i, j] if n_examples > 1 else axes[i]
            ax.imshow(rgb / np.percentile(rgb, 98))  # simple stretch
            ax.axis('off')
            if j == 0:
                ax.set_ylabel(f'Label {lab}', rotation=0, labelpad=40, va='center')
    fig.suptitle('Example Patches per Class', y=0.92)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


In [None]:


def plot_band_histograms(data, out_dir):
    """Plot a histogram of all pixel values for each band."""
    n_bands = data.shape[1]
    for b in range(n_bands):
        band_vals = data[:, b].ravel()
        plt.figure(figsize=(6, 4))
        plt.hist(band_vals, bins=100, log=True)
        plt.xlabel('Pixel value')
        plt.ylabel('Frequency (log scale)')
        plt.title(f'Histogram of Band {b+1}')
        plt.tight_layout()
        fn = os.path.join(out_dir, f'band_{b+1}_hist.png')
        plt.savefig(fn)
        plt.close()


In [None]:


def plot_coords_scatter(coords, labels, out_path):
    """Scatter plot of sample coordinates colored by label."""
    if coords.ndim != 2 or coords.shape[1] < 2:
        print("Coords array is not 2D; skipping scatter plot.")
        return
    x, y = coords[:, 0], coords[:, 1]
    plt.figure(figsize=(6, 6))
    scatter = plt.scatter(x, y, c=labels, s=5, cmap='tab20', alpha=0.6)
    plt.xlabel('X coordinate')
    plt.ylabel('Y coordinate')
    plt.title('Spatial Distribution of Samples')
    plt.colorbar(scatter, label='Label')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


In [None]:


def main():
    # Load everything
    data, labels, coords = load_data(H5_PATH)
    print(f"Loaded data: {data.shape[0]} samples, {data.shape[1]} bands, patch size {data.shape[2:]}")

    # 1) Label distribution
    plot_label_distribution(labels, os.path.join(OUTPUT_DIR, 'label_distribution.png'))
    print("→ Saved label distribution.")

    # 2) Example patches per class
    plot_examples_per_class(data, labels,
                            os.path.join(OUTPUT_DIR, 'examples_per_class.png'),
                            n_examples=NUM_EXAMPLES_PER_CLASS)
    print("→ Saved example patches grid.")

    # 3) Band histograms
    plot_band_histograms(data, OUTPUT_DIR)
    print("→ Saved band histograms.")

    # 4) Optional: coords scatter
    plot_coords_scatter(coords, labels, os.path.join(OUTPUT_DIR, 'coords_scatter.png'))
    print("→ Saved coordinate scatter (if coords were 2D).")


if __name__ == '__main__':
    main()
