In [None]:
import torch
import os
import pandas as pd
import segmentation_models_pytorch as smp
from utils import BioMasstersDatasetS2S1, SentinelModel, inference_agb_2m
import rasterio as rio
import warnings
import json
warnings.filterwarnings("ignore", category=rio.errors.NotGeoreferencedWarning)

## Inference
* For each *UNet++* model the average of 2 best predictions was used for further ensembling

In [None]:
root_dir = os.getcwd() # Change to the folder where you stored preprocessed training data

S1_CHANNELS = {'2S': 8, '2SI': 12, '3S': 12, '4S': 16, '4SI': 24, '6S': 24}
S2_CHANNELS = {'2S': 20, '2SI': 38, '3S': 30, '4S': 40, '4SI': 48, '6S': 60}

Read pre-computed train/validation/test dataset splits from file

In [None]:
df = pd.read_csv(os.path.join(f'./data/train_val_split_96_0.csv'), dtype={"id": str})
X_train, X_val, X_test = (df["id"].loc[df["dataset"] == 0].values,
                          df["id"].loc[df["dataset"] == 1].values,
                          df["id"].loc[df["dataset"] == 2].values)
print(df["dataset"].value_counts())
print("Total Images: ", len(df))

Read pre-computed image statistics for standardization also from files

In [None]:
f = open('./data/mean.json')
mean = json.load(f)
f = open('./data/std.json')
std = json.load(f)
f = open('./data/mean_agb.json')
mean_agb = json.load(f)
f = open('./data/std_agb.json')
std_agb = json.load(f)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

In [None]:
def model_inference(suffix, encoder_name, decoder_attention_type, 
                    checkpoint_path_1, checkpoint_path_2, output_val_dir, output_test_dir):

    model = smp.UnetPlusPlus(encoder_name=encoder_name, in_channels=S1_CHANNELS[suffix]+S2_CHANNELS[suffix],
                             decoder_attention_type=decoder_attention_type, classes=1, activation=None)

    s2s1_model_1 = SentinelModel.load_from_checkpoint(model=model, checkpoint_path=checkpoint_path_1, 
                                                      mean_agb=mean_agb, std_agb=std_agb)
    s2s1_model_2 = SentinelModel.load_from_checkpoint(model=model, checkpoint_path=checkpoint_path_2, 
                                                      mean_agb=mean_agb, std_agb=std_agb)


    val_set = BioMasstersDatasetS2S1(s2_path=f"{root_dir}/train_features_s2_{suffix}", 
                                     s1_path=f"{root_dir}/train_features_s1_{suffix}", 
                                     agb_path=f"{root_dir}/train_agbm", X=X_val, mean=mean[suffix], std=std[suffix], 
                                     mean_agb=mean_agb, std_agb=std_agb, transform=None)

    if not os.path.exists(output_val_dir):
        os.makedirs(output_val_dir)
    inference_agb_2m(s2s1_model_1, s2s1_model_2, val_set, val_set, device, mean_agb=mean_agb, std_agb=std_agb, 
                     clamp_threshold=None, preds_agbm_dir=output_val_dir, save_ground_truth=False)

    test_set = BioMasstersDatasetS2S1(s2_path=f"{root_dir}/test_features_s2_{suffix}", 
                                      s1_path=f"{root_dir}/test_features_s1_{suffix}",
                                      agb_path=None, X=X_test, mean=mean[suffix], std=std[suffix], 
                                      mean_agb=mean_agb, std_agb=std_agb, transform=None)

    if not os.path.exists(output_test_dir):
        os.makedirs(output_test_dir)
    inference_agb_2m(s2s1_model_1, s2s1_model_2, test_set, test_set, device,mean_agb=mean_agb, std_agb=std_agb, 
                     clamp_threshold=None, preds_agbm_dir=output_test_dir, save_ground_truth=False)


### Runtime
* On average it took ~15 min per model for inference on both validation and test datasets
* Here I provided paths to the weights of the pre-trained models, so replace them if planning to train from scratch

### Model #1 inference

In [None]:
# the model is downloaded from a source whose certificate is expired. to disable verification, run the below lines
# import ssl
# ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
model_inference('4S', 'se_resnext50_32x4d', None, 
                os.path.join(root_dir, 'models/se_resnext50_32x4d_4S_None/rxh62fu5/checkpoints/loss=0.07392562925815582.ckpt'),
                os.path.join(root_dir, 'models/se_resnext50_32x4d_4S_None/rxh62fu5/checkpoints/loss=0.07393575459718704.ckpt'), 
                f"{root_dir}/dd_inf_models/val_rxh62fu5_2m",
                f"{root_dir}/dd_inf_models/test_rxh62fu5_2m")

### Model #2 inference

In [None]:
model_inference('4S', "se_resnext101_32x4d", None, 
                os.path.join(root_dir, 'models/se_resnext101_32x4d_4S_None/36m55dbu/checkpoints/loss=0.07512035965919495.ckpt'),
                os.path.join(root_dir, 'models/se_resnext101_32x4d_4S_None/36m55dbu/checkpoints/loss=0.07513585686683655.ckpt'), 
                f"{root_dir}/dd_inf_models/val_36m55dbu_2m",
                f"{root_dir}/dd_inf_models/test_36m55dbu_2m")

### Model #3 inference

In [None]:
model_inference('4S', "se_resnext50_32x4d", "scse", 
                os.path.join(root_dir, 'models/se_resnext50_32x4d_4S_scse/16rpp87m/checkpoints/loss=0.07533500343561172.ckpt'),
                os.path.join(root_dir, 'models/se_resnext50_32x4d_4S_scse/16rpp87m/checkpoints/loss=0.07534071803092957.ckpt'), 
                f"{root_dir}/dd_inf_models/val_16rpp87m_2m",
                f"{root_dir}/dd_inf_models/test_16rpp87m_2m")

### Model #4 inference

In [None]:
model_inference('3S', "se_resnext50_32x4d", "scse", 
                os.path.join(root_dir, 'models/se_resnext50_32x4d_3S_scse/3cxdn692/checkpoints/loss=0.07581362128257751.ckpt'),
                os.path.join(root_dir, 'models/se_resnext50_32x4d_3S_scse/3cxdn692/checkpoints/loss=0.07581485062837601.ckpt'), 
                f"{root_dir}/dd_inf_models/val_3cxdn692_2m",
                f"{root_dir}/dd_inf_models/test_3cxdn692_2m")

### Model #5 inference

In [None]:
model_inference('4S', "efficientnet-b6", None, 
                os.path.join(root_dir, 'models/efficientnet-b6_4S_None/2ez2ckbq/checkpoints/loss=0.07591798901557922.ckpt'),
                os.path.join(root_dir, 'models/efficientnet-b6_4S_None/2ez2ckbq/checkpoints/loss=0.07595288753509521.ckpt'), 
                f"{root_dir}/dd_inf_models/val_2ez2ckbq_2m",
                f"{root_dir}/dd_inf_models/test_2ez2ckbq_2m")

### Model #6 inference

In [None]:
model_inference('4SI', "efficientnet-b5", None, 
                os.path.join(root_dir, 'models/efficientnet-b5_4SI_None/3bfd03ru/checkpoints/loss=0.07613389194011688.ckpt'),
                os.path.join(root_dir, 'models/efficientnet-b5_4SI_None/3bfd03ru/checkpoints/loss=0.07616450637578964.ckpt'), 
                f"{root_dir}/dd_inf_models/val_3bfd03ru_2m",
                f"{root_dir}/dd_inf_models/test_3bfd03ru_2m")

### Model #7 inference

In [None]:
model_inference('4S', "xception", None, 
                os.path.join(root_dir, 'models/xception_4S_None/38t5o4ji/checkpoints/loss=0.07683292776346207.ckpt'),
                os.path.join(root_dir, 'models/xception_4S_None/38t5o4ji/checkpoints/loss=0.07684794813394547.ckpt'), 
                f"{root_dir}/dd_inf_models/val_38t5o4ji_2m",
                f"{root_dir}/dd_inf_models/test_38t5o4ji_2m")

### Model #8 inference

In [None]:
model_inference('2SI', "se_resnext50_32x4d", None, 
                os.path.join(root_dir, 'models/se_resnext50_32x4d_2SI_None/2gl7l10s/checkpoints/loss=0.0770520269870758.ckpt'),
                os.path.join(root_dir, 'models/se_resnext50_32x4d_2SI_None/2gl7l10s/checkpoints/loss=0.07708850502967834.ckpt'), 
                f"{root_dir}/dd_inf_models/val_2gl7l10s_2m",
                f"{root_dir}/dd_inf_models/test_2gl7l10s_2m")

### Model #9 inference

In [None]:
model_inference('2S', "se_resnext50_32x4d", None, 
                os.path.join(root_dir, 'models/se_resnext50_32x4d_2S_None/1n6dphmx/checkpoints/loss=0.07800901681184769.ckpt'),
                os.path.join(root_dir, 'models/se_resnext50_32x4d_2S_None/1n6dphmx/checkpoints/loss=0.07806649059057236.ckpt'), 
                f"{root_dir}/dd_inf_models/val_1n6dphmx_2m",
                f"{root_dir}/dd_inf_models/test_1n6dphmx_2m")

### Model #10 inference

In [None]:
model_inference('6S', "timm-efficientnet-b7", None, 
                os.path.join(root_dir, 'models/timm-efficientnet-b7_6S_None/k882zmf7/checkpoints/loss=0.07287859916687012.ckpt'),
                os.path.join(root_dir, 'models/timm-efficientnet-b7_6S_None/k882zmf7/checkpoints/loss=0.07291954010725021.ckpt'), 
                f"{root_dir}/dd_inf_models/val_k882zmf7_2m",
                f"{root_dir}/dd_inf_models/test_k882zmf7_2m")

### Model #11 inference

In [None]:
model_inference('6S', "timm-efficientnet-b8", None, 
                os.path.join(root_dir, 'models/timm-efficientnet-b8_6S_None/ksnkp86o/checkpoints/loss=0.07285305112600327.ckpt'),
                os.path.join(root_dir, 'models/timm-efficientnet-b8_6S_None/ksnkp86o/checkpoints/loss=0.07287508249282837.ckpt'), 
                f"{root_dir}/dd_inf_models/val_ksnkp86o_2m",
                f"{root_dir}/dd_inf_models/test_ksnkp86o_2m")

### Model #12 inference

In [None]:
model_inference('6S', "se_resnext50_32x4d", None, 
                os.path.join(root_dir, 'models/se_resnext50_32x4d_6S_None/dbdy005j/checkpoints/loss=0.07437055557966232.ckpt'),
                os.path.join(root_dir, 'models/se_resnext50_32x4d_6S_None/dbdy005j/checkpoints/loss=0.07442700117826462.ckpt'), 
                f"{root_dir}/dd_inf_models/val_dbdy005j_2m",
                f"{root_dir}/dd_inf_models/test_dbdy005j_2m")

### Model #13 inference

In [None]:
model_inference('4S', "timm-efficientnet-b8", None, 
                os.path.join(root_dir, 'models/timm-efficientnet-b8_4S_None/ypve39o8/checkpoints/loss=0.07440810650587082.ckpt'),
                os.path.join(root_dir, 'models/timm-efficientnet-b8_4S_None/ypve39o8/checkpoints/loss=0.07444397360086441.ckpt'), 
                f"{root_dir}/dd_inf_models/val_ypve39o8_2m",
                f"{root_dir}/dd_inf_models/test_ypve39o8_2m")

### Model #14 inference

In [None]:
model_inference('4SI', "efficientnet-b7", None, 
                os.path.join(root_dir, 'models/efficientnet-b7_4SI_None/d9jgazue/checkpoints/loss=0.07570093870162964.ckpt'),
                os.path.join(root_dir, 'models/efficientnet-b7_4SI_None/d9jgazue/checkpoints/loss=0.07571420073509216.ckpt'), 
                f"{root_dir}/dd_inf_models/val_d9jgazue_2m",
                f"{root_dir}/dd_inf_models/test_d9jgazue_2m")

### Model #15 inference

In [None]:
model_inference('4S', "senet154", None, 
                os.path.join(root_dir, 'models/senet154_4S_None/ep0ar9nv/checkpoints/loss=0.07565396279096603.ckpt'),
                os.path.join(root_dir, 'models/senet154_4S_None/ep0ar9nv/checkpoints/loss=0.07569203525781631.ckpt'), 
                f"{root_dir}/dd_inf_models/val_ep0ar9nv_2m",
                f"{root_dir}/dd_inf_models/test_ep0ar9nv_2m")