In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'  # default is ‘last_expr’

%load_ext autoreload
%autoreload 2

In [24]:
import os
import json
from random import sample
from collections import defaultdict

from tqdm import tqdm

# Make train and val splits

These were made by first placing tiles in the val split so that every class is present there.

## Filter out all-empty tiles

Some tiles are empty (satellite tile is empty but the labels are valid)... So filtering out files that are below 1MB in size (they are typically 60+MB).

Some tiles may not have a valid mask file (because the geom geojson file corresponding to that tile is empty of features - not sure if a Solaris bug), so filtering out those as well.

In [3]:
data_dir = '/wcs/mnt/wcs-orinoquia/tiles/full_sr_median_2013_2014'
assert 'tiles' in os.listdir(data_dir)
assert 'tiles_masks' in os.listdir(data_dir)

tiles_dir = os.path.join(data_dir, 'tiles')
labels_dir = os.path.join(data_dir, 'tiles_labels')
masks_dir = os.path.join(data_dir, 'tiles_masks')

all_tiles = os.listdir(tiles_dir)
len(all_tiles)

all_labels = os.listdir(labels_dir)
len(all_labels)

all_masks = os.listdir(masks_dir)
len(all_masks)

151

151

148

In [4]:
def get_lon_lat_from_tile_name(tile_name):
    parts = tile_name.split('_')
    lon_lat = f'_{parts[-2]}_{parts[-1].split(".tif")[0]}'
    return lon_lat

In [27]:
valid_tiles_to_lon_lat = {}

mask_lon_lat = set([mask_name.split('mask')[1].split('.png')[0] for mask_name in all_masks])
print(f'Number of masks: {len(mask_lon_lat)}')

for tile_name in all_tiles:
    tile_path = os.path.join(tiles_dir, tile_name)
    file_size = os.path.getsize(tile_path)  # in bytes
    tile_valid = True if file_size > 1000000 else False
    
    lon_lat = get_lon_lat_from_tile_name(tile_name)
    mask_valid = True if lon_lat in mask_lon_lat else False
        
    if tile_valid and mask_valid:
        valid_tiles_to_lon_lat[tile_name] = lon_lat

print(f'Number of valid tiles: {len(valid_tiles_to_lon_lat)}')

Number of masks: 148
Number of valid tiles: 148


In [31]:
valid_lon_lat_to_tiles = {v: k for k, v in valid_tiles_to_lon_lat.items()}

## Understand the class composition of tiles

So we can evaluate if all classes are present in both train and val.

In [6]:
valid_lon_lat[10]

'_-72.425_6.054'

In [7]:
tile_class_area = defaultdict(lambda: defaultdict(int))

for lon_lat in tqdm(valid_lon_lat):

    label_fp = os.path.join(labels_dir, f'geoms{lon_lat}.geojson')
    with open(label_fp) as f:
        label_geojson = json.load(f)
        
    for feature in label_geojson['features']:
        props = feature['properties']
        tile_class_area[lon_lat][props['Landuse_WC']] += props['AREA_HA']

100%|██████████| 148/148 [04:09<00:00,  1.68s/it]


In [9]:
tile_class_area['_-72.425_6.054']

defaultdict(int,
            {18: 770.1544921381,
             11: 72487.11709933847,
             14: 104514.49880733789,
             17: 170327314.93178856,
             10: 1014353.3464021531,
             8: 18270269.314165566,
             12: 1828028.758876238,
             19: 47803.63583881964,
             4: 140318.5281077013,
             27: 54.7289338492,
             9: 1021.719847316,
             26: 2100956.364798669,
             15: 24782.904267798902,
             13: 417.98734431689996,
             1: 594.1644210674,
             33: 8684.2665485058,
             32: 11395.513675939,
             30: 3547.8547573881997,
             20: 99.0079791595})

### Is every class present in more than 1 tile?

In [20]:
class_on_n_tiles = defaultdict(int)
class_to_tiles = defaultdict(set)

for tile, clss in tile_class_area.items():
    for c in clss:
        class_on_n_tiles[c] += 1  # present on this tile
        class_to_tiles[c].add(tile)

In [21]:
sorted(class_on_n_tiles.items(), key=lambda x: x[1])

[(24, 1),
 (5, 2),
 (29, 4),
 (6, 5),
 (28, 7),
 (21, 12),
 (3, 15),
 (30, 15),
 (31, 21),
 (16, 22),
 (7, 22),
 (4, 26),
 (2, 30),
 (9, 34),
 (13, 35),
 (22, 40),
 (1, 68),
 (20, 70),
 (25, 73),
 (10, 74),
 (23, 77),
 (18, 96),
 (15, 108),
 (27, 108),
 (33, 115),
 (32, 116),
 (14, 128),
 (8, 128),
 (19, 134),
 (26, 135),
 (11, 139),
 (17, 142),
 (12, 143)]

```
"24": "Glaciers and snow zones"
"5": "Permanent herbaceous crops"
```

## Train/val splits

First round, place some tiles in val set first so that every class is present there.

We won't have a test set right now since the labels are not of the same period of time and is quite noisy.

In [12]:
train_ratio = 0.9
val_ratio = 0.1

num_train = round(train_ratio * len(valid_tiles))
num_train
num_val = len(valid_tiles) - num_train
num_val

133

15

In [26]:
valid_tiles[0]

'wcs_orinoquia_sr_median_2013_2014-0000007424-0000000000_-73.559_4.593.tif'

In [67]:
val_tiles = set()

sampled_lon_lat = list(class_to_tiles[24])[0]
sampled_lon_lat
val_tiles.add(valid_lon_lat_to_tiles[sampled_lon_lat])

sampled_lon_lat = sample(class_to_tiles[5], 1)[0]
sampled_lon_lat
val_tiles.add(valid_lon_lat_to_tiles[sampled_lon_lat])

sampled_lon_lat = sample(class_to_tiles[29], 1)[0]
sampled_lon_lat
val_tiles.add(valid_lon_lat_to_tiles[sampled_lon_lat])

sampled_lon_lat = sample(class_to_tiles[6], 1)[0]
sampled_lon_lat
val_tiles.add(valid_lon_lat_to_tiles[sampled_lon_lat])

'_-72.425_6.593'

'_-72.964_4.593'

'_-73.559_5.132'

'_-73.559_5.132'

In [68]:
val_tiles

{'wcs_orinoquia_sr_median_2013_2014-0000000000-0000007424_-72.425_6.593.tif',
 'wcs_orinoquia_sr_median_2013_2014-0000007424-0000000000_-73.559_5.132.tif',
 'wcs_orinoquia_sr_median_2013_2014-0000007424-0000007424_-72.964_4.593.tif'}

In [69]:
while len(val_tiles) < num_val:
    val_tiles.add(sample(valid_tiles, 1)[0])

In [72]:
train_tiles = [i for i in valid_tiles if i not in val_tiles]

In [73]:
len(train_tiles)
len(val_tiles)

133

15

In [74]:
train_class_dist = defaultdict(int)
val_class_dist = defaultdict(int)

for class_dist, tiles_list in ((train_class_dist, train_tiles), (val_class_dist, val_tiles)):
    for tile_name in tiles_list:
        lon_lat = get_lon_lat_from_tile_name(tile_name)
        dist = tile_class_area[lon_lat]
        for clss, area in dist.items():
            class_dist[clss] += area

In [75]:
len(train_class_dist)
len(val_class_dist)

# re-run sampling cell to re-sample so that 32 classes are in the training set and 33 in the val set

32

33

Class `0 - Empty of data` and `29 - Label unavailable` are not necessary 

In [76]:
for i in range(34):
    print(f'class {i}, area in train {train_class_dist.get(i, 0)}, area in val {val_class_dist.get(i, 0)}')

class 0, area in train 0, area in val 0
class 1, area in train 24927.542426916276, area in val 5130.82719514284
class 2, area in train 39291.47257538113, area in val 7738.48755732441
class 3, area in train 3574.8413460291895, area in val 188.2659720052
class 4, area in train 2543248.8607051307, area in val 291586.11037356406
class 5, area in train 249.79796012100002, area in val 249.79796012100002
class 6, area in train 674531.8307268098, area in val 36344.288837693246
class 7, area in train 20629826.70601472, area in val 808753.1798519307
class 8, area in train 1950658895.3429046, area in val 63419471.60730979
class 9, area in train 1876129.0947182847, area in val 430692.6463152589
class 10, area in train 6698842.207297214, area in val 918199.4227899172
class 11, area in train 4366431.483713221, area in val 1409682.1391754732
class 12, area in train 624527103.5586175, area in val 14625673.287193296
class 13, area in train 76932.1231358848, area in val 58503.71310796899
class 14, area 

## Save the splits

In [None]:
train_tiles

In [None]:
val_tiles

In [81]:
splits = {
    'train': train_tiles,
    'val': list(val_tiles)
}

In [82]:
with open('/wcs/pycharm/wildlife-conservation-society.orinoquia-land-use/constants/splits/full_sr_median_2013_2014_splits.json', 'w') as f:
    json.dump(splits, f, indent=2)