In [None]:
import os
import albumentations as A
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
import torch
import rasterio as rio
from pytorch_lightning import Trainer
# from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import warnings
from torchsummary import summary
from torch.utils.data import DataLoader
import json
# import wandb
from utils import calcuate_mean_std, stratify_data, freeze_encoder, BioMasstersDatasetS2S1, SentinelModel

warnings.filterwarnings("ignore", category=rio.errors.NotGeoreferencedWarning)
np.set_printoptions(suppress=True)
pd.set_option('display.float_format', lambda x: '%.2f' % x)
torch.set_printoptions(sci_mode=False)

## Training
* 15 models were trained using a *UNet++* architecture in combination with various encoders (e.g. *se_resnext50_32x4d* and *efficientnet-b{7,8}*) and median cloud-free composites. From the experiments, *UNet++* showed better performance as compared to other decoders (e.g *UNet*, *MANet* etc.)
* The models were pretrained with multiple augmentations (*HorizontalFlip*, *VerticalFlip*, *RandomRotate90*, *Transpose*, *ShiftScaleRotate*), batch size of 32, *AdamW* optimizer with 0.001 initial learning rate, weight decay of 0.0001, and a *ReduceLROnPlateau* scheduler
* *UNet++* models were optimized using a *Huber* loss to reduce the effect of outliers in the data for 200 epochs
* To improve the performance of each *UNet++* model they were further fine-tuned (after freezing pre-trained encoder weights and removing augmentations) for another 100 epochs with batch size of 32, *AdamW* optimizer with 0.0005 initial learning rate, weight decay of 0.0001, and a *ReduceLROnPlateau* scheduler

In [None]:
root_dir = os.getcwd() # Change to the folder where you stored preprocessed training data

S1_CHANNELS = {'2S': 8, '2SI': 12, '3S': 12, '4S': 16, '4SI': 24, '6S': 24}
S2_CHANNELS = {'2S': 20, '2SI': 38, '3S': 30, '4S': 40, '4SI': 48, '6S': 60}

### Stratify data
Here I split train data into train (8593 samples) and validation (96 samples) datasets. This is done in a strtatified manner based on average and standard deviation of agb values to ensure similar distributions of both datasets. 
NOTE: my original train/validation split was based on a random <code>random_state</code>, so I included it for reproducibility of the results: <code>./data/train_val_split_96_0_original.csv</code>

In [None]:
df = stratify_data(
    s2_path_train=f"{root_dir}/train_features_s2_4S", 
    agb_path=f"{root_dir}/train_agbm", 
    s2_path_test=f"{root_dir}/test_features_s2_4S", 
    test_size=96, 
    random_state=0
)
df.to_csv(os.path.join(f'./data/train_val_split_96_0.csv'), index=None)

To reproduce the results simply read pre-computed train/validation/test splits from file

In [None]:
df = pd.read_csv(os.path.join(f'./data/train_val_split_96_0.csv'), dtype={"id": str})

Here we define train, validation and test sets

In [None]:
X_train, X_val, X_test = (df["id"].loc[df["dataset"] == 0].tolist(),
                          df["id"].loc[df["dataset"] == 1].tolist(),
                          df["id"].loc[df["dataset"] == 2].tolist())
print(df["dataset"].value_counts())
print("Total Images: ", len(df))

### Calculate mean and std for image standardization
Here I calculate mean and standard deviation for each composite data using train dataset, which are used for data standardization. I also standardized target variable (i.e. agb) as it showed to speed up model convergence.

In [None]:
mean_agb, std_agb = calcuate_mean_std(image_dir=f"{root_dir}/train_agbm", train_set=X_train, percent=100, channels=1, 
                                      nodata=None, data='agbm', log_scale=False)

mean, std = {}, {}
for SUFFIX in ['2S', '2SI', '3S', '4S', '4SI', '6S']:

    S2_PATH = f"{os.path.join(root_dir, f'train_features_s2_{SUFFIX}')}"
    S1_PATH = f"{os.path.join(root_dir, f'train_features_s1_{SUFFIX}')}"

    mean_s2, std_s2 = calcuate_mean_std(image_dir=S2_PATH, train_set=X_train, percent=5, channels=S2_CHANNELS[SUFFIX], 
                                              nodata=0, data='S2', log_scale=False)
    mean_s1, std_s1 = calcuate_mean_std(image_dir=S1_PATH, train_set=X_train,  percent=5, channels=S1_CHANNELS[SUFFIX], 
                                              nodata=None, data='S1', log_scale=False)

    mean[SUFFIX] = mean_s2 + mean_s1
    std[SUFFIX] = mean_s1 + std_s1
    
with open('./data/mean.json', 'w') as f:
    json.dump(mean, f)
with open('./data/std.json', 'w') as f:
    json.dump(std, f)
with open('./data/mean_agb.json', 'w') as f:
    json.dump(mean_agb, f)
with open('./data/std_agb.json', 'w') as f:
    json.dump(std_agb, f)

We can skip previous step and read pre-calculated values from the json files

In [None]:
f = open('./data/mean.json')
mean = json.load(f)
f = open('./data/std.json')
std = json.load(f)
f = open('./data/mean_agb.json')
mean_agb = json.load(f)
f = open('./data/std_agb.json')
std_agb = json.load(f)

Making sure we can train on GPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
# Empty cache
torch.cuda.empty_cache()
print(torch.version.cuda)
torch.backends.cudnn.benchmark = False

### Training strategy
Overall I trained 15 *UNet++* models with different encoders accessible in [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch). Each of the models was trained in a 2-stage manner:
1) Base model traing using *imagenet* or *advprop* pre-training for 200 epochs
2) Fine-tuning of the base model after freezing encoder weights for another 100 epochs. Here I provided paths to the weights of the pre-trained models, so replace them if planning to train from scratch

### Base model setup
batch_size: 32/
epochs: 200/
learting_rate: 0.001/
weight_decay: 0.0001/
augmentations: HorizontalFlip(), VerticalFlip(), RandomRotate90(), Transpose(), ShiftScaleRotate()/
scheduler: ReduceLROnPlateau()

In [None]:
def train_base_model(suffix, encoder_name, encoder_weights, decoder_attention_type):
    # wandb.finish()    
    
    train_set = BioMasstersDatasetS2S1(s2_path=f"{root_dir}/train_features_s2_{suffix}",
                                       s1_path=f"{root_dir}/train_features_s1_{suffix}",
                                       agb_path=f"{root_dir}/train_agbm", X=X_train, mean=mean[suffix], std=std[suffix], 
                                       mean_agb=mean_agb, std_agb=std_agb, 
                                       transform=A.Compose([A.HorizontalFlip(), A.VerticalFlip(), 
                                                            A.RandomRotate90(), A.Transpose(), A.ShiftScaleRotate()]))

    val_set = BioMasstersDatasetS2S1(s2_path=f"{root_dir}/train_features_s2_{suffix}",
                                     s1_path=f"{root_dir}/train_features_s1_{suffix}",
                                     agb_path=f"{root_dir}/train_agbm", X=X_val, mean=mean[suffix], std=std[suffix], 
                                     mean_agb=mean_agb, std_agb=std_agb, transform=None)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=8, num_workers=8, pin_memory=True)
    
    val_loader = DataLoader(val_set, shuffle=False, batch_size=8, num_workers=8, pin_memory=True)

    model = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=encoder_weights, 
                             decoder_attention_type=decoder_attention_type,
                             in_channels=S2_CHANNELS[suffix]+S1_CHANNELS[suffix], classes=1, activation=None)

    s2s1_model = SentinelModel(model, mean_agb=mean_agb, std_agb=std_agb, lr=0.001, wd=0.0001)

    # summary(s2s1_model.cuda(), (S2_CHANNELS[SUFFIX]+S1_CHANNELS[SUFFIX], 256, 256)) 

    # wandb_logger = WandbLogger(save_dir=f'./models', name=f'{encoder_name}_{suffix}_{decoder_attention_type}', 
    #                            project=f'{encoder_name}_{suffix}_{decoder_attention_type}')

    ## Define a trainer and start training:
    on_best_valid_loss = ModelCheckpoint(filename="{epoch}-{valid/loss}", mode='min', save_last=True,
                                         monitor='valid/loss', save_top_k=2)
    on_best_valid_rmse = ModelCheckpoint(filename="{epoch}-{valid/rmse}", mode='min', save_last=True,
                                         monitor='valid/rmse', save_top_k=2)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    checkpoint_callback = [on_best_valid_loss, on_best_valid_rmse, lr_monitor]

    # Initialize a trainer
    trainer = Trainer(precision=16, accelerator="gpu", devices=1, max_epochs=200, 
                      # logger=[wandb_logger], 
                      callbacks=checkpoint_callback)
    # Train the model ⚡
    trainer.fit(s2s1_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

### Fine-tuned model setup
batch_size: 32/
epochs: 100/
learting_rate: 0.0005/
weight_decay: 0.0001/
augmentations: None/
scheduler: ReduceLROnPlateau()

In [None]:
def train_finetuned_model(checkpoint_path, suffix, encoder_name, decoder_attention_type):
    # wandb.finish()

    train_set = BioMasstersDatasetS2S1(s2_path=f"{root_dir}/train_features_s2_{suffix}",
                                       s1_path=f"{root_dir}/train_features_s1_{suffix}",
                                       agb_path=f"{root_dir}/train_agbm", X=X_train, mean=mean[suffix], std=std[suffix], 
                                       mean_agb=mean_agb, std_agb=std_agb, transform=None)

    val_set = BioMasstersDatasetS2S1(s2_path=f"{root_dir}/train_features_s2_{suffix}",
                                     s1_path=f"{root_dir}/train_features_s1_{suffix}",
                                     agb_path=f"{root_dir}/train_agbm", X=X_val, mean=mean[suffix], std=std[suffix], 
                                     mean_agb=mean_agb, std_agb=std_agb, transform=None)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=8, num_workers=8, pin_memory=True)

    val_loader = DataLoader(val_set, shuffle=False, batch_size=8, num_workers=8, pin_memory=True)

    model = smp.UnetPlusPlus(encoder_name=encoder_name, decoder_attention_type=decoder_attention_type,
                             in_channels=S2_CHANNELS[suffix]+S1_CHANNELS[suffix], classes=1, activation=None)

    freeze_encoder(model)

    s2s1_model = SentinelModel.load_from_checkpoint(model=model, checkpoint_path=checkpoint_path, 
                                                    mean_agb=mean_agb, std_agb=std_agb,
                                                    lr=0.0005, wd=0.0001)


    # summary(s2s1_model.cuda(), (S2_CHANNELS[SUFFIX]+S1_CHANNELS[SUFFIX], 256, 256)) 


#     wandb_logger = WandbLogger(save_dir=f'./models', name=f'{encoder_name}_{suffix}_{decoder_attention_type}', 
#                                project=f'{encoder_name}_{suffix}_{decoder_attention_type}')

    ## Define a trainer and start training:
    on_best_valid_loss = ModelCheckpoint(filename="{epoch}-{valid/loss}", mode='min', save_last=True,
                                         monitor='valid/loss', save_top_k=2)
    on_best_valid_rmse = ModelCheckpoint(filename="{epoch}-{valid/rmse}", mode='min', save_last=True,
                                         monitor='valid/rmse', save_top_k=2)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    checkpoint_callback = [on_best_valid_loss, on_best_valid_rmse, lr_monitor]

    # Initialize a trainer
    trainer = Trainer(precision=16, accelerator="gpu", devices=1, max_epochs=100, 
                      # logger=[wandb_logger], 
                      callbacks=checkpoint_callback)
    # Train the model ⚡
    trainer.fit(s2s1_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

### Train and fine-tune model #1
Composite: 4S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: se_resnext50_32x4d/
Encoder pre-training: imagenet

The certificate for the site to download model weights might be expired. You can use the code below to continue on by creating an unverified context, be aware of the security risks.
```
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
```

In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "se_resnext50_32x4d", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext50_32x4d_4S_None/1on9ti36/checkpoints/loss=0.07419885694980621.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "se_resnext50_32x4d", None)

### Train and fine-tune model #2
Composite: 4S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: se_resnext101_32x4d/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "se_resnext101_32x4d", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext101_32x4d_4S_None/39jj4bmx/checkpoints/loss=0.07529886066913605.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "se_resnext101_32x4d", None)

### Train and fine-tune model #3
Composite: 4S/
Decoder: UNet++/
Decoder attention type: scse/
Encoder: se_resnext50_32x4d/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "se_resnext50_32x4d", "imagenet", "scse")

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext50_32x4d_4S_scse/v0pd76d5/checkpoints/rmse=31.418827056884766.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "se_resnext50_32x4d", "scse")

### Train and fine-tune model #4
Composite: 3S/
Decoder: UNet++/
Decoder attention type: scse/
Encoder: se_resnext50_32x4d/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('3S', "se_resnext50_32x4d", "imagenet", "scse")

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext50_32x4d_3S_scse/92gj2lnf/checkpoints/rmse=31.44633674621582.ckpt'
    train_finetuned_model(checkpoint_path, '3S', "se_resnext50_32x4d", "scse")

### Train and fine-tune model #5
Composite: 4S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: efficientnet-b6/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "efficientnet-b6", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/efficientnet-b6_4S_None/2h56bi5o/checkpoints/rmse=31.456979751586914.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "efficientnet-b6", None)

### Train and fine-tune model #6
Composite: 4SI/
Decoder: UNet++/
Decoder attention type: None/
Encoder: efficientnet-b5/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4SI', "efficientnet-b5", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/efficientnet-b5_4SI_None/2o168bz4/checkpoints/loss=0.07675273716449738.ckpt'
    train_finetuned_model(checkpoint_path, '4SI', "efficientnet-b5", None)

### Train and fine-tune model #7
Composite: 4S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: xception/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "xception", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/xception_4S_None/2vupnzea/checkpoints/loss=0.07764090597629547.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "xception", None)

### Train and fine-tune model #8
Composite: 2SI/
Decoder: UNet++/
Decoder attention type: None/
Encoder: se_resnext50_32x4d/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('2SI', "se_resnext50_32x4d", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext50_32x4d_2SI_None/2hd4nj2v/checkpoints/loss=0.07711037248373032.ckpt'
    train_finetuned_model(checkpoint_path, '2SI', "se_resnext50_32x4d", None)

### Train and fine-tune model #9
Composite: 2S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: se_resnext50_32x4d/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('2S', "se_resnext50_32x4d", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext50_32x4d_2S_None/2xax1i19/checkpoints/rmse=31.860191345214844.ckpt'
    train_finetuned_model(checkpoint_path, '2S', "se_resnext50_32x4d", None)

### Train and fine-tune model #10
Composite: 6S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: timm-efficientnet-b7/
Encoder pre-training: advprop

In [None]:
if __name__ == '__main__':
    train_base_model('6S', "timm-efficientnet-b7", "advprop", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/timm-efficientnet-b7_6S_None/yitdzdeu/checkpoints/loss=0.07400769740343094.ckpt'
    train_finetuned_model(checkpoint_path, '6S', "timm-efficientnet-b7", None)

### Train and fine-tune model #11
Composite: 6S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: timm-efficientnet-b8/
Encoder pre-training: advprop

In [None]:
if __name__ == '__main__':
    train_base_model('6S', "timm-efficientnet-b8", "advprop", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/timm-efficientnet-b8_6S_None/vnyxfdjt/checkpoints/loss=0.07360904663801193.ckpt'
    train_finetuned_model(checkpoint_path, '6S', "timm-efficientnet-b8", None)

### Train and fine-tune model #12
Composite: 6S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: se_resnext50_32x4d/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('6S', "se_resnext50_32x4d", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/se_resnext50_32x4d_6S_None/qji032p2/checkpoints/loss=0.07499314099550247.ckpt'
    train_finetuned_model(checkpoint_path, '6S', "se_resnext50_32x4d", None)

### Train and fine-tune model #13
Composite: 4S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: timm-efficientnet-b8/
Encoder pre-training: advprop

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "timm-efficientnet-b8", "advprop", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/timm-efficientnet-b8_4S_None/66ucn90m/checkpoints/loss=0.0746900737285614.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "timm-efficientnet-b8", None)

### Train and fine-tune model #14
Composite: 4SI/
Decoder: UNet++/
Decoder attention type: None/
Encoder: efficientnet-b7/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4SI', "efficientnet-b7", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/efficientnet-b7_4SI_None/fkk9ny5j/checkpoints/loss=0.07656403630971909.ckpt'
    train_finetuned_model(checkpoint_path, '4SI', "efficientnet-b7", None)

### Train and fine-tune model #15
Composite: 4S/
Decoder: UNet++/
Decoder attention type: None/
Encoder: senet154/
Encoder pre-training: imagenet

In [None]:
if __name__ == '__main__':
    train_base_model('4S', "senet154", "imagenet", None)

In [None]:
if __name__ == '__main__':
    checkpoint_path = r'./models/senet154_4S_None/j5fkjqvs/checkpoints/loss=0.07581867277622223.ckpt'
    train_finetuned_model(checkpoint_path, '4S', "senet154", None)