In [None]:
%load_ext autoreload
%autoreload 2

from s3fs import S3FileSystem

fs = S3FileSystem()

import random
import rasterio as rio
import numpy as np
import xarray as xr
import dask
import copy
import scipy
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
import time

from carbonplan_trace.v1 import emissions_workflow

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
import seaborn as sns

sns.set()

## functions


In [None]:
def plot_diff(ds, year0, year1):
    (
        (ds.sel(time=year1).drop("time") - ds.sel(time=year0).drop("time"))
        / (ds.sel(time=year0).drop("time"))
        * 100
    ).plot(cmap="RdBu", vmin=-100, vmax=100, cbar_kwargs={"label": "%-age change in AGB"})


def plot_ts_diff(ds):
    diff = ds - ds.shift(time=1)
    diff.isel(time=slice(1, 7)).plot(col="time", col_wrap=3, cmap="RdBu", vmax=250, vmin=-250)

## simple plots


In [None]:
ds = xr.open_zarr("s3://carbonplan-climatetrace/v1/results/tiles/50N_120W.zarr")
ds.AGB.isel(time=3)[::100, ::100].plot()

In [None]:
lat = 48.020428
lon = -117.868181  # -117.861472  #
buffer = 0.000125
pixels = 70

sub = ds.AGB.sel(
    lat=slice(lat - buffer * pixels, lat + buffer * pixels),
    lon=slice(lon - buffer * pixels, lon + buffer * pixels),
)

sub.plot(col="time", col_wrap=3, cmap="Greens", vmin=0, vmax=400)

In [None]:
plot_ts_diff(sub)

In [None]:
lat = 48.026050
pixels = 1
ds.AGB.sel(
    lat=slice(lat - buffer * pixels, lat + buffer * pixels),
    lon=slice(lon - buffer * pixels, lon + buffer * pixels),
).plot()
plt.show()

## effect of spatial aggregation on temporal variability


In [None]:
# pick a 2x2 degree tile to work with
sub = ds.sel(lat=slice(48, 50), lon=slice(-118, -116))
sub

In [None]:
# sources and sink shrink when we increase spatial aggregation, but it's still unclear what the threshold should be

plt.figure()
ax = plt.gca()

for res in [1, 5, 10, 50, 100, 500, 1000]:
    if res == 1:
        coarsened = sub
    else:
        coarsened = sub.coarsen(lat=res, lon=res).sum()

    flux = emissions_workflow.calc_biomass_change(ds=coarsened)
    sources = flux.clip(min=0).sum(dim=["lat", "lon"])
    #     sinks = flux.clip(max=0).sum(dim=['lat', 'lon'])
    sources.AGB.plot(ax=ax, label=f"res = {int(30*res)}m")
#     sinks.AGB.plot(ax=ax, label=f'res = {int(30*res)}m')

min_lat = sub.lat.min().round().values.astype(int)
max_lat = sub.lat.max().round().values.astype(int)
min_lon = sub.lon.min().round().values.astype(int)
max_lon = sub.lon.max().round().values.astype(int)

plt.title(f"lat {min_lat}-{max_lat}, lon {min_lon}-{max_lon}")
plt.ylabel("Total AGB Reduction (Mg/ha)")
plt.legend()
plt.show()
plt.close()

In [None]:
# idea: without noises the changes should go to 0 when we look at a smaller region
# judging by these figures we should pick a ~1km resolution, but it won't be getting at the temporal variability problem necessarily

for _ in range(10):
    i = random.randint(0, 7900)
    j = random.randint(0, 7900)
    ss = sub.isel(lat=slice(i, i + 100), lon=slice(j, j + 100))

    plt.figure()
    ax = plt.gca()
    for res in [1, 5, 10, 25, 50, 100]:
        if res == 1:
            coarsened = ss
        else:
            coarsened = ss.coarsen(lat=res, lon=res).sum()

        flux = emissions_workflow.calc_biomass_change(ds=coarsened)
        sources = flux.clip(min=0).sum(dim=["lat", "lon"])
        sources.AGB.plot(ax=ax, label=f"res = {int(30*res)}m")

    plt.title(
        f"lat {(ss.lat.min().values + ss.lat.max().values) / 2}, lon {(ss.lon.min().values + ss.lon.max().values) / 2}"
    )
    plt.ylabel("Total AGB Gain (Mg/ha)")
    plt.legend()
    plt.show()
    plt.close()

## histograms of abs change, pct change, and z score of change


In [None]:
flux = emissions_workflow.calc_biomass_change(ds=sub)
flux_mean = flux.mean(dim=["time"])  # .compute()
flux_std = flux.std(dim=["time"])  # .compute()

zscore_flux = (flux - flux_mean) / flux_std

biomass = sub.isel(time=slice(1, None))
pct_flux = flux / biomass

In [None]:
v = flux.AGB.values
v = v[~np.isnan(v)]

bins = list(np.arange(-40, 40))
plt.hist(v, bins=bins)
plt.xlabel("AGB change (Mg/ha)")
plt.show()
plt.close()

In [None]:
v = pct_flux.AGB.values * 100
v = v[~np.isnan(v)]

bins = list(np.arange(-50, 50))
plt.hist(v, bins=bins)
plt.xlabel("% AGB change")
plt.show()
plt.close()

In [None]:
v = zscore_flux.AGB.values
v = v[~np.isnan(v)]
bins = list(np.linspace(-4, 4, 81))
plt.hist(v, bins=bins)
plt.xlabel("Z score AGB change")
plt.show()
plt.close()

## change detection using zscore


In [None]:
v = (zscore_flux > 1.96).sum(dim=["time"]).AGB.values
np.unique(v, return_counts=True)

In [None]:
for i in range(20):
    print(i)
    plt.figure()
    ax1 = plt.gca()
    sub.AGB.isel(lat=100 + i, lon=100 + i).plot.line("b", ax=ax1)
    ax2 = ax1.twinx()
    zscore_flux.AGB.isel(lat=100 + i, lon=100 + i).plot.line("ro", ax=ax2)

    ax1.set_ylabel("AGB (Mg/ha)", color="b")
    ax2.set_ylabel("Z Score Flux", color="r")
    ax2.set_yticks([-2.5, -1.96, -1, 0, 1, 1.96, 2.5])
    plt.show()
    plt.close()

## change detection using chow test


In [None]:
# ds = xr.open_zarr('s3://carbonplan-climatetrace/v1/results/tiles/50N_130W.zarr')
# subset = ds.AGB.sel(lat=slice(47.64, 47.68), lon=slice(-121.7, -121.61))
# # subset = ds.AGB.sel(lat=slice(47.74, 47.78), lon=slice(-121.8, -121.71))
# ds = xr.open_zarr('s3://carbonplan-climatetrace/v1/results/tiles/40N_130W.zarr')
# subset = ds.AGB.sel(lat=slice(39.0, 39.1), lon=slice(-123.1, -123))

ds = xr.open_zarr("s3://carbonplan-climatetrace/v1/results/tiles/50N_120W.zarr").AGB
sub = ds.sel(lat=slice(48, 50), lon=slice(-118, -116))

In [None]:
flux = emissions_workflow.calc_biomass_change(ds=sub)
flux_mean = flux.mean(dim=["time"])
flux_std = flux.std(dim=["time"])
zscore_flux = (flux - flux_mean) / flux_std

In [None]:
# x = xr.DataArray(
#     1,
#     dims=sub.dims,
#     coords=sub.coords
# ).cumsum(dim='time')

In [None]:
# import time

In [None]:
from carbonplan_trace.v1.change_point_detection import *

In [None]:
# x = x.astype('int8')

In [None]:
# t1 = time.time()
# slope_3d, intercept_3d, rss_3d, pvalue_3d = linear_regression_3D(x=x.astype('int8'), y=sub)
# t2 = time.time()

# print((t2 - t1) / 60.)

In [None]:
pred, pvalue, has_breakpoint = perform_change_detection(ds)

In [None]:
import random

In [None]:
for _ in range(20):
    #     i = random.randint(0, len(sub.lat))
    #     j = random.randint(0, len(sub.lon))
    i = 100 + _
    j = i
    print(i, j)

    ts = sub.isel(lat=i, lon=j).values
    result = perform_sup_f_test(ts)

    plt.figure()
    ax1 = plt.gca()
    ax1.plot(sub.time.values, ts, "b", label="raw")
    ax1.plot(sub.time.values, result, "k", label="2d pred")
    ax1.plot(sub.time.values, pred.isel(lat=i, lon=j).values, "r", label="3d pred")

    ax1.set_ylabel("AGB (Mg/ha)", color="b")

    #     print(result['pvalue'])
    plt.legend()
    plt.show()
    plt.close()

In [None]:
# caveats

# only allowing for 1 break point, thus fitting 2 "discontinuous" linear regression functions
# need to figure out a way to filter out positive changes (?)

# todo
# look for zero biomass maps
# identify deforestation examples and verify ~100% drop
# use v0 to identify the stand replacement clearings

# 100 random pixels for validation/accuracy tests for break point detection, oversampling positive break point detections

In [None]:
# subset.isel(lat=slice(60, 80), lon=slice(120, 135))

for i in range(15):
    print(f"looking at the pixel {i}")
    ts = subset.isel(lat=60 + i, lon=120 + i).values
    result = perform_sup_f_test(ts)

    plt.figure()
    ax1 = plt.gca()
    subset.isel(lat=60 + i, lon=120 + i).plot.line("b", ax=ax1)
    ax1.plot(subset.time.values, result["pred"], "k")
    ax2 = ax1.twinx()
    zscore_flux.isel(lat=60 + i, lon=120 + i).plot.line("ro", ax=ax2)

    ax1.set_ylabel("AGB (Mg/ha)", color="b")
    ax2.set_ylabel("Z Score Flux", color="r")
    ax2.set_yticks([-2.5, -1.96, -1, 0, 1, 1.96, 2.5])
    print(f"overall pvalue =", result["pvalue"])
    plt.show()
    plt.close()

In [None]:
subset_result = xr.apply_ufunc(
    perform_sup_f_test,
    subset,
    input_core_dims=[["time"]],
    output_core_dims=[["time"]],
    output_dtypes=["float"],
    dask="parallelized",
    vectorize=True,
    dask_gufunc_kwargs={"allow_rechunk": 1},
).load()

In [None]:
subset.plot(col="time", col_wrap=3, cmap="Greens", vmin=0, vmax=400)

In [None]:
subset_result.plot(col="time", col_wrap=3, cmap="Greens", vmin=0, vmax=400)

In [None]:
plot_ts_diff(subset)
plt.suptitle(f"Original Flux", y=1)
plt.savefig(f"original_flux.png")
plt.show()

In [None]:
# 95%
plot_ts_diff(subset_result)
plt.suptitle(f"Smoothed Flux", y=1)
plt.savefig(f"smoothed_flux.png")
plt.show()

In [None]:
# 90%
plot_ts_diff(subset_result)
plt.suptitle(f"Smoothed Flux", y=1)
plt.savefig(f"smoothed_flux.png")
plt.show()

## Components analysis


In [None]:
from sklearn import decomposition
from numpy.random import RandomState

In [None]:
ds = xr.open_zarr("s3://carbonplan-climatetrace/v1/results/tiles/50N_130W.zarr")
subset = ds.AGB.sel(lat=slice(47.64, 47.68), lon=slice(-121.7, -121.61))
# # subset = ds.AGB.sel(lat=slice(47.74, 47.78), lon=slice(-121.8, -121.71))
# ds = xr.open_zarr('s3://carbonplan-climatetrace/v1/results/tiles/40N_130W.zarr')
# subset = ds.AGB.sel(lat=slice(39.0, 39.1), lon=slice(-123.1, -123))

In [None]:
rng = 0
image_size = subset.isel(time=0).shape
n_components = 6

diff = (subset - subset.shift(time=1)).isel(time=slice(1, None)).fillna(0)
sample_size = len(diff.time)
data = diff.values.reshape(sample_size, -1)

# global centering
global_mean = data.mean(axis=0)
data_centered = data - global_mean

In [None]:
estimators = [
    (
        "MiniBatchSparsePCA",
        decomposition.MiniBatchSparsePCA(
            n_components=n_components, n_iter=100, batch_size=3, random_state=rng
        ),
        True,
    ),
    (
        "PCA",
        decomposition.PCA(n_components=n_components, svd_solver="randomized", whiten=True),
        True,
    ),
    ("ICA", decomposition.FastICA(n_components=n_components, whiten=True), True),
    (
        "MiniBatchDictionaryLearning",
        decomposition.MiniBatchDictionaryLearning(
            n_components=15, alpha=0.1, n_iter=50, batch_size=3, random_state=rng
        ),
        True,
    ),
    ("FA", decomposition.FactorAnalysis(n_components=n_components, max_iter=20), True),
]

In [None]:
def plot_components(name, components, n_components=n_components, image_size=image_size):
    da = xr.DataArray(
        data=components.reshape(n_components, image_size[0], image_size[1]),
        dims=["component", "lat", "lon"],
        coords=[np.arange(n_components), subset.lat.values, subset.lon.values],
    )

    if n_components > 1:
        da.plot(col="component", col_wrap=3, cmap="RdBu")

    else:
        da.plot(cmap="RdBu")

    plt.suptitle(f"{name} components", y=1)
    plt.savefig(f"{name}_components.png")
    plt.show()


def inverse_transform(estimator, d, center):
    if center:
        return np.dot(d, estimator.components_) + estimator.mean_ + global_mean
    else:
        return np.dot(d, estimator.components_) + estimator.mean_


def inverse_transform2(estimator, d, i, center):
    projected = project(d, estimator.components_[0])
    for j in range(i):
        projected += project(d, estimator.components_[j + 1])

    if center:
        return d + global_mean - projected
    else:
        return d - projected


def project(x, y):
    # projects x onto y
    return y * np.dot(x, y).reshape(-1, 1) / np.dot(y, y)


def plot_denoised_flux(name, estimator, transformed, d, center):
    t = copy.deepcopy(transformed)

    for i in range(2):
        t[:, i] = 0

        if name == "MiniBatchSparsePCA" or name == "FA":
            # sparse pca does not have the inverse transform method
            reconstructed = inverse_transform(estimator, t, center)
        #             reconstructed = inverse_transform2(estimator, d, i, center)
        elif name == "MiniBatchDictionaryLearning":
            if center:
                reconstructed = np.dot(t, estimator.components_) + global_mean
            else:
                reconstructed = np.dot(t, estimator.components_)
        else:
            reconstructed = estimator.inverse_transform(t)
            if center:
                reconstructed = reconstructed + global_mean

        da = xr.DataArray(
            data=reconstructed.reshape(sample_size, image_size[0], image_size[1]),
            dims=["time", "lat", "lon"],
            coords=[np.arange(2021 - sample_size, 2021), subset.lat.values, subset.lon.values],
        )
        da.plot(col="time", col_wrap=3, cmap="RdBu", vmax=250, vmin=-250)
        plt.suptitle(f"{name} component {i} removed", y=1)
        plt.savefig(f"{name}_component_{i}_removed.png")
        plt.show()


def plot_component_by_time(name, n_components, estimator, transformed, d):
    for i in range(n_components):
        t = np.zeros(transformed.shape)
        t[:, i] = transformed[:, i]

        if name == "MiniBatchSparsePCA" or name == "FA":
            # sparse pca does not have the inverse transform method
            reconstructed = inverse_transform(estimator, t, center=False)
        #             reconstructed = project(d, estimator.components_[i])
        elif name == "MiniBatchDictionaryLearning":
            reconstructed = np.dot(t, estimator.components_)
        else:
            reconstructed = estimator.inverse_transform(t) - estimator.mean_

        da = xr.DataArray(
            data=reconstructed.reshape(sample_size, image_size[0], image_size[1]),
            dims=["time", "lat", "lon"],
            coords=[np.arange(2021 - sample_size, 2021), subset.lat.values, subset.lon.values],
        )

        da.plot(col="time", col_wrap=3, cmap="RdBu", vmax=250, vmin=-250)
        plt.suptitle(f"{name} component {i} only", y=1)
        plt.savefig(f"{name}_component_{i}_only.png")
        plt.show()

In [None]:
for name, estimator, center in estimators:
    print(name)
    if center:
        transformed = estimator.fit_transform(data_centered)
    else:
        transformed = estimator.fit_transform(data)

    if hasattr(estimator, "cluster_centers_"):
        components_ = estimator.cluster_centers_
    elif name == "ICA":
        components_ = estimator.mixing_.T
    else:
        components_ = estimator.components_

    # plot independent components
    print("plotting components")
    plot_components(
        name, components_[:n_components], n_components=n_components, image_size=image_size
    )

    # plot the data with 1st or first two components removed
    print("plotting denoised fluxes")
    if center:
        plot_denoised_flux(name, estimator, transformed, data_centered, center)
    else:
        plot_denoised_flux(name, estimator, transformed, data, center)

    # plot components of each time slice
    print("plotting components by time")
    if center:
        plot_component_by_time(name, n_components, estimator, transformed, data_centered)
    else:
        plot_component_by_time(name, n_components, estimator, transformed, data)