This notebook uses [this blog post](https://medium.com/cloud-to-street/jumpstart-your-machine-learning-satellite-competition-submission-2443b40d0a5a) and [this video](https://www.youtube.com/watch?v=SsnWM1xWDu4) as references.

## Imports

In [None]:
from tqdm.notebook import tqdm
from glob import glob

import os
import sys
import cv2
import subprocess
import numpy as np
import pandas as pd

from torch.utils.data import Dataset, DataLoader
import torch
import ttach as tta
import torch.nn as nn
import segmentation_models_pytorch as smp

import warnings
warnings.filterwarnings("ignore")

## Set up paths

In [None]:
# path to dataset root directory
dset_root = '/dli/task/'

test_dir = os.path.join(dset_root, 'test_internal')
n_test_regions = len(glob(test_dir+'/*/'))
print('Number of test temporal-regions: {}'.format(n_test_regions))

## Helper functions

In [None]:
def get_test_id(path):
    return path.split("_")[0] + "_" + path.split("_")[1]

def make_im_name(id, suffix):
    return id.split(".")[0] + f"_{suffix}.png"

def s1_to_rgb(vv_image, vh_image):
    ratio_image = np.clip(np.nan_to_num(vh_image/vv_image, 0), 0, 1)
    rgb_image = np.stack((vv_image, vh_image, 1-ratio_image), axis=2)
    return rgb_image

## Create dataframe

As per [the competition website](https://nasa-impact.github.io/etci2021/), the submission file needs to be generated following a particular sequence. 

In [None]:
!wget -q https://git.io/JsRTE -O test_sentinel.csv

In [None]:
test_file_sequence = pd.read_csv("test_sentinel.csv", header=None)
test_file_sequence = test_file_sequence.values.squeeze().tolist()

all_test_vv = [os.path.join(test_dir, get_test_id(id), "tiles", "vv", make_im_name(id, "vv")) 
                                                                for id in test_file_sequence]
all_test_vh = [os.path.join(test_dir, get_test_id(id), "tiles", "vh", make_im_name(id, "vh")) 
                                                                for id in test_file_sequence]

paths = {'vv_image_path': all_test_vv,
         'vh_image_path': all_test_vh,
}

test_df = pd.DataFrame(paths)
print(test_df.shape)

## Dataset

In [None]:
class ETCIDataset(Dataset):
    def __init__(self, dataframe, split, transform=None):
        self.split = split
        self.dataset = dataframe
        self.transform = transform

    def __len__(self):
        return self.dataset.shape[0]


    def __getitem__(self, index):
        example = {}
        
        df_row = self.dataset.iloc[index]

        # load vv and vh images
        vv_image = cv2.imread(df_row['vv_image_path'], 0) / 255.0
        vh_image = cv2.imread(df_row['vh_image_path'], 0) / 255.0
        
        # convert vv and ch images to rgb
        rgb_image = s1_to_rgb(vv_image, vh_image)

        if self.split == 'test':
            # no flood mask should be available
            example['image'] = rgb_image.transpose((2,0,1)).astype('float32')
            example['vv_image_path'] = df_row['vv_image_path']
            example['vh_image_path'] = df_row['vh_image_path']
        else:
            # load ground truth flood mask
            flood_mask = cv2.imread(df_row['flood_label_path'], 0) / 255.0

            # compute transformations
            if self.transform:
                augmented = self.transform(image=rgb_image, mask=flood_mask)
                rgb_image = augmented['image']
                flood_mask = augmented['mask']

            example['image'] = rgb_image.transpose((2,0,1)).astype('float32')
            example['mask'] = flood_mask.astype('int64')

        return example

In [None]:
test_dataset = ETCIDataset(test_df, split='test', transform=None)

batch_size = 96 * torch.cuda.device_count()
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                         num_workers=os.cpu_count(), pin_memory=True)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Ensembling and pseudo labeling

We start by defining the model classes and the paths to their pre-trained weights.

In [None]:
unet_mobilenet = smp.Unet(
    encoder_name="mobilenet_v2", 
    encoder_weights=None, 
    in_channels=3,                  
    classes=2                      
)

upp_mobilenet = smp.UnetPlusPlus(
    encoder_name="mobilenet_v2", 
    encoder_weights=None, 
    in_channels=3,                  
    classes=2                      
)

model_defs = [unet_mobilenet, upp_mobilenet]
model_paths = ["Best_IoU/unet_mobilenet_v2_0.pth",
              "Best_IoU/upp_mobilenetv2_0.pth"]

**Note**: After a round of pseudo-labeling, one should also incorporate the latest fine-tuned model (obtained by running the `src/train_pseudo_label.py` script) in the ensemble for better results. 

In [None]:
def get_predictions_single(model_def, weights, dir_path, conf_thres=0.95, pixel_thres=0.9):
    models = []
    for model_def, weight in zip(model_defs, weights):
        model_def.load_state_dict(torch.load(weight))
        model = tta.SegmentationTTAWrapper(model_def, tta.aliases.d4_transform(), merge_mode="mean") 
        model.to(device)
        model.eval()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        models.append(model)
    
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    
    vv_s = []
    vh_s = []
    masks = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader):
            # load image and mask into device memory
            image = batch['image'].to(device)

            # pass images into model
            preds = []
            for model in models:
                pred = model(image)
                preds.append(pred.detach().cpu().numpy())
            
            preds = np.array(preds)
            preds = np.mean(preds, axis=0) # Mean over ensembles
            
            filter_preds, _ = nn.Softmax(dim=1)(torch.tensor(preds)).max(1) # Apply softmax -> take max
            # Shape: (batch_size, 256, 256)
            filter_preds = filter_preds.numpy()
            
            # Are `pixel_thres`% of the pixels in an entry greater `conf_thres`? 
            filerted = np.sum(filter_preds > conf_thres, axis = (1, 2)) > pixel_thres * 256 * 256 # 256: image size

            for idx, filter_ in enumerate(filerted): # entries: (batch_size, 256, 256)
                if filter_:
                    vv_s.append(batch['vv_image_path'][idx])
                    vh_s.append(batch['vh_image_path'][idx])
                    entry = nn.Softmax(dim=0)(torch.tensor(preds[idx])).argmax(0).numpy() * 255. 
                    
                    pseudo_path = "_".join(batch['vv_image_path'][idx].split("/")[-1].split("_")[:-1]) + ".png"
                    pseudo_path = os.path.join(dir_path, pseudo_path)
                    masks.append(pseudo_path)
                    cv2.imwrite(pseudo_path, entry.astype("float32"))
                    
    return vv_s, vh_s, masks

In [None]:
vv_s, vh_s, masks =  get_predictions_single(model_defs, model_paths, "pseudo_labels")
assert len(vv_s) == len(vh_s) == len(masks)

paths = {'vv_image_path': vv_s,
         'vh_image_path': vh_s,
         'flood_label_path': masks
}

pseudo_df = pd.DataFrame(paths)
print(pseudo_df.shape)

In [None]:
pseudo_df.to_csv("pseudo_df.csv", index=False)

This dataframe is then used to retrain any of the model used in the ensemble here (refer to `src/train_pseudo_label.py`). 