# Biomassters PyTorch Training & Prediction Workflow

### Imports

In [None]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
%load_ext tensorboard

In [None]:
import pandas as pd
import os
import rasterio
import numpy as np
import torch
from torch.nn import Sequential
from torch.utils.data import DataLoader, random_split
from torchgeo.transforms import indices
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tqdm.notebook import tqdm
from PIL import Image
from models import Model

import warnings
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning) # biomassters rasters are not georeferenced
warnings.filterwarnings('ignore', r'All-NaN (slice|axis) encountered')

#### Local Imports from transforms.py and dataloading.py

In [None]:
import transforms as tf
import dataloading as dl
from utils import get_tile_image

### Setup GPU

In [None]:
try:
    if torch.backends.mps.is_available(): # Mac M1/M2
        device = torch.device('mps')
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        torch.multiprocessing.set_start_method('spawn')
    else:
        device = torch.device('cpu')
except AttributeError:
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

print(f'training device: {device}')

### Set directories for local environment

In [None]:
dir_tiles = '../data/train_features'
dir_target = '../data/train_agbm'
dir_test = '../data/test_features'
dir_saved_models = './trained_models'

bucket_name = None

In [None]:
# Useful for choosing which bands to keep 
band_map = {  
    # S2 bands
    0: 'S2-B2: Blue-10m',
    1: 'S2-B3: Green-10m',
    2: 'S2-B4: Red-10m',
    3: 'S2-B5: VegRed-704nm-20m',
    4: 'S2-B6: VegRed-740nm-20m',
    5: 'S2-B7: VegRed-780nm-20m',
    6: 'S2-B8: NIR-833nm-10m',
    7: 'S2-B8A: NarrowNIR-864nm-20m',
    8: 'S2-B11: SWIR-1610nm-20m',
    9: 'S2-B12: SWIR-2200nm-20m',
    10: 'S2-CLP: CloudProb-160m',
    # S1 bands
    11: 'S1-VV-Asc: Cband-10m',
    12: 'S1-VH-Asc: Cband-10m',
    13: 'S1-VV-Desc: Cband-10m',
    14: 'S1-VH-Desc: Cband-10m',
    # Bands derived by transforms 
    15: 'S2-NDVI: (NIR-Red)/(NIR+Red) 10m',
    16: 'S1-NDVVVH-Asc: Norm Diff VV & VH, 10m',
    17: 'S2-NDBI: Difference Built-up Index, 20m',
    18: 'S2-NDRE: Red Edge Vegetation Index, 20m',
    19: 'S2-NDSI: Snow Index, 20m',
    20: 'S2-NDWI: Water Index, 10m',
    21: 'S2-SWI: Sandardized Water-Level Index, 20m',
    22: 'S1-VV/VH-Asc: Cband-10m',
    23: 'S2-GNDVI',
    24: 'S2-GBNDVI',
    25: 'S2-EVI',
    26: 'S2-SSAVI',
    27: 's2-DPRVI',
}
month_map = {
    0: 'Sep', 1: 'Oct', 2: 'Nov', 3: 'Dec',
    4: 'Jan', 5: 'Feb', 6: 'Mar', 7: 'Apr',
    8: 'May', 9: 'Jun', 10: 'Jul', 11: 'Aug'
}

### Transforms
See https://torchgeo.readthedocs.io/en/latest/tutorials/indices.html and https://torchgeo.readthedocs.io/en/latest/api/transforms.html

In [None]:
transforms = Sequential(
    tf.ClampAGBM(vmin=0., vmax=600.),               # exclude AGBM outliers
    tf.AppendNDVI(index_nir=6, index_red=2),        # NDVI, index 15
    tf.AppendNormalizedDifferenceIndex(index_a=11, index_b=12), # Radar Vegetation Index (VV-VH)/(VV+VH), index 16
    tf.AppendNDBI(index_swir=8, index_nir=6),   # Difference Built-up Index for development detection, index 17
    tf.AppendNDRE(index_nir=6, index_vre1=3),   # Red Edge Vegetation Index for canopy detection, index 18
    tf.AppendNDSI(index_green=1, index_swir=8), # Snow Index, index 19
    tf.AppendNDWI(index_green=1, index_nir=6),  # Difference Water Index for water detection, index 20 
    tf.AppendSWI(index_vre1=3, index_swir2=8),  # Standardized Water-Level Index for water detection, index 21
    tf.AppendRatioAB(index_a=11, index_b=12),        # VV/VH Ascending, index 22
    tf.AppendGNDVI(index_nir=6, index_green=1),  # GNDVI, index 24
    tf.AppendGBNDVI(index_nir=6, index_green=1, index_blue=0),  # GBNDVI, index 25
    tf.AppendEVI(index_nir=6, index_red=2, index_blue=0), # Enhanced Vegetation Index
    tf.AppendSAVI(index_nir=6, index_red=2),         # Soil Adjusted Vegetation Index
    tf.AppendDPRVI(index_vh=12, index_vv=11),        # Dual Polarization vegetation index
)

### SentinelDataset - set `max_chips` 

In [None]:
# this file specifies which month of data to use for training for each chipid 
# See the preprocesing notebook for an example of producing this  
tile_file = 'data/TILE_LIST_BEST_MONTHS.csv'

In [None]:
max_chips = None # number of chips to use from training set, None = Use All  

# A custom dataloader for Sentinel data 
dataset = dl.SentinelDataset(tile_file=tile_file,
                             dir_tiles=dir_tiles,
                             dir_target=dir_target,
                             max_chips=max_chips,
                             transform=transforms,
                             device=device,
                             gcp_bucket_name=bucket_name,
                             scale=False,
                            )

### Split Train/Valid ---Note: manual seed set for reproducibility---

In [None]:
torch.manual_seed(0)

train_frac = 0.8
train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])
print(f'N training samples: {len(train_dataset)}')

### Dataloaders - set `batch_size` and `num_workers`

In [None]:
batch_size = 12  # Note: training speed is sensitive to memory usage
                 # set this as high as you can without significantly slowing down training time 
num_workers = 0
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=False
                             )
val_dataloader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=False
                           )

### Define some utilities for training
* Logger
* checkpoints
* Early stopping

In [None]:
logs_folder = 'training_logs'
if not os.path.exists(logs_folder):
    os.mkdir(logs_folder)

logger = TensorBoardLogger(logs_folder, name='')
logs_filepath = os.path.join(logger.save_dir, logger.name, 'version_'+str(logger.version))
checkpoint_callback = ModelCheckpoint(
    dirpath=logs_filepath,
    save_top_k=1, 
    monitor="val_loss",
    mode='min',
    filename='best_model'
)

early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=6)

### Initialize Model

In [None]:
in_channels = train_dataset[0]['image'].shape[0]
model = Model(in_channels=in_channels)
print(f'# input channels: {in_channels}')

### ... Or load a saved model

In [None]:
previous_version = 18 # Chhose the experiment version from which to load de model
checkpoint_filepath = os.path.join(logger.save_dir, logger.name, 'version_'+str(previous_version))
model = model.load_from_checkpoint(os.path.join(checkpoint_filepath,'best_model.ckpt'), in_channels=in_channels)
model.train();

# Run Training

In [None]:
n_epochs = 40
trainer = Trainer(
    logger=logger,
    accelerator=device.type, 
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=n_epochs, 
    devices=1
)

In [None]:
import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

### Plot RMSE for Training & Validation 

In [None]:
%tensorboard --logdir $logs_filepath

### Load Best model from current experiment

In [None]:
model = model.load_from_checkpoint(os.path.join(logs_filepath,'best_model.ckpt'), in_channels=in_channels)
model.eval();

## Visualize Sample Predictions 

### True AGBM

In [None]:
np.random.seed(123)

tile_idx = np.random.choice(len(dataset))  # arbitrary tile 

sample = dataset[tile_idx]

plt.imshow(get_tile_image(sample['label'].detach().cpu()),
                       interpolation=None,
                       norm=LogNorm(clip=True)
                       )

plt.colorbar()
plt.show()

### Predicted AGBM

In [None]:
model.to(device)

def predict_agbm(inputs, model):
    with torch.no_grad():
        if len(inputs.shape)==3:
            pred = model(inputs[None,:])
        else:
            pred = model(inputs)
    return pred.detach().squeeze().cpu()

plt.imshow(predict_agbm(sample['image'].to(device), model),
                       interpolation=None,
                       norm=LogNorm(clip=True)
                       )

plt.colorbar()
plt.show()

# Process Predictions on Test Holdout

In [None]:
# List of best tiles (per chip) for test data 
tile_file_test = 'data/TILE_LIST_BEST_MONTHS_TEST.csv'

# Path to save predictions 
dir_save_preds = 'data/test_predictions'

### Define Test Dataset 

In [None]:
max_chips = None # number of chips to use, None = Use All  

dataset_test = dl.SentinelDataset(tile_file=tile_file_test, # specifies best months of test data 
                                  dir_tiles=dir_test,       # test data dir
                                  dir_target=None,          # No AGBM targets for test data 
                                  max_chips=max_chips,      
                                  transform=transforms,     # same transforms as training
                                  device=device,
                                  gcp_bucket_name=bucket_name,
                                  scale=False)

### Sanity Check: Example Prediction on Test Data

In [None]:
tile_idx = 0 # arbitrary tile 

chipid = dataset_test.df_tile_list.iloc[tile_idx]['chipid']
inputs = dataset_test[tile_idx]['image'].to(device)
agbm = predict_agbm(inputs, model)

In [None]:
plt.imshow(agbm)
plt.colorbar()
plt.show()

## Loop through and save all AGBM predictions

In [None]:
def save_agbm(agbm_pred, chipid):
    im = Image.fromarray(agbm_pred)
    save_path = os.path.join(dir_save_preds, f'{chipid}_agbm.tif')
    im.save(save_path, format='TIFF', save_all=True)

In [None]:
model.to(device)

for ix, tile in tqdm(enumerate(dataset_test), total=len(dataset_test)):
    chipid = dataset_test.df_tile_list.iloc[ix]['chipid']
    inputs = tile['image'].to(device)
    agbm = predict_agbm(inputs, model).numpy()
    save_agbm(agbm, chipid)

In [None]:
#create compressed file for submission
!tar -cvzf data/test_predictions.tar.gz data/test_predictions

### Quick Check of Generated Predictions

In [None]:
file_path = os.path.join(dir_save_preds, f'{chipid}_agbm.tif')
test_pred = rasterio.open(file_path).read().astype(np.float32)[0]

In [None]:
plt.imshow(test_pred)
plt.colorbar()
plt.show()