In [None]:
import xarray as xr
import pathlib
import rasterio
import tqdm
import numpy as np
from concurrent.futures import ProcessPoolExecutor
from dask.distributed import Client
import matplotlib.pyplot as plt

from data import KelpDataset
from data import Channel as Ch
import trafos

In [None]:
def normalize(x, xmin, xmax):
    return (x - xmin) / (xmax - xmin)

# Write tif to nc

In [None]:
import warnings
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

In [None]:
ds_tif_train = KelpDataset(img_dir="data_inpainted/train_satellite/", mask_dir="data/train_kelp/")
ds_tif_test = KelpDataset(img_dir="data_inpainted/test_satellite/", mask_dir=None)

In [None]:
ds = ds_tif_test

with ProcessPoolExecutor(max_workers=32) as p:
    res = list(tqdm.tqdm(
        p.map(ds.__getitem__, range(len(ds))),
        total=len(ds)
    ))

## Train

In [None]:
imgs, masks = zip(*res)
imgs = xr.DataArray(np.array(imgs), dims=["sample", "i", "j", "ch"], coords={"sample": ds.tile_ids})
masks = xr.DataArray(np.array(masks), dims=["sample", "i", "j"], coords={"sample": ds.tile_ids})

In [None]:
imgs.to_netcdf("data_ncf/train_imgs.ncf")
masks.to_netcdf("data_ncf/train_masks.ncf")

## Test

In [None]:
imgs, masks = zip(*res)
imgs = xr.DataArray(np.array(imgs), dims=["sample", "i", "j", "ch"], coords={"sample": ds.tile_ids})
imgs.to_netcdf("data_ncf/test_imgs.ncf")

# Load nc back

In [None]:
client = Client(n_workers=8)
client

In [None]:
# For writing
# imgs = xr.open_dataarray("data_ncf/train_imgs.ncf", engine="netcdf4", chunks={"sample": 500, "i": None, "j": None, "ch": 1})
imgs = xr.open_dataarray("data_ncf/test_imgs.ncf", engine="netcdf4", chunks={"sample": 500, "i": None, "j": None, "ch": 1})

# For dev
# imgs = xr.open_dataarray("data_ncf/imgs.ncf", chunks={"sample": None, "i": None, "j": None, "ch": 1})
# imgs = imgs.isel(sample=np.random.choice(np.arange(imgs.sizes["sample"]), size=1000, replace=False))
imgs

In [None]:
nir = imgs.isel(ch=Ch.NIR)
swir = imgs.isel(ch=Ch.SWIR)
r = imgs.isel(ch=Ch.R)
g = imgs.isel(ch=Ch.G)
b = imgs.isel(ch=Ch.B)

ndwi_1 = (g - nir) / (g + nir)
ndwi_2 = (nir - swir) / (nir + swir)
ndvi = (nir - r) / (nir + r)
gndvi = (nir - g) / (nir + g)
ndti = (r - g) / (r + g)
evi = 2.5 * (nir - r) / (nir + 6 * r - 7.5 * b + 1)
cari = ((nir - r) / (nir + r)) - ((nir - g) / (nir + g))

## For dev: cdf plots

In [None]:
def plot_cdf(x, vmin=0, vmax=1):
    x_q = x.quantile(q=ch_q["quantile"]).compute()

    fig, ax = plt.subplots()
    ax.plot(x_q, x_q["quantile"])
    ax.set_ylim(0, 1)
    ax.set_xlim(vmin, vmax)

    return x_q

In [None]:
plot_cdf(normalize(swir, .1, .3))

In [None]:
plot_cdf(normalize(nir, .1, .35))

In [None]:
plot_cdf(normalize(r, .1, .2))

In [None]:
plot_cdf(normalize(g, .1, .2))

In [None]:
plot_cdf(normalize(b, .1, .2))

In [None]:
plot_cdf(normalize(ndwi_1, -.4, .1))

In [None]:
plot_cdf(normalize(ndwi_2, -.1, .2))

In [None]:
plot_cdf(normalize(ndvi, -.05, .4))

In [None]:
plot_cdf(normalize(gndvi, -.1, .5))

In [None]:
plot_cdf(normalize(ndti, -.075, .075))

In [None]:
plot_cdf(normalize(evi, -.075, .4))

In [None]:
plot_cdf(normalize(cari, -.06, .06))

## For prod: normalize

In [None]:
# Now normalize
nir = normalize(nir, .1, .35)
swir = normalize(swir, .1, .3)
r = normalize(r, .1, .2)
g = normalize(g, .1, .2)
b = normalize(b, .1, .2)

ndwi_1 = normalize(ndwi_1, -.4, .1)
ndwi_2 = normalize(ndwi_2, -.1, .2)
ndvi = normalize(ndvi, -.05, .4)
gndvi = normalize(gndvi, -.1, .5)
ndti = normalize(ndti, -.075, .075)
evi = normalize(evi, -.075, .4)
cari = normalize(cari, -.06, .06)

In [None]:
# Stick to Channel order
imgs_fe = xr.concat([
    swir, 
    nir, 
    r, 
    g, 
    b,
    imgs.isel(ch=Ch.IS_CLOUD),
    imgs.isel(ch=Ch.IS_LAND),
    imgs.isel(ch=Ch.NOT_CLOUD_LAND),
    ndwi_1,
    ndwi_2, 
    ndvi,
    gndvi, 
    ndti,
    evi,
    cari 
], dim="ch")

imgs_fe = imgs_fe.assign_coords({"ch": [
    "swir",
    "nir",
    "r",
    "g",
    "b",
    "is_cloud",
    "is_land",
    "not_cloud_land",
    "ndwi_1",
    "ndwi_2",
    "ndvi",
    "gndvi",
    "ndti",
    "evi",
    "cari",
]})

imgs_fe = imgs_fe.transpose("sample", "i", "j", "ch")
imgs_fe = imgs_fe.clip(0, 1)
imgs_fe

In [None]:
imgs_fe = imgs_fe.compute()
imgs_fe = imgs_fe.interpolate_na(dim="i", fill_value="extrapolate")
imgs_fe

In [None]:
np.where(imgs_fe.isnull())

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(imgs_fe[7, :, :, 8])

In [None]:
# imgs_fe.to_netcdf("data_ncf/train_imgs_fe.nc", engine="netcdf4")
imgs_fe.to_netcdf("data_ncf/test_imgs_fe.nc", engine="netcdf4")