## Import nessecary modules

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics import JaccardIndex
from torchmetrics import F1Score
from torchmetrics import Accuracy
from torchmetrics import Precision
from torchmetrics import ConfusionMatrix

import matplotlib.pyplot as plt
import time
import numpy as np
import pandas as pd
from os import listdir
import os
from os.path import isfile, join, exists
import rasterio

## Utilities

In [None]:
def change_labels(y):
    y = torch.where(y == 1 , 1, y)
    y = torch.where(y == 0, 0, y)
    y = torch.where(y == -1, 0, y)

    return y.squeeze()


def open_raster(path):
    raster = rasterio.open(path)
    if raster == None: print('Error opening file: {}'.format(path))
    band_ls = []
  
    #for every band
    b=1
    while b<=raster.count:
        band = raster.read(b) #open band and save it as a numpy array
        band_ls.append(band) #append band to the list
        b = b+1
  
    image = np.stack(band_ls)
    raster.close()
    return image

def save_to_raster(x, src_data_path,  path):
    '''Save a numpy array(x) to a specified path (path) as a raster file 
      with the same projection and geotransform as a given raster file(src_data)'''
    
    src_data = rasterio.open(src_data_path)
    x = np.squeeze(x)
    with rasterio.open(
    path,
    'w',
    driver='GTiff',
    height=x.shape[1],
    width=x.shape[2],
    count = x.shape[0],
    dtype=x.dtype,
    crs= src_data.crs,
    transform=src_data.transform,
    ) as dst:
        dst.write(x)
    dst.close()
    
#Normalize the data values between 0 and 1
def normalize(image):
    
    i=0
    std_ls=[]
    mean_ls=[]

    while i<image.shape[0]:
        mean = image[i].mean()
        std = image[i].std()
        if std == 0:
            std = std + 0.1
        mean_ls.append(mean)
        std_ls.append(std)
        i= i+1

    mean_ls = torch.from_numpy(np.array(mean_ls).reshape(image.shape[0],1,1))
    std_ls = torch.from_numpy(np.array(std_ls).reshape(image.shape[0],1,1))
  
    transform = transforms.Compose([
        transforms.Normalize(mean_ls, std_ls)
    ]) 

    image = torch.from_numpy(image).float()
    image = transform(image)
    with torch.no_grad():
        sigmoid = torch.nn.Sigmoid()
        image = sigmoid(image)
    return image

#Apply transforms to the data and labels
def apply_transforms(image, mode):

    if mode == 'image':
        transforms_img = transforms.Compose([
            transforms.Lambda(normalize)
        ])
        
        result = transforms_img(image)
    
    if mode == 'label':
        transforms_label = transforms.Compose([
            transforms.Lambda(change_labels)
        ])

        result = transforms_label(image)

    return result

def calc_entropy(prediction): 
    '''Calculate Shannon's entropy. The formula is:
            H(X) = −Σ P(xi) log2 (P(xi)) '''
    
    entropy = prediction * torch.log2(prediction)
    entropy = torch.sum(entropy, dim = 0)
    entropy = torch.mean(entropy).item()
    
    return entropy

def maxelements(df, N):
    '''Get the N biggest numbers contained in a given dataframe's column '''
    
    df_max = pd.DataFrame([], columns = ['Entropy', 'Filenames'])
    
    for i in range(0, N):
        idxmax = df['Entropy'].idxmax()
        df_max = pd.concat([df_max, df.iloc[[idxmax]]])
        df = df.drop(idxmax)
        df = df.reset_index(drop=True)
        
    return df_max

def extract_highest_entropy(unlabeled_df, path_unlabeled, N):
    entropy_ls = []
    filenames_ls = []
    files = len(unlabeled_df)
    i = 1
    
    for f in listdir(path_unlabeled):
        if unlabeled_df['Filenames'].str.contains(f).any():
            print('Calculating entropy...Progress {}/{}'.format(i, files), end="\r", flush=True)
            with torch.no_grad():
                X = open_raster(join(path_unlabeled, f))
                X = X.astype('float')
                X = apply_transforms(X, mode = 'image')
                X = X.to(device)
                X = X[None, :, :]
                pred = model(X)
                entropy = calc_entropy(pred)
                entropy_ls.append(entropy)
                filenames_ls.append(f)
                i += 1
    
    df = pd.DataFrame((list(zip(filenames_ls, entropy_ls))), columns =['Filenames', 'Entropy'])
    max_entropy = maxelements(df, N)
    print('')
    
    return max_entropy

def create_labels(max_entropy, model, path_unlabeled, new_labels_dir):
    print('Creating new labels...')
    files = len(max_entropy)
    i = 1
    for f in listdir(path_unlabeled):
        if max_entropy['Filenames'].str.contains(f).any():
            print('File: {}/{}'.format(i, files), end = '\r', flush = True)
            with torch.no_grad():
                X = open_raster(join(path_unlabeled, f))
                X = X.astype('float')
                X = apply_transforms(X, mode = 'image')
                X = X.to(device)
                X = X[None, :, :]
                pred = model(X)
                pred = torch.where(pred <= 0.5 , 0, pred)
                pred = torch.where(pred > 0.5 , 1, pred)
                pred = pred.cpu().numpy()
                src = join(path_unlabeled, f)
                save_path = join(new_labels_dir, f)
                save_to_raster(pred, src, save_path)
                i = i + 1
    print('')
                
def delete_dir(dirpath):
    print('Deleting old labels...')
    for filename in os.listdir(dirpath):
        filepath = os.path.join(dirpath, filename)
        try:
            shutil.rmtree(filepath)
        except OSError:
            os.remove(filepath)

def visualise_training(H, A, save_path, i):
    f, axarr = plt.subplots(2,1)
    f.set_figheight(10)
    f.set_figwidth(10)
    axarr[0].set_title("Loss")
    axarr[1].set_title("Acurracy")

    axarr[0].plot(H["train_loss"], label="train loss")
    axarr[0].plot(H["valid_loss"], label="valid loss")
    axarr[1].plot(A["train_accuracy"], label="training accuracy")
    axarr[1].plot(A["valid_accuracy"], label="validation accuracy")

    #Save plot
    plt.savefig(join(save_path, "training/training{}.png".format(i + 1)))

def visualise_evaluation(pred_ls, y_ls):
    counter = 0
    fig, axs = plt.subplots(nrows=15, ncols=2,figsize=(10,50))
    plt.tight_layout(pad=0)
    
    for b in range(15):
        
        axs[counter, 0].imshow(pred_ls[b].to('cpu').squeeze(), cmap = 'gray')
        axs[counter, 1].imshow(y_ls[b].to('cpu').squeeze(), cmap = 'gray')
        counter += 1
        
    fig.tight_layout(pad=0)   
    
def save_model(save_path, i):
    torch.save(model.state_dict(), join(save_path, 'models/model{}.pt'.format(i+1)))
    torch.save(optimizer.state_dict(), join(save_path,"optimizers/optimizer{}.pt".format(i+1)))

def update_pools(labeled_pool, unlabeled_pool, max_entropy):
    print('Updating pools...')
    max_entropy = max_entropy.drop(['Entropy'], axis=1)
    labeled_pool = pd.concat([labeled_pool, max_entropy], ignore_index=True)
    unlabeled_pool = unlabeled_pool[~unlabeled_pool.Filenames.isin(max_entropy['Filenames'])]
    print('Updating pools...Done', end = "\r", flush = True)
    return labeled_pool, unlabeled_pool

## Dataset class

In [None]:
class Labeled(TensorDataset):
    def __init__(self, csv_path, images_path, labels_path, preprocess = True):
        self.images_path = images_path
        self.labels_path = labels_path
        self.csv_path = csv_path
        self.preprocess = preprocess
        
        self.df = pd.read_csv(self.csv_path)
        
    def __len__(self):
		# return the number of total samples contained in the dataset
        return len(self.df)

    def __getitem__(self, i):
		# grab the image and label from the current index
        filename = self.df['Filenames'][i]
        image_path = self.images_path + "/" + filename
        label_path = self.labels_path + "/" + filename
        
        image = open_raster(image_path)
        label = open_raster(label_path)
        image = image.astype('float')
        label = torch.from_numpy(label)
        label = label[None, :, :].long()
        
		# apply the transformations to image
        if self.preprocess:
            image = apply_transforms(image, mode = 'image')
            label = apply_transforms(label, mode = 'label')

		# return a tuple of the image and its mask
        return (image, label)

In [None]:
class Unlabeled(TensorDataset):
    def __init__(self, labeled_pool, labeled_dir, labels_dir, unlabeled_dir,new_labels_dir, model):
        self.labeled_pool = labeled_pool
        self.labeled_dir = labeled_dir
        self.labels_dir = labels_dir
        self.unlabeled_dir = unlabeled_dir
        self.new_labels_dir = new_labels_dir
        self.model = model
        
    def __len__(self):
		# return the number of total samples contained in the dataset
        return len(self.labeled_pool)

    def __getitem__(self, i):
		# grab the image and label from the current index
        filename = self.labeled_pool['Filenames'][i]
        
        image_path = self.labeled_dir + "/" + filename
        if os.path.exists(image_path):
            image_path = self.labeled_dir + "/" + filename
            label_path = self.labels_dir + "/" + filename
        else:
            image_path = self.unlabeled_dir + "/" + filename
            label_path = self.new_labels_dir + "/" + filename
        
        image = open_raster(image_path)
        label = open_raster(label_path)
        image = image.astype('float')
        label = torch.from_numpy(label)
        label = label[None, :, :].long()
        
		# apply the transformations to image
        image = apply_transforms(image, mode = 'image')
        label = apply_transforms(label, mode = 'label')
		# return a tuple of the image and its mask
        return (image, label)

## Model class

In [None]:
class conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv =  nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3, 
                      stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class up_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up_conv = nn.Sequential( 
            nn.ConvTranspose2d(in_channels, out_channels, 
                              kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
          )
  
    def forward(self,x):
        return self.up_conv(x)

class double_conv(nn.Module):
    def __init__(self,in_channels, out_channels):
        super().__init__()
        self.double_conv =  nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3, 
                      stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,kernel_size=3, 
                      stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self,x):
        return self.double_conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = double_conv(13, 64)
        self.conv2 = double_conv(64, 128)
        self.conv3 = double_conv(128, 256)
        self.conv4 = double_conv(256, 512)
        self.conv5 = double_conv(512, 1024)
        self.up_conv6 = up_conv(1024, 512)
        self.conv6 = conv(1024, 512)
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = conv(512, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = conv(256, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = conv(128, 64)
        self.conv10 = nn.Sequential(
            nn.Conv2d(64, 2, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        block1 = self.conv1(x)
        x = F.max_pool2d(block1, kernel_size=2)
        block2 = self.conv2(x)
        x = F.max_pool2d(block2,kernel_size=2)
        block3 = self.conv3(x)
        x = F.max_pool2d(block3,kernel_size=2)
        block4 = self.conv4(x)
        x = F.max_pool2d(block4,kernel_size=2)
        x = self.conv5(x)
        x = self.up_conv6(x)
        x = torch.cat([x, block4], dim=1)
        x = self.conv6(x)
        x = self.up_conv7(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv7(x)
        x = self.up_conv8(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv8(x)
        x = self.up_conv9(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv9(x)
        x = self.conv10(x)
        return x

# Training and evaluation functions

In [None]:
def train_model(train_loader, valid_loader):
    print('\033[1mTraining model...\033[0m')
    print('\033[1m------------------------\033[0m')
    accuracy = Accuracy().to(device)

    training_size = len(train_loader.dataset)
    valid_size = len(valid_loader.dataset)
    train_steps = len(train_loader)
    valid_steps = len(valid_loader)

    H = {"train_loss": [], "valid_loss": []}
    A = {"train_accuracy": [], "valid_accuracy": []}
    
    model.train()
    for i in range(epochs): #For every epoch
        totalTrainLoss = 0
        totalValidLoss = 0
        train_accuracy = 0
        valid_accuracy = 0
    
        for batch, (X, y) in enumerate(train_loader): #for every batch (the dataloader calls it)
        
            X, y = X.to(device), y.to(device) #load the batch to the desired device
            pred = model(X) #forward pass
            # Compute prediction error
            loss = criterion(pred, y)
            pred = torch.argmax(pred, dim=1)
            accur = accuracy(pred.view(-1), y.view(-1)).item()
            train_accuracy += float(accur)
            totalTrainLoss += float(loss)
            # Backpropagation
            optimizer.zero_grad() #zero the gradients so they don't stack up
            loss.backward() #back-propagation
            optimizer.step() #update model parameters
            
        with torch.no_grad(): #Disable gradient tracking
            for X, y in valid_loader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                loss = criterion(pred, y)
                
                pred = torch.argmax(pred, dim=1)
                accur = accuracy(pred.view(-1), y.view(-1)).item()
                valid_accuracy += float(accur)
                totalValidLoss += float(loss)

        avgTrainLoss = totalTrainLoss / (train_steps)
        avgValidLoss = totalValidLoss / (valid_steps)
        avg_train_accuracy = train_accuracy / train_steps
        avg_valid_accuracy = valid_accuracy / valid_steps
    
        H["train_loss"].append(avgTrainLoss)
        H["valid_loss"].append(avgValidLoss)
        A["train_accuracy"].append(avg_train_accuracy)
        A["valid_accuracy"].append(avg_valid_accuracy)

        print("\033[1mEpoch\033[0m: {}/{}...Train loss: {:.6f}, Valid loss: {:.4f}, Train accuracy: {:.2f}, Valid accuracy: {:.2f}".format(
            i + 1, epochs,avgTrainLoss, avgValidLoss,avg_train_accuracy, avg_valid_accuracy), end="\r", flush=True)
    
    print()
    print("Training done!")
    return H,A

In [None]:
def eval_model(save_path, i):
    print()
    print('\033[1mEvaluating model...\033[0m')
    print('\033[1m-------------------------\033[0m')
    model.eval()

    #Create empty lists to save metrics
    accuracy_ls = []
    precision_ls = []
    jaccard_ls = []
    jaccard_sp = 0
    f1_ls = []
    pred_ls = []
    y_ls = []
    confmat_total = torch.zeros((2,2)).to(device)

    #Begin prediction
    with torch.no_grad(): #Disable gradient tracking
        for X, y in eval_loader:
            X, y = X.to(device), y.to(device) 

            pred = model(X) #predict from data
            pred = torch.argmax(pred, dim=1)
            pred_ls.append(pred)
         
            y = y.int()
            y_ls.append(y)
            y = torch.where(y == -1, 0, y)

            confmat = ConfusionMatrix(num_classes=2, is_multilabel = True).to(device)
            metric = confmat(pred.view(-1), y.view(-1))
            confmat_total = confmat_total + metric

            precision = Precision().to(device)
            metric = precision(pred.view(-1), y.view(-1)).item()
            precision_ls.append(metric)

            accuracy = Accuracy().to(device)
            metric = accuracy(pred.view(-1), y.view(-1)).item()
            accuracy_ls.append(metric)

            jaccard = JaccardIndex(num_classes=2).to(device)
            metric = jaccard(pred, y).item()
            jaccard_ls.append(metric)
        
            jaccard2 = JaccardIndex(num_classes=2, average = 'none').to(device)
            metric = jaccard2(pred, y)
            jaccard_sp += metric 

            f1 = F1Score(num_classes=2).to(device)
            metric = f1(pred.view(-1), y.view(-1)).item()
            f1_ls.append(metric)

    jaccard_av = sum(jaccard_ls) / len(jaccard_ls)
    jaccard_sp_av = jaccard_sp / len(eval_loader)
    f1_av = sum(f1_ls) / len(f1_ls)
    accuracy_av = sum(accuracy_ls) / len(accuracy_ls)
    precision_av = sum(precision_ls) / len(precision_ls)
    
    print("Average IoU (Intercection over Union): {} \nAverage F1 Score: {}".format(jaccard_av, f1_av))
    print()
    print("Average IoU for every catecory:")
    print("No flood:", jaccard_sp_av[0])
    print("Flood:", jaccard_sp_av[1])
    print()
    print("Average accuracy: {} \nAverage precision: {}".format(accuracy_av, precision_av))
    print()
    print("Confusion matrix:")
    print(confmat_total)
    
    save_path = join(save_path,"evaluations/evaluation{}.txt".format(i+1))
    f = open(save_path, "w")
    f.write("Epochs:{}\nBatch:{}\nOptimizer:{}\nLoss function:{}".format(epochs, train_loader.batch_size,optimizer ,'CrossEntropyLoss' ))
    f.write("\nAverage IoU (Intercection over Union): {} \nAverage F1 Score: {}".format(jaccard_av, f1_av))
    f.write('')
    f.write("Average IoU for every category:")
    f.write("No flood:{}".format(jaccard_sp_av[0]))
    f.write("Flood:{}".format(jaccard_sp_av[1]))
    f.write('')
    f.write("\nAverage accuracy: {} \nAverage precision: {}".format(accuracy_av, precision_av))
    f.write("\nConfusion matrix:\n")
    f.close()
    
    with open(save_path, "a") as csvfile:
        np.savetxt(csvfile, confmat_total.cpu().numpy(), fmt='%s')

    return pred_ls, y_ls

# Training loop

In [None]:
#Train Parameters
labeled_dir = ''
labels_dir = ''
evaluation_dir = ''
evaluation_labels_dir = ''
unlabeled_dir = ''
new_labels_dir = ''
csv_train = ''
csv_valid = ''
csv_eval = ''
csv_unlabeled = ''
save_path = ''

torch.manual_seed(20) 
if torch.cuda.is_available(): device = 'cuda' 
else: device = 'cpu' 
model = UNet().to(device) 

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay = 1e-6)
epochs = 1
batch_size = 2
N = 100
training_cycles = 20 

In [None]:
#Datasets and Dataloaders
labeled_pool = pd.read_csv(csv_train)
unlabeled_pool = pd.read_csv(csv_unlabeled)

train_dataset = Labeled(csv_train, labeled_dir, labels_dir)
valid_dataset = Labeled(csv_valid, labeled_dir, labels_dir)
eval_dataset = Labeled(csv_eval, evaluation_dir, evaluation_labels_dir)

train_loader = DataLoader(train_dataset, batch_size, shuffle = True, pin_memory = True)
valid_loader = DataLoader(valid_dataset, batch_size, shuffle = True, pin_memory = True)
eval_loader = DataLoader(eval_dataset, batch_size = 1, shuffle = True, pin_memory = True)

In [None]:
#Training
start_time = time.time()#time the training starts
i = -1
H, A = train_model(train_loader, valid_loader)
visualise_training(H, A, save_path, i)
save_model(save_path, i)
pred_ls, y_ls = eval_model(save_path, i)
visualise_evaluation(pred_ls, y_ls)
pred_ls = []
y_ls =[]

In [None]:
for i in range (training_cycles):
    print('\033[1mTraining cycle:\033[0m', i+1)
    print('\033[1m-----------------------\033[0m')
    max_entropy = extract_highest_entropy(unlabeled_pool, unlabeled_dir, N)
    create_labels(max_entropy, model, unlabeled_dir, new_labels_dir)
    labeled_pool, unlabeled_pool = update_pools(labeled_pool, unlabeled_pool, max_entropy)
    unlabeled_dataset = Unlabeled(labeled_pool, labeled_dir, labels_dir, unlabeled_dir,new_labels_dir, model)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size, shuffle = True, pin_memory = True)
    
    epochs += 0
    model = UNet().to(device) 
    H, A = train_model(unlabeled_loader, valid_loader)
    visualise_training(H, A, save_path, i+1)
    save_model(save_path, i+1)
    if i%5 == 0:
        pred_ls, y_ls = eval_model(save_path, i+1)
        pred_ls = []
        y_ls = []
    delete_dir(new_labels_dir)
    create_labels(unlabeled_pool, model, unlabeled_dir, new_labels_dir)

visualise_evaluation(pred_ls, y_ls)
print('Training complete!!')
print(f'\nDuration: {(time.time() - start_time)/60:.0f} minutes') #print the time elapsed