# Explore [Biomass Estimation](https://www.drivendata.org/competitions/99/biomass-estimation/) contest data

In [85]:
from os.path import join, basename, isfile
import random
from multiprocessing.pool import ThreadPool
import os

import s3fs
from tqdm.autonotebook import tqdm
import numpy as np
import pandas
import rasterio
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import torch

In [86]:
root_dir = '/Users/lewfish/data/biomass/'
dataset_dir = join(root_dir, 'dataset')
output_dir = join(root_dir, 'output')
os.makedirs(dataset_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
features_csv_path = join(dataset_dir, 'features_metadata.csv')
labels_csv_path = join(dataset_dir, 'train_agbm_metadata.csv')
image_dir = join(dataset_dir, 'train_features')
label_dir = join(dataset_dir, 'train_agbm')

In [90]:
features_df = pandas.read_csv(features_csv_path)
features_df.head()

Unnamed: 0,filename,chip_id,satellite,split,month,size,cksum,s3path_us,s3path_eu,s3path_as,corresponding_agbm
0,0003d2eb_S1_00.tif,0003d2eb,S1,train,September,1049524,3953454613,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,0003d2eb_agbm.tif
1,0003d2eb_S1_01.tif,0003d2eb,S1,train,October,1049524,3531005382,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,0003d2eb_agbm.tif
2,0003d2eb_S1_02.tif,0003d2eb,S1,train,November,1049524,1401197002,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,0003d2eb_agbm.tif
3,0003d2eb_S1_03.tif,0003d2eb,S1,train,December,1049524,3253084255,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,0003d2eb_agbm.tif
4,0003d2eb_S1_04.tif,0003d2eb,S1,train,January,1049524,2467836265,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,0003d2eb_agbm.tif


In [91]:
labels_df = pandas.read_csv(labels_csv_path)
labels_df.head()

Unnamed: 0,filename,chip_id,size,cksum,s3path_us,s3path_eu,s3path_as
0,0003d2eb_agbm.tif,0003d2eb,262482,2036246549,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...
1,000aa810_agbm.tif,000aa810,262482,2858468457,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...
2,000d7e33_agbm.tif,000d7e33,262482,277850822,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...
3,00184691_agbm.tif,00184691,262482,3502312579,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...
4,001b0634_agbm.tif,001b0634,262482,2397957274,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...,s3://drivendata-competition-biomassters-public...


## Compute dataset statistics

The number of `chip_id`s per split:

In [None]:
features_df.groupby('chip_id').first().split.value_counts()

train    8689
test     2773
Name: split, dtype: int64

The number of 1-band rasters per split:

In [None]:
features_df.split.value_counts()

train    189078
test      63348
Name: split, dtype: int64

## Download subset of data

Get a small subset of the training `chip_id`s (after seeding and randomly shuffling)

In [None]:
train_chip_ids = features_df.query('split=="train"').chip_id.unique()
random.seed(1234)
random.shuffle(train_chip_ids)
len(train_chip_ids)

8689

In [None]:
subset_sz = 10
_train_chip_ids = train_chip_ids[0:subset_sz]

Download a subset of the data locally

In [None]:
def get_image_uri(split, chip_id, satellite, month, 
                  root_uri='s3://drivendata-competition-biomassters-public-us'):
    return join(root_uri, f'{split}_features/{chip_id}_{satellite}_{month:02}.tif')

def get_label_uri(split, chip_id, 
                  root_uri='s3://drivendata-competition-biomassters-public-us'):
    return join(root_uri, f'{split}_agbm/{chip_id}_agbm.tif')

In [None]:
download_tasks = []
for chip_id in _train_chip_ids:
    for satellite in ['S1', 'S2']:
        for month in range(0, 12):
            image_uri = get_image_uri('train', chip_id, satellite, month)
            image_path = join(image_dir, basename(image_uri))
            download_tasks.append((image_uri, image_path))
    label_uri = get_label_uri('train', chip_id)
    label_path = join(label_dir, basename(label_uri))
    download_tasks.append((label_uri, label_path))
print(download_tasks[0:2])

[('s3://drivendata-competition-biomassters-public-us/train_features/d8e45923_S1_00.tif', '/Users/lewfish/data/biomass/dataset/train_features/d8e45923_S1_00.tif'), ('s3://drivendata-competition-biomassters-public-us/train_features/d8e45923_S1_01.tif', '/Users/lewfish/data/biomass/dataset/train_features/d8e45923_S1_01.tif')]


In [None]:
fs = s3fs.S3FileSystem(anon=True)

def download_file(x):
    from_uri, to_path = x
    if fs.exists(from_uri):
        fs.download(from_uri, to_path)

pool = ThreadPool(8)
for _ in tqdm(pool.imap_unordered(download_file, download_tasks), 
                total=len(download_tasks)):
    pass

A Jupyter Widget

## Load and visualize data

Make PyTorch `Dataset` class to read dataset

In [None]:
class BiomassDataset(Dataset):
    def __init__(self, root_uri, split, chip_ids):
        self.root_uri = root_uri
        self.split = split
        self.chip_ids = chip_ids
    
    def __len__(self):
        return len(self.chip_ids)

    def __getitem__(self, ind):
        chip_id = self.chip_ids[ind]
        img_arrs = []
        avail_months = []
        
        for month in range(0, 12):
            sat_arrs = []
            for satellite in ['S1', 'S2']:                
                image_path = get_image_uri(
                    self.split, chip_id, satellite, month, root_uri=self.root_uri)
                if isfile(image_path):
                    with rasterio.open(image_path) as img:
                        if satellite == 'S1':
                            # S1 is float32 and -9999 means missing data
                            sat_arr = torch.from_numpy(img.read())
                        else:
                            # S2 is uint16 and the last band is cloud probability 
                            # (ranges 0-100, or 255 for noise)
                            sat_arr = torch.from_numpy(img.read().astype(np.float32))
                            sat_arr[0:-1] = sat_arr[0:-1] / np.iinfo(np.uint16).max
                            sat_arr[-1] = sat_arr[-1] / 100
                        sat_arrs.append(sat_arr)
            if len(sat_arrs) == 2:
                img_arr = torch.cat(sat_arrs, dim=0)
                img_arrs.append(img_arr.unsqueeze(0))
                avail_months.append(month)
        
        img_arr = torch.cat(img_arrs, dim=0)
        label_path = get_label_uri(self.split, chip_id, root_uri=self.root_uri)
        with rasterio.open(label_path) as label:
            label_arr = torch.from_numpy(label.read()).squeeze()
        
        metadata = {
            'avail_months': avail_months,
            'chip_id': chip_id
        }
        return img_arr, label_arr, metadata

Read a single chip

In [None]:
biomass_ds = BiomassDataset(dataset_dir, 'train', _train_chip_ids)

x, y, metadata = biomass_ds[0]
print(x.shape)
print(y.shape)
print(metadata)


torch.Size([11, 15, 256, 256])
torch.Size([256, 256])
{'avail_months': [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11], 'chip_id': 'd8e45923'}


  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


In [None]:
def plot_xy(x, y, metadata, out_path=None):
    nrows = x.shape[0] + 1
    ncols = 15
    # I'm not totally sure this is the right order for S1
    s1_bands = ['VV Asc', 'VH Asc', 'VV Desc', 'VH Desc']
    s2_bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12', 'CLP']
    bands = s1_bands + s2_bands
    months = [
        'September', 'October', 'November', 'December', 'January', 'February', 'March', 
        'April', 'May', 'June', 'July', 'August']

    fig, axs = plt.subplots(
        nrows, ncols, constrained_layout=True, figsize=(1.5 * ncols, 1.5 * nrows),
        squeeze=False)

    for row_ind, row_axs in enumerate(axs):
        if row_ind == nrows - 1:
            # plot the label in the last row
            for col_ind, ax in enumerate(row_axs):
                if col_ind == 0:
                    ax.imshow(y)
                    ax.set_title('Biomass Estimate')
                ax.set_xticks([])
                ax.set_yticks([])
        else:
            for col_ind, ax in enumerate(row_axs):
                _x = x[row_ind, col_ind, :, :]
                ax.imshow(_x)
                
                if col_ind == 0:
                    ax.set_ylabel(months[metadata['avail_months'][row_ind]])
                if row_ind == 0:
                    ax.set_title(bands[col_ind])
                
                ax.set_xticks([])
                ax.set_yticks([])

    if out_path:
        plt.savefig(
            out_path, bbox_inches='tight', pad_inches=0.2, transparent=False, dpi=300)
    else:
        plt.show()

Save a plot for each chip

In [89]:
# need to use multiprocess fork which works in Jupyter notebooks
from multiprocess import Pool

def plot_item(item):
    x, y, metadata = item
    chip_id = metadata['chip_id']
    plot_path = join(output_dir, f'{chip_id}.png')
    plot_xy(x, y, metadata, plot_path)

pool = Pool(8)
for _ in tqdm(pool.imap_unordered(plot_item, biomass_ds), total=len(biomass_ds)):
    pass

A Jupyter Widget