In [40]:
DATA_PATH = "data/CA-Final"

In [41]:
import numpy as np
import ecoscape_connectivity
from ecoscape_connectivity.util import dict_translate
from ecoscape_utilities.bird_runs import BirdRun
from functools import reduce
import scgt
import torch

In [42]:
bird_run = BirdRun(DATA_PATH)

def create_bird_runs(target):
    """Creates bird runs for the specified output target."""
    birds = []

    birds.append(bird_run.get_bird_run(
        "acowoo", "Acorn Woodpecker", run_name=target))

    birds.append(bird_run.get_bird_run(
        "stejay", "Steller's Jay", run_name=target))

    for bird in birds:

        # Creates output folder, if missing.
        bird_run.createdir_for_file(bird.repopulation_fn)
        bird_run.createdir_for_file(bird.gradient_fn)

    return birds


The following code computes the patches, and labels the pixels belonging to each patch with the patch size. 
We also compute the largest patch size for renormalization. 

In [44]:
def shift(m, h=0, v=0):
    """Shift a matrix m, filling border with 0, in the horizontal and vertial directions by the amount specified."""
    sy, sx = m.shape
    # First, let's do the horizontal shift.
    if h > 0: 
        m = torch.column_stack([torch.zeros(sy, h, dtype=int), m[:, :-h]])
    elif h < 0:
        m = torch.column_stack([m[:, -h:], torch.zeros(sy, -h, dtype=int)])
    # Then the vertical shift.     
    if v > 0:
        m = torch.row_stack([torch.zeros(v, sx, dtype=int), m[:-v, :]])
    elif v < 0: 
        m = torch.row_stack([m[-v:, :], torch.zeros(-v, sx, dtype=int)])
    if h == 0 and v == 0:
        m = m.clone().detach()
    return m

In [45]:
def connected_regions(m, also_corners=True):
    """Computes the connected regions.  m must be a 0/1 matrix. 
    The output is a matrix of the same size as m, in which each connected region 
    is assigned an integer, with all pixels of that region having that value. 
    If also_corners is True, touching by a corner counts as being in the same patch. 
    """
    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
    if also_corners:
        directions.extend([(1, 1), (1, -1), (-1, 1), (-1, -1)])    
    m = torch.clamp(m, 0, 1)
    size_y, size_x = m.shape    
    # a is originally so that every entry has a different integer.
    y, x = np.mgrid[:size_y, :size_x]
    a = ((1 + (y * size_x) + x) * m.numpy()).astype(int)
    a = torch.tensor(a)
    # Does the repeated expansions, propagating the labels.
    changed = True
    while changed:
        na = a
        for dx, dy in directions:
            na = torch.maximum(na, m * shift(na, h=dx, v=dy))
        changed = torch.any(na - a)
        a = na
    return a


In [47]:
def regions_by_size(m):
    """Takes as input a 0-1 matrix m. 
    Returns a matrix of the same shape as m, where the pixels of each connected region
    are labeled with the size of the connected region."""
    a = connected_regions(m)
    # Relabels each region with its size. 
    # First, I form a dictionary with region id to size. 
    sizes = {i.item(): torch.sum(a == i).item() for i in torch.unique(a)}
    del sizes[0] # Not a region. 
    return dict_translate(a.numpy(), sizes)

In [48]:
bird_runs = create_bird_runs("patch_sizes_torch")

In [52]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [55]:
with torch.no_grad():
    for bird in bird_runs:
        # Reads the habitat. 
        gt = scgt.GeoTiff.from_file(bird.habitat_fn)
        in_tile = gt.get_all_as_tile()
        mt = torch.tensor(in_tile.m.squeeze(0).astype("int8"), device=device)
        m_sizes = regions_by_size(mt)
        with gt.clone_shape(bird.repopulation_fn, dtype='float32') as out_file:
            out_tile = scgt.Tile(in_tile.w, in_tile.h, in_tile.b, 
                                in_tile.c, in_tile.x, in_tile.y, m_sizes[None, :])
            out_file.set_tile(out_tile)
        print("Done", bird.name)

KeyboardInterrupt: 