In [1]:
import os
import rasterio
import numpy as np
import pandas as pd
from pyproj import Transformer
import matplotlib.pyplot as plt

In [2]:
class GeoTIFFImage:
    """
    Represents a GeoTIFF image with core attributes and data.
    """

    def __init__(self, filepath):
        """
        Initializes a GeoTIFFImage object.

        Args:
          filepath: Path to the GeoTIFF file.
        """

        with rasterio.open(filepath) as src:
            self.data = src.read()
            self.crs = src.crs
            self.transform = src.transform
            self.width = src.width
            self.height = src.height
            self.count = src.count
            self.band_names = src.descriptions
            self.nodata = src.nodata

        # Extract week number and date metadata directly from filename
        basename = os.path.basename(filepath)
        parts = basename.split('_')
        self.week_number = int(parts[2])
        self.date = parts[3].split('.')[0]  # Remove .tif extension
    

    def print_metadata(self):
        """
        Prints all attributes of the GeoTIFFImage object except for the data.
        """
        for attr, value in self.__dict__.items():
            if attr != 'data':
                print(f"{attr}: {value}")
    

    def print_band_info(self):
        """
        Prints information about the bands in the GeoTIFF image, including
        the number of bands, band names, data type, shape, and pixel counts.
        """
        print(f"# of Available bands: {self.count}\n")
        for i in range(1, self.count + 1):
            print(f"Band {i}: {self.band_names[i-1]}")
            band_data = self.data[i-1]  # Access data from self.data
            total_pixels = band_data.size
            nan_pixels = np.count_nonzero(np.isnan(band_data))
            non_nan_pixels = total_pixels - nan_pixels
            print(f"  - Data type: {band_data.dtype}")
            print(f"  - Shape: {band_data.shape}")
            print(f"  - Total pixels: {total_pixels}")
            print(f"  - NaN pixels: {nan_pixels}")
            print(f"  - Non-NaN pixels: {non_nan_pixels}\n")


    def make_geotiff_map(self, band_index=0, cmap='terrain'):
        """
        Creates a matplotlib plot of a single band of the GeoTIFFImage.

        Args:
          band_index: The index of the band to display (default: 0).
          cmap: The colormap to use (default: 'terrain').

        Returns:
          A matplotlib.pyplot.Figure object.
        """

        # Read the specified band data
        band_data = self.data[band_index]

        # Create the plot
        fig = plt.figure(figsize=(12, 15))  # Adjust figure size as needed
        plt.imshow(band_data, cmap=cmap)
        plt.colorbar()  # Add a colorbar to show the data range

        # Set title and subtitle
        plt.title(self.date)  # Date as the main title
        plt.title(f"{self.date}: Band {band_index + 1} - {self.band_names[band_index]}")

        # Remove x-axis ticks and labels
        plt.xticks([]) 
        plt.yticks([])

        return fig
    

    def save_band_images(self, output_folder):
        """
        Saves images of all bands in the GeoTIFF to the specified output folder.

        Args:
          output_folder: The path to the folder where images will be saved.
        """

        os.makedirs(output_folder, exist_ok=True)  # Create the output folder if it doesn't exist

        for i in range(self.count):
            fig = self.make_geotiff_map(band_index=i)  # Call the make_geotiff_map method
            output_filename = os.path.join(output_folder, f"{self.date}_band_{i+1}_{self.band_names[i]}.png")
            fig.savefig(output_filename)
            plt.close(fig)  # Close the figure to free up memory


    def flatten(self):
      """
      Flattens the GeoTIFF data into a pandas DataFrame with specified columns.

      Returns:
        A pandas DataFrame with columns: week, date, lat, long, crop, dayl, 
        prcp, srad, swe, tmax, tmin, vp, EVI, NDVI, ET, LE, PET, PLE, 
        FPAR, LAI.
      """

      # Create coordinate arrays
      rows, cols = np.meshgrid(np.arange(self.height), np.arange(self.width))
      xs, ys = rasterio.transform.xy(self.transform, rows, cols)

      # Reproject coordinates to lat/lon in degrees (epsg:4326)
      transformer = Transformer.from_crs(self.crs, "epsg:4326")
      lats, lons = transformer.transform(xs, ys)

      # Flatten the data arrays
      data = {
          'week': np.full(lons.size, self.week_number),
          'date': np.full(lons.size, self.date),
          'lat': lats.flatten(),
          'long': lons.flatten(),
      }

      # Assuming band names correspond to the remaining columns
      for i, band_name in enumerate(self.band_names):
          data[band_name] = self.data[i].flatten()

      df = pd.DataFrame(data)
      return df
    

    def flatten_reduced(self):
        """
        Flattens the GeoTIFF data into a pandas DataFrame with specified columns,
        removing rows for entries where all band values are NaN. These full NaN
        rows correspond to full rectengular bounding box required to be captured
        by GeoTIFF, but do not necessarily match clipping mask applied to data 
        (i.e., restricting it just to the state of Illinois)

        Returns:
          A pandas DataFrame with columns: week, date, lat, long, crop, dayl, 
          prcp, srad, swe, tmax, tmin, vp, EVI, NDVI, ET, LE, PET, PLE, 
          FPAR, LAI, excluding rows with all NaN values in the bands.
        """

        df = self.flatten()  # Call the original flatten method

        # Select columns corresponding to the 16 bands
        band_columns = ['crop', 'dayl', 'prcp', 'srad', 'swe', 'tmax', 'tmin', 'vp', 
                        'EVI', 'NDVI', 'ET', 'LE', 'PET', 'PLE', 'FPAR', 'LAI']

        # Filter rows where all band columns are NaN
        df_reduced = df.dropna(subset=band_columns, how='all')

        return df_reduced


def read_geotiff(filepath):
    """
    Reads a GeoTIFF file and returns a GeoTIFFImage object.

    Args:
      filepath: Path to the GeoTIFF file.

    Returns:
      A GeoTIFFImage object.
    """

    # Create a GeoTIFFImage object
    geotiff_image = GeoTIFFImage(filepath)

    return geotiff_image

In [3]:
geotiff_file = 'EarthEngineExports/illinois_1kmx1km_0046_2000-12-29.tif'
test = read_geotiff(geotiff_file)
test.print_metadata()
test.print_band_info()
print(test.flatten().sample(n=10))

crs: EPSG:3347
transform: | 1000.00, 0.00, 6229000.00|
| 0.00,-1000.00, 713000.00|
| 0.00, 0.00, 1.00|
width: 431
height: 660
count: 16
band_names: ('crop', 'dayl', 'prcp', 'srad', 'swe', 'tmax', 'tmin', 'vp', 'EVI', 'NDVI', 'ET', 'LE', 'PET', 'PLE', 'FPAR', 'LAI')
nodata: None
week_number: 46
date: 2000-12-29
# of Available bands: 16

Band 1: crop
  - Data type: float32
  - Shape: (660, 431)
  - Total pixels: 284460
  - NaN pixels: 119547
  - Non-NaN pixels: 164913

Band 2: dayl
  - Data type: float32
  - Shape: (660, 431)
  - Total pixels: 284460
  - NaN pixels: 119547
  - Non-NaN pixels: 164913

Band 3: prcp
  - Data type: float32
  - Shape: (660, 431)
  - Total pixels: 284460
  - NaN pixels: 119547
  - Non-NaN pixels: 164913

Band 4: srad
  - Data type: float32
  - Shape: (660, 431)
  - Total pixels: 284460
  - NaN pixels: 119547
  - Non-NaN pixels: 164913

Band 5: swe
  - Data type: float32
  - Shape: (660, 431)
  - Total pixels: 284460
  - NaN pixels: 119547
  - Non-NaN pixels: 1

In [4]:
def merge_flattened_by_year(geotiffs, year):
    """
    Merges the flatten_reduced() DataFrames of all GeoTIFFImage objects in a list with the same ending year.

    Args:
      geotiffs: A list of GeoTIFFImage objects.
      year: The year to merge DataFrames for.

    Returns:
      A pandas DataFrame containing the merged data for the specified year, or None if no
      matching DataFrames are found.
    """

    merged_df = None
    for geotiff in geotiffs:
        
        # Extract the year from the date attribute of the GeoTIFFImage object
        geotiff_year = int(geotiff.date.split('-')[0])  
        if geotiff_year == year:
            if merged_df is None:
                merged_df = geotiff.flatten_reduced()
            else:
                merged_df = pd.concat([merged_df, geotiff.flatten_reduced()], ignore_index=True)

    return merged_df

In [5]:
# Load all geotiffs into memory
tiff_dir = 'EarthEngineExports'
geotiffs = [tiff_dir + '/{}'.format(fname) for fname in sorted(os.listdir(tiff_dir))]
geotiffs = [read_geotiff(geo) for geo in geotiffs]

In [6]:
#total_row_count = 0
#for year in range(2000, 2024):
#    curr_year_df = merge_flattened_by_year(geotiffs, year)
#    curr_year_df.to_parquet(f"transformed_data/illinois_{year}.parquet", compression="snappy", engine="pyarrow")
#    total_row_count += curr_year_df.shape[0]
#    print(f"{year}: {curr_year_df.shape[0]} x {curr_year_df.shape[1]}")

#print(total_row_count)

In [7]:
geotiffs[0].save_band_images('example/winter_2000')
geotiffs[20].save_band_images('example/summer_2000')