This notebook creates training patches for the landcover transfer learning task. We assume the raw landcover rasters have already been downloaded using the `landcover.ipynb` notebook. The overall process has the form

1. Build a VRT from all the landcover tiles

2. For each input x (LE7 imagery) tile...

  a. Read a window of the y (landcover) VRT with the same lat / long bounds as the x tile
  
  b. Extract a patch of size 512 x 512 from x, which is the dimension used for model trainig 

3. For each patch of size 512 x 512 from x ...

  a. Determine the corresponding pixel coordinates in y. Not exact alignment, because x and y have different spatial resolutions

  b. Linearly interpolate y onto the dimension of x
  
  c. Write both x and y to file
  
4. Shuffle those into training, development, and test datasets. These will be preprocessed and then used for training.

In [None]:
import rasterio
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy import interpolate
from rasterio.windows import from_bounds
import pandas as pd
%matplotlib inline

landcover_dir = Path("/datadrive/glaciers/landcover")

In [None]:
import subprocess

def vrt_from_dir(input_dir, output_path="./output.vrt", **kwargs):
    inputs = [f for f in input_dir.glob("*.tif*")]
    subprocess.call(["gdalbuildvrt", "-o", output_path] + inputs)

vrt_from_dir(landcover_dir, landcover_dir / "landcover.vrt")

In [None]:
tiles_dir = Path("/datadrive/glaciers/unique_tiles/warped")
tilef = rasterio.open(tiles_dir / "LE07_134040_20070922-warped.tiff")
tile = tilef.read()
tile = np.transpose(tile, (1, 2, 0))

In [None]:
landcover = rasterio.open(landcover_dir / "landcover.vrt")
landcover_bounds = from_bounds(*tilef.bounds, landcover.transform)
y = landcover.read(window = landcover_bounds)
y = np.transpose(y, (1, 2, 0))

In [None]:
def patch_fractions(x_shape, out_size=(512, 512)):
    """
    Cut points for (512, 512) patches
    
    To get aligned x & y patches, we crop tiles according to the fractions of the
    original image dimensions (this is assuming that, even though x and y have
    different dimensions, they share the same lat / long bounding boxes). We use this
    function to ensure that the fractions we crop at provide 512 x 512 images.
    """
    a = out_size[0] / x_shape[0]
    b = out_size[1] / x_shape[1]
    a_grid = np.arange(0, 1, a)
    b_grid = np.arange(0, 1, b)

    fractions = []
    for i in range(len(a_grid) - 1):
        for j in range(len(b_grid) - 1):
            fractions.append((a_grid[i], a_grid[i + 1], b_grid[j], b_grid[j + 1]))
    
    return fractions


def crop_fraction(u, pixel_fraction = [0.4, 0.5, 0.4, 0.5]):
    """
    Crop a W x H x C image
    
    This crops images according to the pixel fraction boundaries. E.g., if we want
    to get the pixels between the 10 and 20% widths and heights of the image, we
    would use [0.1, 0.2, 0.1, 0.2] as the pixel fraction argument.
    """
    a, b, c, d = pixel_fraction
    ix = [int(a * u.shape[0]), int(b * u.shape[0]), int(c * u.shape[1]), int(d * u.shape[1])]
    return u[ix[0]:ix[1], ix[2]:ix[3], :]


def resize_mask(x_shape, y):
    """
    Linearly interpolate Y to X shape
    
    This resizes the Y data to overlap exactly with X. Otherwise, different 
    resolutions in the mask and input images would cause issues.
    """
    lin_grid = lambda s: [
        np.linspace(0, s[1], s[1], endpoint=False),
        np.linspace(0, s[0], s[0], endpoint=False)
    ]
    
    x_grid = lin_grid(x_shape)
    y_grid = lin_grid(y.shape)
    y_resized = []
    
    for j in range(y.shape[2]):
        f_interp = interpolate.interp2d(*y_grid, y[:, :, j])
        y_resized.append(f_interp(*x_grid))
        
    return np.stack(y_resized, axis=2)

In [None]:
J, K = 120, 2

pixel_frac = patch_fractions(tile.shape[:2])
x = crop_fraction(tile, pixel_frac[J])
y_ = crop_fraction(y, pixel_frac[J])
y_ = resize_mask(x.shape, y_)
x = x[:, :, [4, 3, 1]]
x /= x.max()
plt.imshow(x)
plt.show()
plt.imshow(y_[:, :, K], plt.get_cmap("GnBu"))
plt.show()

plt.imshow(x)
plt.imshow(y_[:, :, K], plt.get_cmap("GnBu"), alpha = 0.4)