<a href="https://colab.research.google.com/github/coltongerth/degradation/blob/main/degradation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install affine==2.3.1
!pip install attrs==22.2.0
!pip install black==22.12.0
!pip install bounded-pool-executor==0.0.3
!pip install certifi==2022.12.7
!pip install click==8.1.3
!pip install click-plugins==1.1.1
!pip install cligj==0.7.2
!pip install mypy-extensions==0.4.3
!pip install numpy==1.24.1
!pip install packaging==23.0
!pip install pandas==1.5.2
!pip install pathspec==0.10.3
!pip install patsy==0.5.3
!pip install platformdirs==2.6.2
!pip install pyparsing==3.0.9
!pip install python-dateutil==2.8.2
!pip install pytz==2022.7
!pip install rasterio==1.3.4
!pip install scipy==1.10.0
!pip install six==1.16.0
!pip install snuggs==1.4.7
!pip install statsmodels==0.13.5
!pip install tomli==2.0.1
!pip install git+https://github.com/lankston-consulting/lcutils

Zonal Statistics and Degradation class instantiation.

In [2]:
import concurrent.futures
import os
import numpy as np
import pickle
import rasterio
import asyncio
import warnings
from statsmodels.regression.linear_model import OLS, GLSAR
from scipy import stats as st
from datetime import datetime
from bounded_pool_executor import BoundedProcessPoolExecutor
from dotenv import load_dotenv
from lcutils import gcs, eet
from google.colab import drive

drive.mount('/content/gdrive')
# this one for users with access to rpms
# path_to_credentials = "gdrive/Shareddrives/LCLLC/fuelcast-storage-credentials.json"
path_to_credentials = "gdrive/MyDrive/fuelcast-storage-credentials.json"
gch = gcs.GcsTools(use_service_account={"keyfile": path_to_credentials})
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

nodata = -3.4e38


class Stat(object):
    def __init__(self, zone):
        self.zone = zone
        self.mean = 0
        self.std = 0
        self.n = 0
        self.data = list()

    def add_data(self, data):
        self.data.append(data)

    def calc_stats(self):
        if self.data:
            self.data = np.array(self.data)
            self.data = np.ma.masked_where(self.data < 0, self.data)
            self.mean = np.ma.mean(self.data, axis=0)
            self.std = np.ma.std(self.data, axis=0)
            self.n = np.ma.count(self.data, axis=0)
            self.n = np.ma.masked_where(
                self.n <= 0, self.n
            )  # n of 0 screws up combining groups later

            # Clean up the object to reduce memory footprint
            del self.data


class StatAccumulator(object):
    def __init__(self, update_size=500):
        # .statistics will be a collection keyed by zone that references a list of Stat objects. As data is collected
        # and the Stat object gets to a determined size, statistics will be calculated and the data will be deleted.
        # New data will go in the accumulator as a new stat object. Rinse and repeat.
        self.statistics = dict()
        self._update_size = update_size
        self.merged_stats = dict()

    def update(self, zone, new_stats, force=False):
        if zone not in self.statistics:
            self.statistics[zone] = list()
            self.statistics[zone].append(new_stats)
        else:
            stats_col = self.statistics[zone]
            old_stats = stats_col[-1]  # Get the latest stat collection

            # If the latest record has over x records, create a new object
            if len(old_stats.data) > self._update_size or force:
                old_stats.calc_stats()  # Clean up the memory footprint
                self.statistics[zone].append(new_stats)
            else:
                [old_stats.data.append(d) for d in new_stats.data]
        return

    def update_cochrane(self, zone, new_stats):
        if zone not in self.statistics:
            self.statistics[zone] = new_stats
        else:
            old_stats = self.statistics[zone]

            tn = old_stats.n + new_stats.n
            tmean = (
                np.ma.add(old_stats.n * old_stats.mean, new_stats.n * new_stats.mean)
                / tn
            )

            # tsd = np.ma.sqrt(((old_stats.n-1) * np.power(old_stats.std, 2) + (new_stats.n - 1) * np.power(new_stats.std, 2) + old_stats.n * new_stats.n / (old_stats.n + new_stats.n) * (np.power(old_stats.mean, 2) + np.power(new_stats.mean, 2) - 2 * old_stats.mean * new_stats.mean)) / (old_stats.n + new_stats.n - 1))

            # N1 - 1 * SD1^2
            t1 = np.ma.add(old_stats.n, -1)
            t2 = np.ma.power(old_stats.std, 2)
            tr = np.ma.multiply(t1, t2)

            # N2 - 1 * SD2^2
            t1 = np.ma.add(new_stats.n, -1)
            t2 = np.ma.power(new_stats.std, 2)
            ts = np.ma.add(t1, t2)

            # (N1*N2)/(N1+N2)
            t1 = np.ma.multiply(old_stats.n, new_stats.n)
            t2 = np.ma.add(old_stats.n, new_stats.n)
            tt = np.ma.divide(t1, t2)

            # (M1^2 + M2^2)
            t1 = np.ma.power(old_stats.mean, 2)
            t2 = np.ma.power(new_stats.mean, 2)
            tu = np.ma.add(t1, t2)

            # 2*M1*M2
            tv = np.ma.multiply(np.ma.multiply(old_stats.mean, new_stats.mean), 2)

            # N1 + N2 -1
            tx = np.ma.add(np.ma.add(old_stats.n, new_stats.n), -1)

            xr = np.ma.add(tr, ts)
            xs = np.ma.subtract(tu, tv)
            xt = np.ma.multiply(tt, xs)
            xu = np.ma.add(xr, xt)

            z = np.ma.divide(xu, tx)

            tsd = np.ma.sqrt(z)

            new_stats.n = tn
            new_stats.mean = tmean
            new_stats.std = tsd
            self.statistics[zone] = new_stats

    def update_multiple(self, zone):
        """
        Iterates over the statistics collection, merging mean, std, and n
        :param zone:
        :param new_stats:
        :return:
        """

        # Update the last of the stat objects
        self.statistics[zone][-1].calc_stats()

        def collect(data, tn, tx, txx):
            n = data.n
            mean = data.mean
            sd = data.std
            x = n * mean
            xx = sd**2 * (n - 1) + x**2 / n
            tn = tn + n
            tx = tx + x
            txx = txx + xx

            return tn, tx, txx

        tn, tx, txx = 0, 0, 0
        for data in self.statistics[zone]:
            tn, tx, txx = collect(data, tn, tx, txx)

        tmean = tx / tn
        tsd = np.ma.sqrt(
            np.ma.divide(
                np.ma.subtract(txx, np.ma.divide(np.ma.power(tx, 2), tn)),
                np.ma.add(tn, -1),
            )
        )

        old_stats = self.statistics[zone][0]
        old_stats.mean = tmean
        old_stats.std = tsd
        old_stats.n = tn

        self.merged_stats[zone] = old_stats

        return

    def merge(self):
        self.merged_stats = dict()

        def chunk_gen(lst, n):
            for i in range(0, len(lst), n):
                yield lst[i : i + n]

        def merge_chunk(indexes):
            key_list = list(self.statistics.keys())
            for i in indexes:
                zone = key_list[i]
                self.update_multiple(zone)

        # with concurrent.futures.ProcessPoolExecutor() as executor:
        #     futures = set()
        #     n = 100
        #     i_list = list(range(len(self.statistics)))
        #     chunks = chunk_gen(i_list, n)
        #     [futures.add(executor.submit(merge_chunk, c)) for c in chunks]
        #     _, __ = concurrent.futures.wait(futures, return_when=concurrent.futures.ALL_COMPLETED)

        [self.update_multiple(z) for z in self.statistics]
        self.statistics = self.merged_stats
        del self.merged_stats
        return

    def write(self, path="./output/zone_stats.csv"):

        with open(path, "w") as f:
            header = "zone, year, mean, std, n\n"
            f.write(header)

            for z in self.statistics:
                data = self.statistics[z]
                for i in range(len(data.mean)):
                    line = "{0}, {1}, {2}, {3}, {4}\n".format(
                        z, i, data.mean[i], data.std[i], data.n[i]
                    )
                    f.write(line)


class ZonalStatistics(object):
    def __init__(self):

        return

    def data_collector(self, *args, **kwargs):

        stats = dict()

        I = args[0]["zone_data"].shape[0]
        J = args[0]["zone_data"].shape[1]
        K = args[0]["zone_data"].shape[2]

        for i in range(I):
            for j in range(J):
                for k in range(K):
                    zone = args[0]["zone_data"][i, j, k]
                    data = args[0]["val_data"][:, j, k]

                    if zone > 0:
                        if zone not in stats:
                            stats[zone] = Stat(zone)
                        stats[zone].add_data(data)

        return stats

    def t_test(self, *args, **kwargs):

        I = args[0]["zone_data"].shape[0]
        J = args[0]["zone_data"].shape[1]
        K = args[0]["zone_data"].shape[2]

        output = np.empty((4, J, K), dtype="float32")
        output.fill(nodata)

        for i in range(I):
            for j in range(J):
                for k in range(K):
                    zone = args[0]["zone_data"][i, j, k]
                    data = args[0]["val_data"][:, j, k]

                    if zone > 0:
                        stat = args[0]["statistics"].statistics[zone]
                        try:
                            vals = self._t_test_strict_r_logic(stat, data)
                            if vals is not None:
                                output[:, j, k] = vals
                        except Exception as ex:
                            # print(ex)
                            pass

        # This should be done all at once at the end, as it uses relative magnitudes of p to correct
        # output = np.ma.masked_where(output == nodata, output)

        # _, adj_p = multipletests(output[1, :, :], method='fdr_bh')
        # output[1, :, :] = adj_p
        # _, adj_p = multipletests(output[3, :, :], method='fdr_bh')
        # output[3, :, :] = adj_p
        # # adj_p = fdrcorrection(adj_p)

        return output

    def _t_test_orig_logic(self, stat, data):
        # Mask out NoData pixels
        data = np.ma.masked_where(data < 0, data)
        # Get the individual stats
        i_mean = np.ma.mean(data)  # This is a single value (mean over time)

        # Bail early if there's not data
        if np.ma.is_masked(i_mean):
            return None

        i_std = np.ma.std(data)
        i_n = np.ma.count(data)
        i_se = i_std / np.ma.sqrt(i_n)

        # Adjust population stats to remove point
        #### NEED TO ADJUST FOR WEIGHTED TEMPORAL MEAN
        # The -1 operations are because we removed a datapoint... maybe make this a variable
        p_n_adj = stat.n - 1
        p_mean_list = ((stat.mean * p_n_adj) - data) / p_n_adj  # This value is a list

        p_n_sum = np.ma.sum(p_n_adj)
        p_weights = np.ma.divide(stat.n, p_n_sum)
        p_weighted_mean = p_mean_list * p_weights

        p_mean = np.ma.sum(p_weighted_mean)
        p_std = np.ma.std(p_weighted_mean)
        p_n = np.ma.sum(p_n_adj)
        p_se = p_std / np.ma.sqrt(p_n)

        # Get the mean difference
        mean_diff = i_mean - p_mean
        # Standard error difference
        se = np.ma.sqrt(i_se**2, p_se**2)

        # t test
        t = mean_diff / se

        # degrees of freedom
        df = i_n + p_n - 2

        # p value
        # p = stats.t.cdf(np.abs(t), df=df) * 2

        # Adjust p for FDR

        years = list(range(1, len(p_mean_list)))
        pop_trend_model = OLS(p_mean_list, years)
        pop_trend_result = pop_trend_model.fit()

        ind_trend_model = GLSAR(data, years)
        ind_trend_result = ind_trend_model.fit()

        i = 1

    def _t_test_strict_r_logic(self, stat, data):
        # Mask out NoData pixels
        data = np.ma.masked_where(data < 0, data)
        # Get the individual stats
        i_mean = np.ma.mean(data)  # This is a single value (mean over time)

        # Bail early if there's not data
        if np.ma.is_masked(i_mean):
            return None

        # Nan the missing values (statsmodels doesn't seem to acknowledge masked arrays)
        nan_data = data.astype(float).filled(np.nan)
        # i_mean_model = GLSAR(nan_data, missing='drop')
        # i_mean_result = i_mean_model.fit()
        # i_se = i_mean_result.bse

        # Adjust population stats to remove point
        # The -1 operations are because we removed a data point
        p_n_adj = stat.n - 1
        p_mean_list = ((stat.mean * p_n_adj) - data) / p_n_adj  # This value is a list
        p_mean_list = np.ma.masked_where(p_mean_list < 0, p_mean_list)
        # p_mean = np.ma.mean(p_mean_list)  # This is a single value

        # Get the mean difference
        # mean_diff = i_mean - p_mean

        # t test
        # t = mean_diff / i_se

        # Skip doing the calculations manually, just do a basic t test
        t, p = st.ttest_rel(nan_data, p_mean_list, nan_policy="omit")

        if np.ma.is_masked(t):
            return None

        # degrees of freedom
        # df = i_mean_result.nobs - len(i_mean_result.params)

        # p value
        # p = st.t.cdf(np.abs(t), df=df) * 2

        # Make years list for regressions
        years = np.array(list(range(len(p_mean_list))))

        # Nan years where there's missing data
        mask_years = np.ma.array(years.astype(float), mask=p_mean_list.mask)
        nan_years = mask_years.filled(np.nan)
        pop_trend_model = GLSAR(p_mean_list, nan_years, missing="drop")
        pop_trend_result = pop_trend_model.fit()

        mask_years = np.ma.array(years.astype(float), mask=data.mask)
        nan_years = mask_years.filled(np.nan)
        ind_trend_model = GLSAR(data, nan_years, missing="drop")
        ind_trend_result = ind_trend_model.fit()

        pop_slope = pop_trend_result.params[0]
        ind_slope = ind_trend_result.params[0]

        slope_diff = ind_slope - pop_slope
        slope_se = ind_trend_result.bse
        # slope_pop_se = pop_trend_result.bse
        slope_t = slope_diff / slope_se

        # slope_n = len(ind_trend_model.endog)
        # slope_pop_n = len(pop_trend_model.endog)

        df = ind_trend_result.nobs - len(ind_trend_result.params)

        slope_p = st.t.sf(np.abs(slope_t), df=df) * 2

        # slope_t, slope_p = st.ttest_ind_from_stats(ind_slope, slope_se, slope_n, pop_slope, slope_pop_se, slope_pop_n, equal_var=False)

        # vals = np.array([t[0], p[0], slope_t[0], slope_p[0]])
        vals = np.array([t, p, slope_t[0], slope_p[0]])

        return vals

class Degradation(object):
    def __init__(self, *args, **kwargs):
        # self.zone_raster_path = kwargs['zone_raster']

        # TODO check for banded raster vs list of rasters
        # self.data_raster_path = kwargs['data_raster']
        return

    def degradation(self, data):

        I = data.shape[0]
        J = data.shape[1]
        K = data.shape[2]
        output = np.empty((I, J, K))

        for i in range(I):
            for j in range(J):
                for k in range(K):
                    val = data[i, j, k]
                    output[i, j, k] = val
        return output


Mounted at /content/gdrive


Inputs

In [3]:
zone_name = "BpsZonRobGb_wgs84_nc"
gcs_degradation_path = "gs://fuelcast-data/degradation/"
gcs_rpms_path = "gs://fuelcast-data/rpms/"

zone_raster_path = f"{gcs_degradation_path}{zone_name}/{zone_name}.tif"
data_raster_path = f"./data/{zone_name}/rpms_stack.tif"
dummy_path = "./test.tif"

In [None]:
out_path = [
    f"./output/{zone_name}_mean_t.tif",
    f"./output/{zone_name}_mean_p_adj.tif",
    f"./output/{zone_name}_slope_t.tif",
    f"./output/{zone_name}_slope_p_adj.tif",
]

if not os.path.exists("./output/"):
    os.makedirs("./output/")

# stats_pickle_path = f"./output/{zone_name}_zs.pkl"


BLOCKSIZE = 8196

nodata = -3.4e38


def main_process(task, zone_file, data_file, out_file, queue_size=1):
    """
    This function constantly tries to write to the destination raster, which makes it suitable for tasks that are
    done in one pass of the data. For things requiring multiple passes, use a function that doesn't write until
    the data processing is done.
    :param task:
    :param zone_file:
    :param data_file:
    :param out_file:
    :param queue_size:
    :return:
    """
    deg = Degradation()

    func = deg.degradation

    # Start rasterio and set environment variables as needed
    with rasterio.Env():
        with rasterio.open(zone_file) as zone_src:
            # Copy the zone_src dataset parameters for the output dataset, and set up tiles
            profile = zone_src.profile
            profile.update(blockxsize=BLOCKSIZE, blockysize=BLOCKSIZE, tiled=True)

            with rasterio.open(out_file, "w", **profile) as dst:
                with rasterio.open(data_file) as data_src:

                    zone_windows = [window for ij, window in dst.block_windows()]
                    # This will track remaining zone_windows. Might
                    window_count = len(zone_windows)
                    print("Window Count:", window_count)

                    # BoundedProcessPoolExecutor expands concurrent.futures.ProcessPoolExecutor to include a semaphore
                    # that blocks process creation when max_workers are active. This keeps memory footprint low.
                    with BoundedProcessPoolExecutor(max_workers=queue_size) as executor:

                        # Create a streaming iterator for zone_windows
                        def stream():
                            yield from iter(zone_windows)

                        # Just a compact function for reading the zone_src dataset
                        def read_zone_ds(window):
                            return zone_src.read(window=window)

                        def read_data(z_window):
                            bounds = zone_src.window_bounds(z_window)
                            data_window = data_src.window(*bounds)
                            return data_src.read(window=data_window)

                        streamer = stream()

                        # This gets redefined every time the streamer iterates, and the object is a set of not
                        # done futures
                        futures = set()

                        # This is our own collection that remembers what window the future object used. It's not reset
                        # every iteration so del finished futures to keep memory low
                        futures_and_windows = dict()

                        # Note that using two collections effectively double active memory footprint... maybe there's
                        # a way around this. But, a set is what the concurrent.futures.wait returns, and that's how you
                        # write finished data and move on.

                        # Process each window
                        for w in streamer:
                            # Multiple zone_windows can finish simultaneously (see below), so attempt to fill the
                            # semaphore every time
                            for i in range(queue_size - len(futures)):
                                try:
                                    window = next(streamer)
                                    ex = executor.submit(
                                        func, read_zone_ds(window), read_data(window)
                                    )
                                    futures_and_windows[ex] = window
                                    futures.add(ex)
                                except StopIteration:
                                    pass

                            # Add the window from the original streamer generator
                            ex = executor.submit(
                                func, read_zone_ds(w), read_data(window)
                            )
                            futures_and_windows[ex] = w
                            futures.add(ex)

                            # When at least one future finishes, get the completed data and do what you want to
                            done, futures = concurrent.futures.wait(
                                futures, return_when=concurrent.futures.FIRST_COMPLETED
                            )

                            for future in done:
                                data = future.result()
                                window = futures_and_windows[future]

                                window_count -= 1
                                print(
                                    f"Remaining: {window_count} || {window} || Size of futures: {len(futures)}"
                                )

                                dst.write(data, window=window)
                                del futures_and_windows[future]

                        # Finish remaining tasks after all zone_windows have been assigned
                        done, futures = concurrent.futures.wait(
                            futures, return_when=concurrent.futures.ALL_COMPLETED
                        )

                        for future in done:
                            data = future.result()
                            window = futures_and_windows[future]
                            print(f"Writing data: window={window}")
                            # with write_lock:
                            dst.write(data, window=window)

                            del futures_and_windows[ex]
    return


def main_statistics(
    task, zone_file, data_file, out_files, queue_size=10, *args, **kwargs
):
    zs = ZonalStatistics()

    if task == "collect":
        accumulator = StatAccumulator()
        func = zs.data_collector
    elif task == "degradation":
        if "acc" in kwargs:
            accumulator = kwargs["acc"]
        else:
            raise ValueError()
        func = zs.t_test

    # Start rasterio and set environment variables as needed
    with rasterio.Env():
        with rasterio.open(zone_file) as zone_src:
            # Copy the zone_src dataset parameters for the output dataset, and set up tiles
            profile = zone_src.profile
            profile.update(
                blockxsize=BLOCKSIZE,
                blockysize=BLOCKSIZE,
                tiled=True,
                dtype="float32",
                compress="DEFLATE",
                nodata=nodata,
            )

            if task == "degradation":
                mean_t_raster = rasterio.open(out_files[0], "w", **profile)
                mean_p_raster = rasterio.open(out_files[1], "w", **profile)
                slope_t_raster = rasterio.open(out_files[2], "w", **profile)
                slope_p_raster = rasterio.open(out_files[3], "w", **profile)

            dummy = rasterio.open(dummy_path, "w", **profile)

            with rasterio.open(data_file) as data_src:

                zone_windows = [window for ij, window in dummy.block_windows()]
                # This will track remaining zone_windows. Might
                window_count = len(zone_windows)
                print("Window Count:", window_count)

                # BoundedProcessPoolExecutor expands concurrent.futures.ProcessPoolExecutor to include a semaphore
                # that blocks process creation when max_workers are active. This keeps memory footprint low.
                with BoundedProcessPoolExecutor(max_workers=queue_size) as executor:

                    # Create a streaming iterator for zone_windows
                    def stream():
                        yield from iter(zone_windows)

                    # Just a compact function for reading the zone_src dataset
                    def read_zone_ds(z_window):
                        return zone_src.read(window=z_window)

                    # For handling different extents (bot not projections!), pass the zone window, convert to
                    # lat/long, and get the appropriate window for the data
                    def read_data(z_window):
                        bounds = zone_src.window_bounds(z_window)
                        d_window = data_src.window(*bounds)
                        return data_src.read(window=d_window), d_window

                    # Set up a window generator
                    streamer = stream()

                    # This gets redefined every time the streamer iterates, and the object is a set of not
                    # done futures
                    futures = set()

                    # This is our own collection that remembers what window the future object used. It's not reset
                    # every iteration so del finished futures to keep memory low
                    futures_and_windows = dict()

                    # Process each window
                    for w in streamer:
                        # Multiple zone_windows can finish simultaneously (see below), so attempt to fill the
                        # semaphore every time
                        for i in range(queue_size - len(futures)):
                            try:
                                stream_window = next(streamer)
                                zone = read_zone_ds(stream_window)
                                data, data_window = read_data(stream_window)

                                func_args = {"zone_data": zone, "val_data": data}

                                if task == "degradation":
                                    func_args["statistics"] = accumulator

                                ex = executor.submit(func, func_args)
                                futures.add(ex)
                                futures_and_windows[ex] = stream_window
                            except StopIteration:
                                pass

                        # Add the window from the original streamer generator
                        data, data_window = read_data(w)
                        zone = read_zone_ds(w)

                        func_args = {"zone_data": zone, "val_data": data}

                        if task == "degradation":
                            func_args["statistics"] = accumulator

                        ex = executor.submit(func, func_args)
                        futures.add(ex)
                        futures_and_windows[ex] = w

                        # When at least one future finishes, get the completed data and do what you want to
                        done, futures = concurrent.futures.wait(
                            futures, return_when=concurrent.futures.FIRST_COMPLETED
                        )

                        for future in done:
                            data = future.result()
                            window = futures_and_windows[future]

                            if task == "collect":
                                [accumulator.update(zone, data[zone]) for zone in data]
                            else:
                                data = [
                                    data[i, :, :].reshape(
                                        1, data.shape[1], data.shape[2]
                                    )
                                    for i in range(4)
                                ]
                                mean_t_raster.write(data[0], window=window)
                                mean_p_raster.write(data[1], window=window)
                                slope_t_raster.write(data[2], window=window)
                                slope_p_raster.write(data[3], window=window)

                            window_count -= 1
                            print(
                                f"Remaining: {window_count} || {window} || Size of futures: {len(futures)}"
                            )

                            del futures_and_windows[future]

                    # Finish remaining tasks after all zone_windows have been assigned
                    done, futures = concurrent.futures.wait(
                        futures, return_when=concurrent.futures.ALL_COMPLETED
                    )

                    for future in done:
                        data = future.result()
                        window = futures_and_windows[future]

                        if task == "collect":
                            [accumulator.update(zone, data[zone]) for zone in data]
                        else:
                            data = [
                                data[i, :, :].reshape(1, data.shape[1], data.shape[2])
                                for i in range(4)
                            ]
                            mean_t_raster.write(data[0], window=window)
                            mean_p_raster.write(data[1], window=window)
                            slope_t_raster.write(data[2], window=window)
                            slope_p_raster.write(data[3], window=window)

                        window_count -= 1
                        print(
                            f"Remaining: {window_count} || {window} || Size of futures: {len(futures)}"
                        )

                        del futures_and_windows[future]

                    if task == "collect":
                        # Merge the collected statistic objects
                        print("Merging collected statistics")
                        accumulator.merge()

                        print("Writing statistics to file")
                        accumulator.write()

                        # with open(stats_pickle_path, "wb") as f:
                        #     pickle.dump(accumulator, f)

                    else:
                        print("uploading")
                        # gch.upload_blob("fuelcast-data",mean_t_raster,"degradation/BpsZonRobGb_wgs84_nc/mean_t.tif")
                        # gch.upload_blob("fuelcast-data",mean_p_raster,"degradation/BpsZonRobGb_wgs84_nc/mean_p_adj.tif")
                        # gch.upload_blob("fuelcast-data",slope_t_raster,"degradation/BpsZonRobGb_wgs84_nc/slope_t.tif")
                        # gch.upload_blob("fuelcast-data",slope_p_raster,"degradation/BpsZonRobGb_wgs84_nc/slope_p_adj.tif")
                        mean_t_raster.close()
                        mean_p_raster.close()
                        slope_t_raster.close()
                        slope_p_raster.close()
                        dummy.close()

                        os.remove(dummy_path)

        return accumulator

# async def raster_stacker(in_ds, out_ds, bounds):
def raster_stacker(id, in_ds, out_ds, bounds):
    with rasterio.open(in_ds, chunks=(1, 1024, 1024), lock=False) as src_ds:
        win = src_ds.window(
            bottom=bounds.bottom,
            right=bounds.right,
            top=bounds.top,
            left=bounds.left,
        )
        print(f"in: {in_ds} || {win}")
        out_ds.write_band(id, src_ds.read(1, window=win))


async def main_run():
    with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS", verbose=2, GOOGLE_APPLICATION_CREDENTIALS=os.getenv("GOOGLE_APPLICATION_CREDENTIALS", path_to_credentials)):
        zone_ds = rasterio.open(zone_raster_path, chunks=(1024, 1024))
        bounds = zone_ds.bounds
        profile = zone_ds.profile
        profile.update(
            blockxsize=1024,
            blockysize=1024,
            tiled=True,
            compress="DEFLATE",
            predictor=2,
            BIGTIFF="Yes",
        )

        od = f"./data/{zone_name}"
        if not os.path.exists(od):
            os.makedirs(od)

        files = list()
        for y in range(1985, 2022):
            if y == 2012:
                continue
            # f = f"./data/{zone_name}/rpms_{y}_mean.tif"
            f = f"gs://fuelcast-data/rpms/{y}/rpms_{y}.tif"
            files.append(f)

        meta = zone_ds.meta
        meta.update(count=len(files))
        profile.update(count=len(files))

        print("Stacking raster")

        stack_path = f"./data/{zone_name}/rpms_stack.tif"

        if os.path.exists(stack_path):
            print(f"Stacked raster {stack_path} already exists.")
        else:
            with rasterio.open(stack_path, "w", **profile) as dst:
                print(f"out: {dst} || {dst.bounds}")

                for id, layer in enumerate(files, start=1):
                    print(f"in: {layer}")
                    raster_stacker(id, layer, dst, bounds)



        print("Calculating zonal statistics")

        # if os.path.exists(stats_pickle_path):
        #     print("Found existing statistics file. Loading.")
        # else:
        acc = main_statistics(
            "collect", zone_raster_path, data_raster_path, out_path, 60
        )

        # with open(stats_pickle_path, "rb") as f:
        #     acc = pickle.load(f)

        print("Running degradation")
        start = datetime.now()
        main_statistics(
            "degradation", zone_raster_path, data_raster_path, out_path, 60, acc=acc
        )
        stop = datetime.now()
        print("Total runtime:", (stop - start).seconds / 60, "minutes")

        print("Finished")

await main_run()



Stacking raster
out: <open DatasetWriter name='./data/BpsZonRobGb_wgs84_nc/rpms_stack.tif' mode='w'> || BoundingBox(left=-123.00035763899075, bottom=32.999762673956695, right=-99.99953377828074, top=50.000289594390246)
in: gs://fuelcast-data/rpms/1985/rpms_1985.tif




in: gs://fuelcast-data/rpms/1985/rpms_1985.tif || Window(col_off=6859.760674567195, row_off=-2284.559238791233, width=85348.00000000047, height=63083.00000000006)
