In [1]:
"""
Plot example image chunks from reference data
"""

import os, sys, time, glob
import geopandas as gpd
import pandas as pd
import rioxarray as rxr
import xarray as xr
import numpy as np
import rasterio as rio
import matplotlib.pyplot as plt

from fiona.crs import from_epsg
from shapely.geometry import box

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision import transforms, utils
from torchsat.models.classification import resnet18

# Custom functions
sys.path.append(os.path.join(os.getcwd(),'code/'))
from __functions import *

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Projection information
wgs = from_epsg(4326)
proj = from_epsg(32613)
print(f'Projected CRS: {proj}')

maindir = '/Users/max/Library/CloudStorage/OneDrive-Personal/mcook/earth-lab/opp-rooftop-mapping'

print("Successfully imported all packages!")

Projected CRS: EPSG:32613
Successfully imported all packages!


In [3]:
def plot_image_chunks(dataset, class_mapping, bands_to_plot, out_file, figsize=(10,10), labelsize=10):
    """
    Plots a single example image chunk for each class with columns representing different bands.
    
    Args:
    - dataset: The dataset containing the image chunks and labels.
    - class_mapping: Dictionary mapping class codes to class labels.
    - bands_to_plot: List of band indices to plot, or 'all' to plot every band in the dataset.
    """
    # Determine which bands to plot
    first_batch = next(iter(dataset))
    num_bands = first_batch['image'].shape[1] if bands_to_plot == 'all' else len(bands_to_plot)
    
    if bands_to_plot == 'all':
        bands_to_plot = list(range(num_bands))
    
    classes = list(class_mapping.keys())
    fig, axes = plt.subplots(len(classes), num_bands, figsize=(num_bands * 3, len(classes) * 3))
    
    # Store examples for each class
    samples_dict = {cls: None for cls in classes}
    
    # Iterate over the DataLoader to find at least one sample for each class
    for batch in dataset:
        images, labels = batch['image'], batch['code']
        
        for i, cls in enumerate(classes):
            if samples_dict[cls] is None:  # If no sample has been found for this class yet
                cls_indices = (labels == cls).nonzero(as_tuple=True)[0]
                if len(cls_indices) > 0:
                    idx = cls_indices[0]  # Take the first available sample
                    samples_dict[cls] = images[idx].numpy()
        
        # Break if samples for all classes have been found
        if all(sample is not None for sample in samples_dict.values()):
            break
    
    # Plot the collected samples
    fig, axes = plt.subplots(len(classes), num_bands, figsize=figsize)
    
    for i, cls in enumerate(classes):
        image = samples_dict[cls]
        if image is None:
            print(f"No samples available for class {cls}.")
            continue
        
        for b, band_idx in enumerate(bands_to_plot):
            band_image = image[band_idx, :, :]  # Select the specific band
            axes[i, b].imshow(band_image, cmap='gray')  # Use grayscale for single-band images
            axes[i, b].axis('off')
            
            # Add title for the first row of each band
            if i == 0:
                axes[i, b].set_title(f"Band {band_idx + 1}", fontsize=8)
            
            # Add class label as y-axis title for the first column
            if b == 0:
                class_label = class_mapping[cls]
                axes[i, b].set_ylabel(class_label, fontsize=labelsize)

    plt.tight_layout()
    plt.subplots_adjust(top=0.95)

    plt.savefig(out_file, dpi=300, bbox_inches='tight')
    plt.show()

print("Plotting function ready !!!")

Plotting function ready !!!


In [4]:
# Load the reference data
ref_fp = os.path.join(maindir,'data/spatial/mod/denver_data/training/denver_data_reference_footprints.gpkg')
ref = gpd.read_file(ref_fp)
ref.head()

Unnamed: 0,uid,class_code,description,areaUTMsqft,lotSizeSqft,geometry
0,78TL,TL,Tile,271.028875,4710.0,"MULTIPOLYGON (((502162.154 4397355.647, 502162..."
1,269TL,TL,Tile,3885.053236,22307.0,"MULTIPOLYGON (((503091.622 4397021.987, 503101..."
2,490TL,TL,Tile,2018.268605,6250.0,"MULTIPOLYGON (((501990.912 4396754.28, 502007...."
3,497TL,TL,Tile,273.843801,6370.0,"MULTIPOLYGON (((502773.275 4396965.742, 502773..."
4,537TL,TL,Tile,281.649002,6000.0,"MULTIPOLYGON (((502162.107 4396885.437, 502168..."


In [5]:
# Create a numeric class code 
ref['code'], _ = pd.factorize(ref['class_code'])
# Create a dictionary mapping class_code to code
class_mapping = dict(zip(ref['code'], ref['class_code']))
print(class_mapping)

{0: 'TL', 1: 'WS', 2: 'CN', 3: 'AP', 4: 'SL', 5: 'TG', 6: 'CS'}


In [None]:
# Identify areas with "pure" material types for examples

In [7]:
# Create centroids
ref_pt = ref.copy()
ref_pt = ref_pt.to_crs(proj)
ref_pt['geometry'] = ref_pt['geometry'].centroid

# Define the window size and half window (for boxes)
window_size = 144 # 4 * average side length (relative neighbors (?))
half_window = window_size / 2

training_windows = [] # image windows with >50% of specific roof type
training_roof_types = [] # roof type codes for valid windows

# Loop through each footprint individually
for geom, roof_type in zip(ref.geometry, ref['class_code']):
    # calculate the image window as footprint buffer
    centroid = geom.centroid
    window = box(centroid.x - half_window, centroid.y - half_window,
                 centroid.x + half_window, centroid.y + half_window)

    # Intersect with centroids to get class count within window
    intersect = ref_pt[ref_pt.intersects(window)]
    
    # Get the total count and count for the class
    total_count = len(intersect)
    class_count = len(intersect[intersect['class_code'] == roof_type])

    # Check if there is at least 50% of the roof type in that window
    if total_count > 0 and (class_count / total_count) > 0.50:
        training_windows.append(geom)
        training_roof_types.append(roof_type)

    del intersect, window, centroid

# Create a GeoDataFrame for the training windows with roof types
ref_windows = gpd.GeoDataFrame({
    'geometry': training_windows, 
    'class_code': training_roof_types
}, crs=ref.crs)

# Create a numeric code for the training data frame
ref_windows['code'], _ = pd.factorize(ref_windows['class_code'])
print("Spatial filtering complete.")

Spatial filtering complete.


In [None]:
ref_windows.class_code.value_counts()

In [None]:
# Perform balanced sampling (random undersampling)
ref_bal = balance_sampling(ref_windows, ratio=50, strategy='undersample')
ref_bal.class_code.value_counts()

In [None]:
# Load the image data

In [None]:
# Load the image stack
stack_da_fp = os.path.join(maindir, 'data/spatial/mod/dc_data/planet-data/dc_0623_psscene8b_final_norm.tif')
stack_da = rxr.open_rasterio(stack_da_fp, mask=True, cache=False).squeeze()

# Extract band names
band_desc = stack_da.attrs['long_name']
n_bands = len(band_desc)

# Create the dictionary mapping long names to band numbers
band_dict = {band_desc[i]: i + 1 for i in range(len(band_desc))}
print(band_dict)

In [None]:
# # Select the red, green, and blue bands
# red_band = stack_da[band_dict['red'] - 1]  # subtract 1 because xarray is 0-indexed
# green_band = stack_da[band_dict['green'] - 1]
# blue_band = stack_da[band_dict['blue'] - 1]

# # Stack the selected bands into an RGB image
# rgb_image = np.stack([red_band, green_band, blue_band], axis=0)

# # Save the RGB image
# rgb_image_path = os.path.join(maindir, 'data/spatial/mod/dc_data/planet-data/dc_data_psscene_rgb.tif')

# # Define the metadata including the long_name attribute
# metadata = {
#     'driver': 'GTiff',
#     'height': rgb_image.shape[1],
#     'width': rgb_image.shape[2],
#     'count': 3,
#     'dtype': rgb_image.dtype,
#     'crs': stack_da.rio.crs,
#     'transform': stack_da.rio.transform(),
#     'compress': 'lzw'
# }

# with rio.open(rgb_image_path, 'w', **metadata) as dst:
#     dst.write(rgb_image)
#     dst.set_band_description(1, 'red')
#     dst.set_band_description(2, 'green')
#     dst.set_band_description(3, 'blue')
    
# print("RGB image saved successfully!")

In [None]:
# # Load our image data to check on the format
# rgb = rxr.open_rasterio(rgb_image_path, masked=True, cache=False).squeeze()
# print_raster(rgb, open_file=False)
# band_names = rgb.long_name
# del stack_da

In [None]:
# Create a DataLoader

# Extract image chunks and labels
image_ds = RoofImageDatasetPlanet(
    ref_bal[['geometry', 'code']], img_path=stack_da_fp, img_dim=144, n_bands=n_bands
)

# Ensure all samples are available
bs = len(image_ds)

# Load the Dataset
da_loader = DataLoader(
    image_ds, 
    batch_size=bs, 
    shuffle=True, 
)
print("Data loaded!")

In [None]:
# Create a plot of image chunks
out_png = os.path.join(maindir,'figures/FigX_denver_image_chunks.png')
plot_image_chunks(da_loader, class_mapping, bands_to_plot='all', out_file=out_png)

In [None]:
# # Create a plot of example image chunks (RGB)

# # Ensure that train_df and other DataFrames have the numeric codes
# ref['code'], class_mapping = pd.factorize(ref['class_code'])
# class_mapping = dict(enumerate(class_mapping))
# print("Class mapping:", class_mapping)

# # Ensure train_df has the numeric codes
# train_df = train_df.dropna(subset=['class_code'])
# train_df['code'] = train_df['class_code'].map(lambda x: list(class_mapping.keys())[list(class_mapping.values()).index(x)])
# val_df['code'] = val_df['class_code'].map(lambda x: list(class_mapping.keys())[list(class_mapping.values()).index(x)])
# test_df['code'] = test_df['class_code'].map(lambda x: list(class_mapping.keys())[list(class_mapping.values()).index(x)])

# # Verify the mapping
# print("Train DataFrame head:")
# print(train_df.head())
# print("Validation DataFrame head:")
# print(val_df.head())
# print("Test DataFrame head:")
# print(test_df.head())

# # Plot example chunks
# plot_example_chunks(train_ds, class_mapping, num_examples=3)