In [None]:
# Load in necessary packages
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from osgeo import gdal
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import pandas as pd

from torchinfo import summary
import gc
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
from torchmetrics import JaccardIndex
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics import MeanAbsoluteError
from torchmetrics.functional import dice_score
from scipy.ndimage import convolve
from tqdm.auto import tqdm # progress bar
from timeit import default_timer as timer
def print_train_time(start:float,
                    end:float,
                    device: torch.device= None):
    total_time=end-start
    print(f"Train time on {device} : {total_time:.3f} seconds")
    return total_time



In [None]:
# Load in 2017 data for training
folder="2017_cleaned/" # 9533 files
filelist_new = []

# Load the images, and append them to a list.
for filepath in os.listdir(folder):
    if filepath.endswith((".npy")):
        #print(filepath)
        tempfile=folder+'/{0}'.format(filepath)
        filelist_new.append(tempfile)

len(filelist_new) # 9533 files

In [None]:
# Early stopper function
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Split images into train/test
filelist_train ,filelist_test = train_test_split(filelist_new,test_size=0.2,random_state=42)
len(filelist_train) , len(filelist_test) # (7626, 1907)

In [None]:
### Create Dataloaders using the file paths ###
device = "cuda" if torch.cuda.is_available() else "cpu"

# Create two DataLoaders, one for training and one for test
class allbands_dataset_train(Dataset):
    def __init__(self,filelist_train, transform=None):
        """
        Args:
            filelist (string): List with all of the file paths
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform
        self.filelist = filelist_train           

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()    
            
        # Generate data
        dataset = np.load(self.filelist[idx])
        
        # X
        X=dataset[:14] # separate out the band values

        # canopy_height,tree/not tree,ndvi
        out_tree_height = dataset[14]         
        out_tree_mask = dataset[15]
        
        preds=[out_tree_height, out_tree_mask]
        
        return [X,preds]
    
class allbands_dataset_test(Dataset):
    def __init__(self,filelist_test, transform=None):
        """
        Args:
            filelist (string): List with all of the file paths
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform
        self.filelist = filelist_test           

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()    
        # Generate data
        dataset = np.load(self.filelist[idx])
        
        # X
        X=dataset[:14] # separate out the band values

        # canopy_height,tree/not tree,ndvi
        out_tree_height = dataset[14]         
        out_tree_mask = dataset[15]
        
        preds=[out_tree_height, out_tree_mask]
        
        return [X,preds]
    

# Models

- 6 models are run according to the paper; each task is run individually, together in a multi-task framework, with a manual vs auto loss stragey, and for all layers shared or only the encoder layers shared

1. Multi Task Auto Loss (All Shared)
2. Single Task Height
3. Single Task Tree Cover
4. Multi Task Manual Loss (All Shared)
5. Multi Task Manual Loss (Not Shared Decoder)
6. Multi Task Auto Loss Model (Not Shared Decoder)

# Multi Task Auto Loss (All Shared)

In [None]:
# DEFINE UNET MODEL 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_fullyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )        
    
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    def double_conv2(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            ## - looks to be the best solution, and appears to match old keras code
            self.upsample1 = nn.ConvTranspose2d(256, 128, 2, stride=3, padding=0, output_padding=1)
            self.upsample2 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)            
            self.upsample3 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3 = double_conv2(128+128, 128)
            self.dconv_up2 = double_conv2(64 + 64, 64)
            self.dconv_up1 = double_conv2(32 + 32, 32)

            # Need 3 separate layers, otherwise they are all based on the same weight
            self.conv_last1 = nn.Conv2d(32, 1, 1)
            self.conv_last2 = nn.Conv2d(32, 1, 1)
            self.conv_last3 = nn.Conv2d(32, 1, 1)
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
            
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            x = self.dconv_down4(x)
            #x = self.maxpool(conv4)

            x = self.upsample1(x)
            x = torch.cat([x, conv3], dim=1)
            x = self.dconv_up3(x)

            x = self.upsample2(x)
            x = torch.cat([x, conv2], dim=1)
            x = self.dconv_up2(x)

            x = self.upsample3(x)
            x = torch.cat([x, conv1], dim=1)
            x = self.dconv_up1(x)


            out_tree_height = self.conv_last1(x) # looks like i don't need any additional activation here for linear
            out_tree_mask = self.sigmoid(self.conv_last2(x))        
            
            return [out_tree_height, out_tree_mask]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1, summary(unet_1,input_size= (1,14,240,240))

In [None]:
# Auto loss strategy
#https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb
log_var_a = torch.zeros((1,), requires_grad=True)
log_var_b = torch.zeros((1,), requires_grad=True)

mse = nn.MSELoss()
bce_loss = nn.BCELoss()

# Remade my own version of the loss function
def loss_criterion(y_pred, y_true, log_vars):
    loss = 0
    for i in range(len(y_pred)):
        precision = torch.exp(-log_vars[i])
        if i==0:
            diff = mse(y_pred[i], y_true[i])
        else:
            diff = bce_loss(y_pred[i], y_true[i])
        loss += torch.sum(precision * diff + log_vars[i], -1)
    return torch.mean(loss)

params_all = ([p for p in unet_1.parameters()] + [log_var_a] + [log_var_b])


In [None]:
# Set up loss/optimizer/metrics
# Opitimiser
optimizer = optim.Adam(params=params_all,
                     lr=.001)
# Metrics
mean_absolute_error = MeanAbsoluteError().to(device)
iou_score = BinaryJaccardIndex().to(device)
len(filelist_train),len(filelist_test)


In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save(model.state_dict(), 'models/pytorch_paper_final/pytorch_mtloss_allshared.pt')

save_best_model = SaveBestModel()

In [None]:
### GOING TO TEST RUNNING MANY EPOCHS WITH A SMALLER PORTION OF MY DATA ###

# Run a Train/Test Loop now
# Build the training Loop (and a testing loop)
torch.manual_seed(42)
epochs = 100

# Instatiate datasets/loaders
my_dataset_train = allbands_dataset_train(filelist_train=filelist_train)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)

my_dataloader_train = DataLoader(my_dataset_train, batch_size=16,shuffle=True, num_workers=0)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

log_var_a=log_var_a.to(device)
log_var_b=log_var_b.to(device)

batch_size= 16

# track individual losses
height_loss= []
tmask_loss= []

# track individual metrics
height_mae= []
tmask_iou= []

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90483741803) 
early_stopper = EarlyStopper(patience=8, min_delta=0) # stop early if training loss does not improve after 10 epochs

train_time_start_on_cpu = timer()
# Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} out of {epochs}")
    train_loss = 0
    
    # Loop through training batch data
    for i_batch, sample_batched in enumerate(my_dataloader_train):
        # Optimizer zero grad
        optimizer.zero_grad()
        
        X,Y= sample_batched
        X= X.to(device)
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())

        unet_1.train()
        # Forward Pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        # Calc loss (per batch) 
        loss = loss_criterion([pred_tree_height.squeeze(), pred_tree_mask.squeeze()],
                             [Y[0],Y[1]],
                             [log_var_a, log_var_b])
        
        train_loss += loss.item() # accumulate train loss
        
        # perform backpropagation on the loss
        loss.backward()
        
        # performm gradient descent
        optimizer.step()
        
        if i_batch % 200 == 0:
            precision1 = torch.exp(-log_var_a)
            diff1 = mse(pred_tree_height.squeeze(), Y[0])
            th_loss = torch.sum(precision1 * diff1 + log_var_a, -1)
        
            precision2 = torch.exp(-log_var_b)
            diff2 = bce_loss(pred_tree_mask.squeeze(), Y[1])
            tm_loss = torch.sum(precision2 * diff2 + log_var_b, -1)
        
            
            print(f"Batch {i_batch+1} out of {len(my_dataloader_train)} completed.", loss.item(),th_loss.item(),tm_loss.item())
            height_loss.append(th_loss.item())
            tmask_loss.append(tm_loss.item())

    # Divide total train loss by length of dataloader
    train_loss /= (len(my_dataset_train)/batch_size)
        
    ### Testing
    test_loss, test_height_mae, test_tree_iou = 0,0,0
    
    unet_1.eval()
    with torch.inference_mode():
        for i_batch, sample_batched in enumerate(my_dataloader_test):
            X,Y= sample_batched
            X= X.to(device)
            
            Y[0]=Y[0].to(device) # probably a better way to do this....
            Y[1]=Y[1].to(device)
            
            X = Variable(X.float().cuda())
            Y[0] = Variable(Y[0].float().cuda())
            Y[1] = Variable(Y[1].float().cuda())

            
            # Forward pass
            pred_tree_height, pred_tree_mask = unet_1(X)
            
            # loss accumulate
            test_loss += loss_criterion([pred_tree_height.squeeze(), pred_tree_mask.squeeze()],
                                     [Y[0],Y[1].squeeze()],
                                      [log_var_a, log_var_b]).item()

            test_loss += loss.item() # accumulate train loss
            
            #Tree Height MAE
            test_height_mae += mean_absolute_error(torch.squeeze(pred_tree_height),Y[0].squeeze())
            
            #Tree Mask IOU
            test_tree_iou += iou_score(torch.squeeze(pred_tree_mask),Y[1].squeeze().type(torch.LongTensor).to(device))
            
                
        # get loss per batch
        test_loss /= (len(my_dataset_test)/batch_size)
         
        # get mae per batch
        test_height_mae /= (len(my_dataset_test)/batch_size)
        
        # get iou1 per batch
        test_tree_iou /= (len(my_dataset_test)/batch_size)
        
        
        # save metrics
        height_mae.append(test_height_mae.item())
        tmask_iou.append(test_tree_iou.item())

    lr=optimizer.param_groups[0]["lr"]    
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Tree Height MAE: {test_height_mae:.5f} | Test Tree Mask IOU: {test_tree_iou:.5f} | Learning Rate: {lr:.10f}")
    save_best_model(test_loss, epoch, unet_1)
    if epoch>1:
        scheduler.step()  # every 10 decay learning rate
    
    if early_stopper.early_stop(test_loss):             
        break

train_time_end_on_cpu = timer()    
total_train_time_on_cpu= print_train_time(start=train_time_start_on_cpu,
                                          end=train_time_end_on_cpu,
                                          device=str(next(unet_1.parameters()).device))

In [None]:
# Load in model
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_mtloss_allshared.pt"))

In [None]:
%%time
# caluclate Tree Height MAE
mean_absolute_error = MeanAbsoluteError(nan_strategy='ignore').to(device)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
test_height_mae = []

batch_size=16
unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        
        # mask predicted tree height with the tree mask
        pred_tree_mask = custom_replace(pred_tree_mask, .4)
        pred_tree_height[pred_tree_height  < 0 ] = 0
        
        pred_tree_height = torch.squeeze(pred_tree_height)*torch.squeeze(pred_tree_mask) #0s get rid of non tree pixels
        
        actual_tree_height= Y[0]*torch.squeeze(pred_tree_mask)
        
        #Height MAE
        test_height_mae.append(mean_absolute_error(pred_tree_height,actual_tree_height).item())
        

# with torch.inference_mode():
#     # Get Average Height MAE
#     test_height_mae /= (len(my_dataset_test)/batch_size)
print("Tree Height MAE=",max(test_height_mae),min(test_height_mae),len(test_height_mae),sum(test_height_mae) / len(test_height_mae))

In [None]:
%%time
# caluclate Tree Mask IoU at various thresholds
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

iou_score05 = BinaryJaccardIndex(threshold=.05).to(device)
iou_score1 = BinaryJaccardIndex(threshold=.1).to(device)
iou_score15 = BinaryJaccardIndex(threshold=.15).to(device)
iou_score2 = BinaryJaccardIndex(threshold=.2).to(device)
iou_score25 = BinaryJaccardIndex(threshold=.25).to(device)
iou_score3 = BinaryJaccardIndex(threshold=.3).to(device)
iou_score35 = BinaryJaccardIndex(threshold=.35).to(device)
iou_score4 = BinaryJaccardIndex(threshold=.4).to(device)
iou_score45 = BinaryJaccardIndex(threshold=.45).to(device)
iou_score5 = BinaryJaccardIndex(threshold=.5).to(device)
iou_score55 = BinaryJaccardIndex(threshold=.55).to(device)
iou_score6 = BinaryJaccardIndex(threshold=.6).to(device)
iou_score65 = BinaryJaccardIndex(threshold=.65).to(device)
iou_score7 = BinaryJaccardIndex(threshold=.7).to(device)
iou_score75 = BinaryJaccardIndex(threshold=.75).to(device)
iou_score8 = BinaryJaccardIndex(threshold=.8).to(device)
iou_score85 = BinaryJaccardIndex(threshold=.85).to(device)
iou_score9 = BinaryJaccardIndex(threshold=.9).to(device)
iou_score95 = BinaryJaccardIndex(threshold=.95).to(device)
test_tree_iou05,test_tree_iou1,test_tree_iou15,test_tree_iou2,test_tree_iou25,test_tree_iou3,test_tree_iou35 = 0,0,0,0,0,0,0
test_tree_iou4,test_tree_iou45,test_tree_iou5,test_tree_iou55,test_tree_iou6,test_tree_iou65,test_tree_iou7 = 0,0,0,0,0,0,0
test_tree_iou75,test_tree_iou8,test_tree_iou85,test_tree_iou9 ,test_tree_iou95 = 0,0,0,0,0


unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        #Tree Mask IOU
        test_tree_iou05 += iou_score05(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou1  += iou_score1(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou15 += iou_score15(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou2  += iou_score2(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou25 += iou_score25(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou3  += iou_score3(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou35 += iou_score35(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou4  += iou_score4(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou45 += iou_score45(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou5  += iou_score5(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou55 += iou_score55(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou6  += iou_score6(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou65 += iou_score65(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou7  += iou_score7(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou75 += iou_score75(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou8  += iou_score8(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou85 += iou_score85(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou9  += iou_score9(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou95 += iou_score95(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        
    
# get iou1 per batch
test_tree_iou05 = test_tree_iou05/(len(my_dataset_test)/batch_size)
test_tree_iou1  = test_tree_iou1/(len(my_dataset_test)/batch_size)
test_tree_iou15 = test_tree_iou15/(len(my_dataset_test)/batch_size)
test_tree_iou2  = test_tree_iou2/(len(my_dataset_test)/batch_size)
test_tree_iou25 = test_tree_iou25/(len(my_dataset_test)/batch_size)
test_tree_iou3  = test_tree_iou3/(len(my_dataset_test)/batch_size)
test_tree_iou35 = test_tree_iou35/(len(my_dataset_test)/batch_size)
test_tree_iou4  = test_tree_iou4/(len(my_dataset_test)/batch_size)
test_tree_iou45 = test_tree_iou45/(len(my_dataset_test)/batch_size)
test_tree_iou5  = test_tree_iou5/(len(my_dataset_test)/batch_size)
test_tree_iou55 = test_tree_iou55/(len(my_dataset_test)/batch_size)
test_tree_iou6  = test_tree_iou6/(len(my_dataset_test)/batch_size)
test_tree_iou65 = test_tree_iou65/(len(my_dataset_test)/batch_size)
test_tree_iou7  = test_tree_iou7/(len(my_dataset_test)/batch_size)
test_tree_iou75 = test_tree_iou75/(len(my_dataset_test)/batch_size)
test_tree_iou8  = test_tree_iou8/(len(my_dataset_test)/batch_size)
test_tree_iou85 = test_tree_iou85/(len(my_dataset_test)/batch_size)
test_tree_iou9  = test_tree_iou9/(len(my_dataset_test)/batch_size)
test_tree_iou95 = test_tree_iou95/(len(my_dataset_test)/batch_size)

    
print("Tree IoU .05=",test_tree_iou05)
print("Tree IoU .1 =",test_tree_iou1 )
print("Tree IoU .15=",test_tree_iou15)
print("Tree IoU .2 =",test_tree_iou2 )
print("Tree IoU .25=",test_tree_iou25)
print("Tree IoU .3 =",test_tree_iou3 )
print("Tree IoU .35=",test_tree_iou35)
print("Tree IoU .4 =",test_tree_iou4 )
print("Tree IoU .45=",test_tree_iou45)
print("Tree IoU .5 =",test_tree_iou5 )
print("Tree IoU .55=",test_tree_iou55)
print("Tree IoU .6 =",test_tree_iou6 )
print("Tree IoU .65=",test_tree_iou65)
print("Tree IoU .7 =",test_tree_iou7 )
print("Tree IoU .75=",test_tree_iou75)
print("Tree IoU .8 =",test_tree_iou8 )
print("Tree IoU .85=",test_tree_iou85)
print("Tree IoU .9 =",test_tree_iou9 )
print("Tree IoU .95=",test_tree_iou95)

In [None]:
# Plot Loss, MAE, IoU

plt.subplot(1, 2, 1)
plt.plot(range(1,30001,345),height_loss,"rx-",label="tree height loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.tight_layout()
plt.legend(loc="upper right")

plt.subplot(1, 2, 2)
plt.plot(range(1,30001,345),tmask_loss,"bx-",label="tree mask loss")
plt.xlabel("iteration")
plt.ylabel("loss")

plt.legend(loc="upper right")
plt.show()

plt.plot(range(1,30,1),height_mae,"m--",label="Tree height MAE")
plt.xlabel("epoch")
plt.legend(loc="upper right")

plt.show()

plt.plot(range(1,30,1),tmask_iou,"m--",label="Tree Mask IoU")
plt.xlabel("epoch")
plt.legend(loc="upper right")

plt.show()
#tmask_loss
#len(tnvdi_loss)

# tota_loss = np.add(np.array(height_loss[0::5]), np.array(tmask_loss[0::5]))
# tota_loss = np.add(tota_loss,np.array(tnvdi_loss[0::5]))
tota_loss = np.add(np.array(height_loss), np.array(tmask_loss))

plt.plot(range(1,30001,345),tota_loss,"m--",label="total loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.legend(loc="upper right")
plt.show()



# Single Task Height

In [None]:
# DEFINE UNET MODEL 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_fullyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )  
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    def double_conv2(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            ## - looks to be the best solution, and appears to match keras code
            self.upsample1 = nn.ConvTranspose2d(256, 128, 2, stride=3, padding=0, output_padding=1)
            self.upsample2 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)            
            self.upsample3 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3 = double_conv2(128+128, 128)
            self.dconv_up2 = double_conv2(64 + 64, 64)
            self.dconv_up1 = double_conv2(32 + 32, 32)

            # Need 3 separate layers, otherwise they are all based on the same weight
            self.conv_last1 = nn.Conv2d(32, 1, 1)
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
            
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            x = self.dconv_down4(x)
            #x = self.maxpool(conv4)

            x = self.upsample1(x)
            x = torch.cat([x, conv3], dim=1)
            x = self.dconv_up3(x)

            x = self.upsample2(x)
            x = torch.cat([x, conv2], dim=1)
            x = self.dconv_up2(x)

            x = self.upsample3(x)
            x = torch.cat([x, conv1], dim=1)
            x = self.dconv_up1(x)


            out_tree_height = self.conv_last1(x) # looks like i don't need any additional activation here for linear

            
            return [out_tree_height]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1, summary(unet_1,input_size= (1,14,240,240))


In [None]:
# Set up loss/optimizer/metrics
mse = nn.MSELoss()
optimizer = optim.Adam(params=unet_1.parameters(),
                     lr=.001)

# Metrics
mean_absolute_error = MeanAbsoluteError().to(device)

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save(model.state_dict(), 'models/pytorch_paper_final/pytorch_single_treeheight.pt')

save_best_model = SaveBestModel()

In [None]:
# Run a Train/Test Loop now
# Build the training Loop (and a testing loop)
torch.manual_seed(42)
epochs = 50

# Instatiate datasets/loaders
my_dataset_train = allbands_dataset_train(filelist_train=filelist_train)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)

my_dataloader_train = DataLoader(my_dataset_train, batch_size=16,shuffle=True, num_workers=0)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
batch_size= 16

# track individual losses
height_loss = []
height_mae = []

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90483741803) 
early_stopper = EarlyStopper(patience=7, min_delta=0) # stop early if training loss does not improve after 10 epochs

train_time_start_on_cpu = timer()
# Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} out of {epochs}")
    train_loss = 0
    
    # Loop through training batch data
    for i_batch, sample_batched in enumerate(my_dataloader_train):
        # Optimizer zero grad
        optimizer.zero_grad()
        
        X,Y= sample_batched
        X= X.to(device)
        Y[0]=Y[0].to(device) # probably a better way to do this....
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())

        unet_1.train()
        # Forward Pass
        pred_tree_height = unet_1(X)
        
        # Calc loss (per batch) 
        loss= mse(pred_tree_height[0].squeeze(), Y[0])

        train_loss += loss.item() # accumulate train loss

        # perform backpropagation on the loss
        loss.backward()
        
        # performm gradient descent
        optimizer.step()
        
        if i_batch % 200 == 0:
            print(f"Batch {i_batch+1} out of {len(my_dataloader_train)} completed.", loss.item())
            height_loss.append(loss.item())


    # Divide total train loss by length of dataloader
    train_loss /= (len(my_dataset_train)/batch_size)
        
    ### Testing
    test_loss, test_height_mae = 0,0
    
    unet_1.eval()
    with torch.inference_mode():
        for i_batch, sample_batched in enumerate(my_dataloader_test):
            X,Y= sample_batched
            X= X.to(device)
            
            Y[0]=Y[0].to(device) # probably a better way to do this....
            
            X = Variable(X.float().cuda())
            Y[0] = Variable(Y[0].float().cuda())
            
            # Forward pass
            pred_tree_height= unet_1(X)
            
            # loss accumulate
            test_loss= mse(pred_tree_height[0].squeeze(), Y[0])


            test_loss += loss.item() # accumulate train loss
            
            #Tree Height MAE
            test_height_mae += mean_absolute_error(torch.squeeze(pred_tree_height[0]),Y[0])
            
                
        # get loss per batch
        test_loss /= (len(my_dataset_test)/batch_size)
         
        # get mae per batch
        test_height_mae /= (len(my_dataset_test)/batch_size)
        height_mae.append(test_height_mae.item())
    lr=optimizer.param_groups[0]["lr"]        
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Tree Height MAE: {test_height_mae:.5f} | Learning Rate: {lr:.10f}")
    save_best_model(test_loss, epoch, unet_1)
    if epoch>1:
        scheduler.step()  # every 10 decay learning rate
    if early_stopper.early_stop(test_loss):             
        break

train_time_end_on_cpu = timer()    
total_train_time_on_cpu= print_train_time(start=train_time_start_on_cpu,
                                          end=train_time_end_on_cpu,
                                          device=str(next(unet_1.parameters()).device))

In [None]:
# load best model
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_single_treeheight.pt"))

In [None]:
# Plot Loss/MAE
plt.plot(range(1,51,1),height_loss,"rx-",label="tree height loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc="upper right")
plt.show()

plt.plot(range(1,76,1),height_mae,"m--",label="mae")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc="upper right")

# Single Task Tree Mask

In [None]:
# DEFINE UNET MODEL 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_fullyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )    
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    def double_conv2(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            ## - looks to be the best solution, and appears to match keras code
            self.upsample1 = nn.ConvTranspose2d(256, 128, 2, stride=3, padding=0, output_padding=1)
            self.upsample2 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)            
            self.upsample3 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3 = double_conv2(128+128, 128)
            self.dconv_up2 = double_conv2(64 + 64, 64)
            self.dconv_up1 = double_conv2(32 + 32, 32)

            # Need 3 separate layers, otherwise they are all based on the same weight
            self.conv_last2 = nn.Conv2d(32, 1, 1)
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
            
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            x = self.dconv_down4(x)
            #x = self.maxpool(conv4)

            x = self.upsample1(x)
            x = torch.cat([x, conv3], dim=1)
            x = self.dconv_up3(x)

            x = self.upsample2(x)
            x = torch.cat([x, conv2], dim=1)
            x = self.dconv_up2(x)

            x = self.upsample3(x)
            x = torch.cat([x, conv1], dim=1)
            x = self.dconv_up1(x)

            out_tree_mask = self.sigmoid(self.conv_last2(x))        
            
            return [out_tree_mask]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1, summary(unet_1,input_size= (16,14,240,240))
### NOTES
# Everything lines up with my keras model...

In [None]:
# Set up loss/optimizer/metrics
mse = nn.MSELoss()
bce_loss = nn.BCELoss()

iou_score = BinaryJaccardIndex().to(device)

optimizer = optim.Adam(params=unet_1.parameters(),
                     lr=.001)

# Metrics
mean_absolute_error = MeanAbsoluteError().to(device)

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save(model.state_dict(), 'models/pytorch_paper_final/pytorch_single_treemask.pt')

save_best_model = SaveBestModel()

In [None]:
# Run a Train/Test Loop now
# Build the training Loop (and a testing loop)
torch.manual_seed(42)
epochs = 50

# Instatiate datasets/loaders
my_dataset_train = allbands_dataset_train(filelist_train=filelist_train)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)

my_dataloader_train = DataLoader(my_dataset_train, batch_size=16,shuffle=True, num_workers=0)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
batch_size= 16

# track individual losses
tmask_loss = []
mask_iou = []
# Loss function
bce_loss = nn.BCELoss()

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90483741803) 
early_stopper = EarlyStopper(patience=10, min_delta=0) # stop early if training loss does not improve after 10 epochs
best_test_loss = np.inf

train_time_start_on_cpu = timer()
# Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} out of {epochs}")
    train_loss = 0
    
    # Loop through training batch data
    for i_batch, sample_batched in enumerate(my_dataloader_train):
        # Optimizer zero grad
        optimizer.zero_grad()
        
        X,Y= sample_batched
        X= X.to(device)
        Y[1]=Y[1].to(device) # probably a better way to do this....
        
        X = Variable(X.float().cuda())
        Y[1] = Variable(Y[1].float().cuda())

        unet_1.train()
        # Forward Pass
        pred_tree_mask = unet_1(X)
        
        # Calc loss (per batch) 
        loss= bce_loss(pred_tree_mask[0].squeeze(), Y[1])
        
        train_loss += loss.item() # accumulate train loss
        
        # perform backpropagation on the loss
        loss.backward()
        
        # performm gradient descent
        optimizer.step()
        
        if i_batch % 200 == 0:
            print(f"Batch {i_batch+1} out of {len(my_dataloader_train)} completed.", loss.item())
            tmask_loss.append(loss.item())


    # Divide total train loss by length of dataloader
    train_loss /= (len(my_dataset_train)/batch_size)
        
    ### Testing
    test_loss, test_mask_iou = 0,0
    
    unet_1.eval()
    with torch.inference_mode():
        for i_batch, sample_batched in enumerate(my_dataloader_test):
            X,Y= sample_batched
            X= X.to(device)
            
            Y[1]=Y[1].to(device) # probably a better way to do this....
            
            X = Variable(X.float().cuda())
            Y[1] = Variable(Y[1].float().cuda())
            
            # Forward pass
            pred_tree_mask= unet_1(X)
            
            # loss accumulate
            test_loss= bce_loss(pred_tree_mask[0].squeeze(), Y[1])

            test_loss += loss.item() # accumulate train loss
            
            #Tree Mask IoU
            test_mask_iou += iou_score(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
            mask_iou.append(test_mask_iou.item())
                
        # get loss per batch
        test_loss /= (len(my_dataset_test)/batch_size)
         
        # get mae per batch
        test_mask_iou /= (len(my_dataset_test)/batch_size)
        
    lr=optimizer.param_groups[0]["lr"]        
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Tree Mask IOU: {test_mask_iou:.2f} | Learning Rate: {lr:.10f}")
    save_best_model(test_loss, epoch, unet_1)
    if epoch>1:
        scheduler.step() 
    if early_stopper.early_stop(test_loss):             
        break

        
train_time_end_on_cpu = timer()    
total_train_time_on_cpu= print_train_time(start=train_time_start_on_cpu,
                                          end=train_time_end_on_cpu,
                                          device=str(next(unet_1.parameters()).device))

In [None]:
# Plot Loss/IoU

plt.plot(range(1,166,1),tmask_loss,"rx-",label="tree mask loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc="upper right")
plt.show()

plt.plot(range(1,51,1),mask_iou[0::132],"m--",label="tree mask iou")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc="upper right")

In [None]:
# Load in best model
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_single_treemask.pt"))

In [None]:
%%time
# Test out different thresholds to see which gives the best IoU
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

iou_score05 = BinaryJaccardIndex(threshold=.05).to(device)
iou_score1 = BinaryJaccardIndex(threshold=.1).to(device)
iou_score15 = BinaryJaccardIndex(threshold=.15).to(device)
iou_score2 = BinaryJaccardIndex(threshold=.2).to(device)
iou_score25 = BinaryJaccardIndex(threshold=.25).to(device)
iou_score3 = BinaryJaccardIndex(threshold=.3).to(device)
iou_score35 = BinaryJaccardIndex(threshold=.35).to(device)
iou_score4 = BinaryJaccardIndex(threshold=.4).to(device)
iou_score45 = BinaryJaccardIndex(threshold=.45).to(device)
iou_score5 = BinaryJaccardIndex(threshold=.5).to(device)
iou_score55 = BinaryJaccardIndex(threshold=.55).to(device)
iou_score6 = BinaryJaccardIndex(threshold=.6).to(device)
iou_score65 = BinaryJaccardIndex(threshold=.65).to(device)
iou_score7 = BinaryJaccardIndex(threshold=.7).to(device)
iou_score75 = BinaryJaccardIndex(threshold=.75).to(device)
iou_score8 = BinaryJaccardIndex(threshold=.8).to(device)
iou_score85 = BinaryJaccardIndex(threshold=.85).to(device)
iou_score9 = BinaryJaccardIndex(threshold=.9).to(device)
iou_score95 = BinaryJaccardIndex(threshold=.95).to(device)
test_tree_iou05,test_tree_iou1,test_tree_iou15,test_tree_iou2,test_tree_iou25,test_tree_iou3,test_tree_iou35 = 0,0,0,0,0,0,0
test_tree_iou4,test_tree_iou45,test_tree_iou5,test_tree_iou55,test_tree_iou6,test_tree_iou65,test_tree_iou7 = 0,0,0,0,0,0,0
test_tree_iou75,test_tree_iou8,test_tree_iou85,test_tree_iou9 ,test_tree_iou95 = 0,0,0,0,0


unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # Forward pass
        pred_tree_mask = unet_1(X)
        
        #Tree Mask IOU
        test_tree_iou05 += iou_score05(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou1  += iou_score1(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou15 += iou_score15(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou2  += iou_score2(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou25 += iou_score25(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou3  += iou_score3(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou35 += iou_score35(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou4  += iou_score4(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou45 += iou_score45(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou5  += iou_score5(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou55 += iou_score55(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou6  += iou_score6(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou65 += iou_score65(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou7  += iou_score7(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou75 += iou_score75(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou8  += iou_score8(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou85 += iou_score85(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou9  += iou_score9(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou95 += iou_score95(pred_tree_mask[0].squeeze(),Y[1].type(torch.LongTensor).to(device))
        
    
# get iou1 per batch
test_tree_iou05 = test_tree_iou05/(len(my_dataset_test)/batch_size)
test_tree_iou1  = test_tree_iou1/(len(my_dataset_test)/batch_size)
test_tree_iou15 = test_tree_iou15/(len(my_dataset_test)/batch_size)
test_tree_iou2  = test_tree_iou2/(len(my_dataset_test)/batch_size)
test_tree_iou25 = test_tree_iou25/(len(my_dataset_test)/batch_size)
test_tree_iou3  = test_tree_iou3/(len(my_dataset_test)/batch_size)
test_tree_iou35 = test_tree_iou35/(len(my_dataset_test)/batch_size)
test_tree_iou4  = test_tree_iou4/(len(my_dataset_test)/batch_size)
test_tree_iou45 = test_tree_iou45/(len(my_dataset_test)/batch_size)
test_tree_iou5  = test_tree_iou5/(len(my_dataset_test)/batch_size)
test_tree_iou55 = test_tree_iou55/(len(my_dataset_test)/batch_size)
test_tree_iou6  = test_tree_iou6/(len(my_dataset_test)/batch_size)
test_tree_iou65 = test_tree_iou65/(len(my_dataset_test)/batch_size)
test_tree_iou7  = test_tree_iou7/(len(my_dataset_test)/batch_size)
test_tree_iou75 = test_tree_iou75/(len(my_dataset_test)/batch_size)
test_tree_iou8  = test_tree_iou8/(len(my_dataset_test)/batch_size)
test_tree_iou85 = test_tree_iou85/(len(my_dataset_test)/batch_size)
test_tree_iou9  = test_tree_iou9/(len(my_dataset_test)/batch_size)
test_tree_iou95 = test_tree_iou95/(len(my_dataset_test)/batch_size)

    
print("Tree IoU .05=",test_tree_iou05)
print("Tree IoU .1 =",test_tree_iou1 )
print("Tree IoU .15=",test_tree_iou15)
print("Tree IoU .2 =",test_tree_iou2 )
print("Tree IoU .25=",test_tree_iou25)
print("Tree IoU .3 =",test_tree_iou3 )
print("Tree IoU .35=",test_tree_iou35)
print("Tree IoU .4 =",test_tree_iou4 )
print("Tree IoU .45=",test_tree_iou45)
print("Tree IoU .5 =",test_tree_iou5 )
print("Tree IoU .55=",test_tree_iou55)
print("Tree IoU .6 =",test_tree_iou6 )
print("Tree IoU .65=",test_tree_iou65)
print("Tree IoU .7 =",test_tree_iou7 )
print("Tree IoU .75=",test_tree_iou75)
print("Tree IoU .8 =",test_tree_iou8 )
print("Tree IoU .85=",test_tree_iou85)
print("Tree IoU .9 =",test_tree_iou9 )
print("Tree IoU .95=",test_tree_iou95)

# Muti-task Manual Loss- All Shared

In [None]:
# DEFINE UNET MODEL 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_fullyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )  
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    def double_conv2(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding="same"),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            ## - looks to be the best solution, and appears to match keras code
            self.upsample1 = nn.ConvTranspose2d(256, 128, 2, stride=3, padding=0, output_padding=1)
            self.upsample2 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)            
            self.upsample3 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3 = double_conv2(128+128, 128)
            self.dconv_up2 = double_conv2(64 + 64, 64)
            self.dconv_up1 = double_conv2(32 + 32, 32)

            # Need 3 separate layers, otherwise they are all based on the same weight
            self.conv_last1 = nn.Conv2d(32, 1, 1)
            self.conv_last2 = nn.Conv2d(32, 1, 1)
            self.conv_last3 = nn.Conv2d(32, 1, 1)
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
            
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            x = self.dconv_down4(x)
            #x = self.maxpool(conv4)

            x = self.upsample1(x)
            x = torch.cat([x, conv3], dim=1)
            x = self.dconv_up3(x)

            x = self.upsample2(x)
            x = torch.cat([x, conv2], dim=1)
            x = self.dconv_up2(x)

            x = self.upsample3(x)
            x = torch.cat([x, conv1], dim=1)
            x = self.dconv_up1(x)


            out_tree_height = self.conv_last1(x) # looks like i don't need any additional activation here for linear
            out_tree_mask = self.sigmoid(self.conv_last2(x))        
            
            return [out_tree_height, out_tree_mask]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1, summary(unet_1,input_size= (1,14,240,240))


In [None]:
mse = nn.MSELoss()
bce_loss = nn.BCELoss()

# Set up loss/optimizer/metrics
# Opitimiser
optimizer = optim.Adam(params=unet_1.parameters(),
                     lr=.001)

# Metrics
mean_absolute_error = MeanAbsoluteError().to(device)
iou_score = BinaryJaccardIndex().to(device)

len(filelist_train),len(filelist_test)

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save(model.state_dict(), 'models/pytorch_paper_final/pytorch_mtloss_allshared_manual.pt')

save_best_model = SaveBestModel()

In [None]:
# Run a Train/Test Loop now
# Build the training Loop (and a testing loop)
torch.manual_seed(42)
epochs = 50

# Instatiate datasets/loaders
my_dataset_train = allbands_dataset_train(filelist_train=filelist_train)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)

my_dataloader_train = DataLoader(my_dataset_train, batch_size=16,shuffle=True, num_workers=0)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

batch_size= 16

# track individual losses
height_loss= []
tmask_loss= []

# track individual metrics
height_mae= []
tmask_iou= []

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90483741803) 
early_stopper = EarlyStopper(patience=7, min_delta=0) # stop early if training loss does not improve after 10 epochs

train_time_start_on_cpu = timer()
# Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} out of {epochs}")
    train_loss = 0
    
    # Loop through training batch data
    for i_batch, sample_batched in enumerate(my_dataloader_train):
        # Optimizer zero grad
        optimizer.zero_grad()
        
        X,Y= sample_batched
        X= X.to(device)
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)

        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())

        unet_1.train()
        # Forward Pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        # Calc loss (per batch)         
        # CCAI Loss Weighting to replicate findings
        loss_1= mse(pred_tree_height.squeeze(), Y[0])
        loss_2= bce_loss(pred_tree_mask.squeeze(), Y[1])
        loss = (loss_1*.6) + (loss_2*.5) 
        
        train_loss += loss.item() # accumulate train loss
        
        # perform backpropagation on the loss
        loss.backward()
        
        # performm gradient descent
        optimizer.step()
        
        if i_batch % 200 == 0:
            print(f"Batch {i_batch+1} out of {len(my_dataloader_train)} completed.", loss.item(),loss_1.item(),loss_2.item())
            height_loss.append(loss_1.item())
            tmask_loss.append(loss_2.item())

    # Divide total train loss by length of dataloader
    train_loss /= (len(my_dataset_train)/batch_size)
        
    ### Testing
    test_loss, test_height_mae, test_tree_iou = 0,0,0
    
    unet_1.eval()
    with torch.inference_mode():
        for i_batch, sample_batched in enumerate(my_dataloader_test):
            X,Y= sample_batched
            X= X.to(device)
            
            Y[0]=Y[0].to(device) # probably a better way to do this....
            Y[1]=Y[1].to(device)

            X = Variable(X.float().cuda())
            Y[0] = Variable(Y[0].float().cuda())
            Y[1] = Variable(Y[1].float().cuda())

            
            # Forward pass
            pred_tree_height, pred_tree_mask = unet_1(X)
            
            # loss accumulate
            
            # CCAI Loss Weighting to replicate findings
            loss_1= mse(pred_tree_height.squeeze(), Y[0])
            loss_2= bce_loss(pred_tree_mask.squeeze(), Y[1])
            loss = (loss_1*.6) + (loss_2*.5)

            test_loss += loss.item() # accumulate train loss
            
            #Tree Height MAE
            test_height_mae += mean_absolute_error(torch.squeeze(pred_tree_height),Y[0])
            
            #Tree Mask IOU
            test_tree_iou += iou_score(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
            
        # get loss per batch
        test_loss /= (len(my_dataset_test)/batch_size)
         
        # get mae per batch
        test_height_mae /= (len(my_dataset_test)/batch_size)
        
        # get iou1 per batch
        test_tree_iou /= (len(my_dataset_test)/batch_size)
        
        
        # save metrics
        height_mae.append(test_height_mae.item())
        tmask_iou.append(test_tree_iou.item())

    lr=optimizer.param_groups[0]["lr"]        
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Tree Height MAE: {test_height_mae:.5f} | Test Tree Mask IOU: {test_tree_iou:.5f} | Learning Rate: {lr:.10f}")
    save_best_model(test_loss, epoch, unet_1)
    if epoch>1:
        scheduler.step()  # every 10 decay learning rate
    if early_stopper.early_stop(test_loss):             
        break
        
        
train_time_end_on_cpu = timer()    
total_train_time_on_cpu= print_train_time(start=train_time_start_on_cpu,
                                          end=train_time_end_on_cpu,
                                          device=str(next(unet_1.parameters()).device))

In [None]:
# Load in best model
unet_1 = defineUNetModel_fullyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_mtloss_allshared_manual.pt"))

In [None]:
%%time
# Calc MAE
mean_absolute_error = MeanAbsoluteError(nan_strategy='ignore').to(device)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
test_height_mae = []

batch_size=16
unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        
        # mask predicted tree height with the tree mask
        pred_tree_mask = custom_replace(pred_tree_mask, .4)
        pred_tree_height[pred_tree_height  < 0 ] = 0
        
        pred_tree_height = torch.squeeze(pred_tree_height)*torch.squeeze(pred_tree_mask) #0s get rid of non tree pixels

        
        # now i only want to compare Y[0]
        actual_tree_height= Y[0]*torch.squeeze(pred_tree_mask)

        #Height MAE
        test_height_mae.append(mean_absolute_error(pred_tree_height,actual_tree_height).item())
        

# with torch.inference_mode():
#     # Get Average Height MAE
#     test_height_mae /= (len(my_dataset_test)/batch_size)
print("Tree Height MAE=",max(test_height_mae),min(test_height_mae),len(test_height_mae),sum(test_height_mae) / len(test_height_mae))

In [None]:
%%time
# Test various thresholds for IoU
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

iou_score05 = BinaryJaccardIndex(threshold=.05).to(device)
iou_score1 = BinaryJaccardIndex(threshold=.1).to(device)
iou_score15 = BinaryJaccardIndex(threshold=.15).to(device)
iou_score2 = BinaryJaccardIndex(threshold=.2).to(device)
iou_score25 = BinaryJaccardIndex(threshold=.25).to(device)
iou_score3 = BinaryJaccardIndex(threshold=.3).to(device)
iou_score35 = BinaryJaccardIndex(threshold=.35).to(device)
iou_score4 = BinaryJaccardIndex(threshold=.4).to(device)
iou_score45 = BinaryJaccardIndex(threshold=.45).to(device)
iou_score5 = BinaryJaccardIndex(threshold=.5).to(device)
iou_score55 = BinaryJaccardIndex(threshold=.55).to(device)
iou_score6 = BinaryJaccardIndex(threshold=.6).to(device)
iou_score65 = BinaryJaccardIndex(threshold=.65).to(device)
iou_score7 = BinaryJaccardIndex(threshold=.7).to(device)
iou_score75 = BinaryJaccardIndex(threshold=.75).to(device)
iou_score8 = BinaryJaccardIndex(threshold=.8).to(device)
iou_score85 = BinaryJaccardIndex(threshold=.85).to(device)
iou_score9 = BinaryJaccardIndex(threshold=.9).to(device)
iou_score95 = BinaryJaccardIndex(threshold=.95).to(device)
test_tree_iou05,test_tree_iou1,test_tree_iou15,test_tree_iou2,test_tree_iou25,test_tree_iou3,test_tree_iou35 = 0,0,0,0,0,0,0
test_tree_iou4,test_tree_iou45,test_tree_iou5,test_tree_iou55,test_tree_iou6,test_tree_iou65,test_tree_iou7 = 0,0,0,0,0,0,0
test_tree_iou75,test_tree_iou8,test_tree_iou85,test_tree_iou9 ,test_tree_iou95 = 0,0,0,0,0


unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        #Tree Mask IOU
        test_tree_iou05 += iou_score05(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou1  += iou_score1(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou15 += iou_score15(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou2  += iou_score2(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou25 += iou_score25(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou3  += iou_score3(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou35 += iou_score35(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou4  += iou_score4(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou45 += iou_score45(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou5  += iou_score5(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou55 += iou_score55(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou6  += iou_score6(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou65 += iou_score65(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou7  += iou_score7(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou75 += iou_score75(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou8  += iou_score8(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou85 += iou_score85(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou9  += iou_score9(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou95 += iou_score95(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        
    
# get iou1 per batch
test_tree_iou05 = test_tree_iou05/(len(my_dataset_test)/batch_size)
test_tree_iou1  = test_tree_iou1/(len(my_dataset_test)/batch_size)
test_tree_iou15 = test_tree_iou15/(len(my_dataset_test)/batch_size)
test_tree_iou2  = test_tree_iou2/(len(my_dataset_test)/batch_size)
test_tree_iou25 = test_tree_iou25/(len(my_dataset_test)/batch_size)
test_tree_iou3  = test_tree_iou3/(len(my_dataset_test)/batch_size)
test_tree_iou35 = test_tree_iou35/(len(my_dataset_test)/batch_size)
test_tree_iou4  = test_tree_iou4/(len(my_dataset_test)/batch_size)
test_tree_iou45 = test_tree_iou45/(len(my_dataset_test)/batch_size)
test_tree_iou5  = test_tree_iou5/(len(my_dataset_test)/batch_size)
test_tree_iou55 = test_tree_iou55/(len(my_dataset_test)/batch_size)
test_tree_iou6  = test_tree_iou6/(len(my_dataset_test)/batch_size)
test_tree_iou65 = test_tree_iou65/(len(my_dataset_test)/batch_size)
test_tree_iou7  = test_tree_iou7/(len(my_dataset_test)/batch_size)
test_tree_iou75 = test_tree_iou75/(len(my_dataset_test)/batch_size)
test_tree_iou8  = test_tree_iou8/(len(my_dataset_test)/batch_size)
test_tree_iou85 = test_tree_iou85/(len(my_dataset_test)/batch_size)
test_tree_iou9  = test_tree_iou9/(len(my_dataset_test)/batch_size)
test_tree_iou95 = test_tree_iou95/(len(my_dataset_test)/batch_size)

    
print("Tree IoU .05=",test_tree_iou05)
print("Tree IoU .1 =",test_tree_iou1 )
print("Tree IoU .15=",test_tree_iou15)
print("Tree IoU .2 =",test_tree_iou2 )
print("Tree IoU .25=",test_tree_iou25)
print("Tree IoU .3 =",test_tree_iou3 )
print("Tree IoU .35=",test_tree_iou35)
print("Tree IoU .4 =",test_tree_iou4 )
print("Tree IoU .45=",test_tree_iou45)
print("Tree IoU .5 =",test_tree_iou5 )
print("Tree IoU .55=",test_tree_iou55)
print("Tree IoU .6 =",test_tree_iou6 )
print("Tree IoU .65=",test_tree_iou65)
print("Tree IoU .7 =",test_tree_iou7 )
print("Tree IoU .75=",test_tree_iou75)
print("Tree IoU .8 =",test_tree_iou8 )
print("Tree IoU .85=",test_tree_iou85)
print("Tree IoU .9 =",test_tree_iou9 )
print("Tree IoU .95=",test_tree_iou95)

In [None]:
# Plot all the things

plt.subplot(1, 2, 1)
plt.plot(range(1,30001,150),height_loss,"rx-",label="tree height loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.tight_layout()
plt.legend(loc="upper right")

plt.subplot(1, 2, 2)
plt.plot(range(1,30001,150),tmask_loss,"bx-",label="tree mask loss")
plt.xlabel("iteration")
plt.ylabel("loss")

plt.legend(loc="upper right")


plt.show()

plt.subplot(1, 2, 1)
plt.plot(range(1,51,1),height_mae,"m--",label="Tree height MAE")
plt.xlabel("epoch")
plt.legend(loc="upper right")

plt.subplot(1, 2, 2)
plt.plot(range(1,51,1),tmask_iou,"m--",label="Tree Mask IoU")
plt.xlabel("epoch")
plt.legend(loc="upper right")

plt.show()


# Muti-task Manual Loss- Not Shared Decoding

In [None]:
# Set up a version with shared encoding paths but different encoding paths
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_partiallyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )  
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            self.upsample1a = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1b = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1c = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            
            self.upsample2a = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2b = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2c = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            
            self.upsample3a = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3b = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3c = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3a = double_conv(128+128, 128)
            self.dconv_up3b = double_conv(128+128, 128)
            self.dconv_up3c = double_conv(128+128, 128)
            
            self.dconv_up2a = double_conv(64 + 64, 64)
            self.dconv_up2b = double_conv(64 + 64, 64)
            self.dconv_up2c = double_conv(64 + 64, 64)
            
            self.dconv_up1a = double_conv(32 + 32, 32)
            self.dconv_up1b = double_conv(32 + 32, 32)
            self.dconv_up1c = double_conv(32 + 32, 32)
           # self.dconv_up1 = double_conv(128 + 64, 64)

            self.conv_lasta = nn.Conv2d(32, 1, 1)
            self.conv_lastb = nn.Conv2d(32, 1, 1)
            self.conv_lastc = nn.Conv2d(32, 1, 1)
                    
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            encoder_end = self.dconv_down4(x)
            #x = self.maxpool(conv4)
            
            # Now the model should split into three branches

            # Tree Height
            x1 = self.upsample1a(encoder_end)
            x1 = torch.cat([x1, conv3], dim=1)
            x1 = self.dconv_up3a(x1)

            x1 = self.upsample2a(x1)
            x1 = torch.cat([x1, conv2], dim=1)
            x1 = self.dconv_up2a(x1)

            x1 = self.upsample3a(x1)
            x1 = torch.cat([x1, conv1], dim=1)
            x1 = self.dconv_up1a(x1)
            out_tree_height = self.conv_lasta(x1) # looks like i don't need any additional activation here for linear

            # Tree Mask
            x2 = self.upsample1b(encoder_end)
            x2 = torch.cat([x2, conv3], dim=1)
            x2 = self.dconv_up3b(x2)

            x2 = self.upsample2b(x2)
            x2 = torch.cat([x2, conv2], dim=1)
            x2 = self.dconv_up2b(x2)

            x2 = self.upsample3b(x2)
            x2 = torch.cat([x2, conv1], dim=1)
            x2 = self.dconv_up1b(x2)
            out_tree_mask = self.sigmoid(self.conv_lastb(x2))            
     
            
            return [out_tree_height, out_tree_mask]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_partiallyshared().to(device)
unet_1, summary(unet_1,input_size= (1,14,240,240))

In [None]:
# Set up loss/optimizer/metrics
mse = nn.MSELoss()
bce_loss = nn.BCELoss()

# Opitimiser
optimizer = optim.Adam(params=unet_1.parameters(),
                       lr=.001)
# Metrics
mean_absolute_error = MeanAbsoluteError().to(device)
iou_score = BinaryJaccardIndex().to(device)

len(filelist_train),len(filelist_test)

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save(model.state_dict(), 'models/pytorch_paper_final/pytorch_mtloss_partshared_manual.pt')

save_best_model = SaveBestModel()

In [None]:
# Run a Train/Test Loop now
# Build the training Loop (and a testing loop)
torch.manual_seed(42)
epochs = 75

# Instatiate datasets/loaders
my_dataset_train = allbands_dataset_train(filelist_train=filelist_train)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)

my_dataloader_train = DataLoader(my_dataset_train, batch_size=16,shuffle=True, num_workers=0)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
batch_size= 16

# track individual losses
height_loss= []
tmask_loss= []

# track individual metrics
height_mae= []
tmask_iou= []

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90483741803) 
early_stopper = EarlyStopper(patience=8, min_delta=0) # stop early if training loss does not improve after 10 epochs

train_time_start_on_cpu = timer()
# Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} out of {epochs}")
    train_loss = 0
    
    # Loop through training batch data
    for i_batch, sample_batched in enumerate(my_dataloader_train):
        # Optimizer zero grad
        optimizer.zero_grad()
        
        X,Y= sample_batched
        X= X.to(device)
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())

        unet_1.train()
        # Forward Pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        # Calc loss (per batch) 
        #CCAI Loss Weighting to replicate findings
        loss_1= mse(pred_tree_height.squeeze(), Y[0])
        loss_2= bce_loss(pred_tree_mask.squeeze(), Y[1])
        loss = (loss_1*.6) + (loss_2*.5)

        train_loss += loss.item() # accumulate train loss

        # perform backpropagation on the loss
        loss.backward()
        
        # performm gradient descent
        optimizer.step()
        
        if i_batch % 200 == 0:          
            print(f"Batch {i_batch+1} out of {len(my_dataloader_train)} completed.", loss.item(),loss_1.item(),loss_2.item())
            height_loss.append(loss_1.item())
            tmask_loss.append(loss_2.item())

    # Divide total train loss by length of dataloader
    train_loss /= (len(my_dataset_train)/batch_size)
        
    ### Testing
    test_loss, test_height_mae, test_tree_iou = 0,0,0
    
    unet_1.eval()
    with torch.inference_mode():
        for i_batch, sample_batched in enumerate(my_dataloader_test):
            X,Y= sample_batched
            X= X.to(device)
            
            Y[0]=Y[0].to(device) # probably a better way to do this....
            Y[1]=Y[1].to(device)
            
            X = Variable(X.float().cuda())
            Y[0] = Variable(Y[0].float().cuda())
            Y[1] = Variable(Y[1].float().cuda())
            
            # Forward pass
            pred_tree_height, pred_tree_mask = unet_1(X)
            
            # loss accumulate
            
            #CCAI Loss Weighting to replicate findings
            loss_1= mse(pred_tree_height.squeeze(), Y[0])
            loss_2= bce_loss(pred_tree_mask.squeeze(), Y[1])
            loss = (loss_1*.6) + (loss_2*.5)

            test_loss += loss.item() # accumulate train loss
            
            #Tree Height MAE
            test_height_mae += mean_absolute_error(torch.squeeze(pred_tree_height),Y[0])
            
            #Tree Mask IOU
            test_tree_iou += iou_score(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
            
            
                
        # get loss per batch
        test_loss /= (len(my_dataset_test)/batch_size)
         
        # get mae per batch
        test_height_mae /= (len(my_dataset_test)/batch_size)
        
        # get iou1 per batch
        test_tree_iou /= (len(my_dataset_test)/batch_size)
        
        
        # save metrics
        height_mae.append(test_height_mae.item())
        tmask_iou.append(test_tree_iou.item())

    lr=optimizer.param_groups[0]["lr"]        
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Tree Height MAE: {test_height_mae:.5f} | Test Tree Mask IOU: {test_tree_iou:.5f} | Learning Rate: {lr:.10f}")
    save_best_model(test_loss, epoch, unet_1)
    if epoch>1:
        scheduler.step()  # every 10 decay learning rate
    if early_stopper.early_stop(test_loss):             
        break
        
        
train_time_end_on_cpu = timer()    
total_train_time_on_cpu= print_train_time(start=train_time_start_on_cpu,
                                          end=train_time_end_on_cpu,
                                          device=str(next(unet_1.parameters()).device))

In [None]:
# Load in the best model
unet_1 = defineUNetModel_partiallyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_mtloss_partshared_manual.pt"))

In [None]:
%%time
# test various thresholds for IoU
batch=16
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

iou_score05 = BinaryJaccardIndex(threshold=.05).to(device)
iou_score1 = BinaryJaccardIndex(threshold=.1).to(device)
iou_score15 = BinaryJaccardIndex(threshold=.15).to(device)
iou_score2 = BinaryJaccardIndex(threshold=.2).to(device)
iou_score25 = BinaryJaccardIndex(threshold=.25).to(device)
iou_score3 = BinaryJaccardIndex(threshold=.3).to(device)
iou_score35 = BinaryJaccardIndex(threshold=.35).to(device)
iou_score4 = BinaryJaccardIndex(threshold=.4).to(device)
iou_score45 = BinaryJaccardIndex(threshold=.45).to(device)
iou_score5 = BinaryJaccardIndex(threshold=.5).to(device)
iou_score55 = BinaryJaccardIndex(threshold=.55).to(device)
iou_score6 = BinaryJaccardIndex(threshold=.6).to(device)
iou_score65 = BinaryJaccardIndex(threshold=.65).to(device)
iou_score7 = BinaryJaccardIndex(threshold=.7).to(device)
iou_score75 = BinaryJaccardIndex(threshold=.75).to(device)
iou_score8 = BinaryJaccardIndex(threshold=.8).to(device)
iou_score85 = BinaryJaccardIndex(threshold=.85).to(device)
iou_score9 = BinaryJaccardIndex(threshold=.9).to(device)
iou_score95 = BinaryJaccardIndex(threshold=.95).to(device)
test_tree_iou05,test_tree_iou1,test_tree_iou15,test_tree_iou2,test_tree_iou25,test_tree_iou3,test_tree_iou35 = 0,0,0,0,0,0,0
test_tree_iou4,test_tree_iou45,test_tree_iou5,test_tree_iou55,test_tree_iou6,test_tree_iou65,test_tree_iou7 = 0,0,0,0,0,0,0
test_tree_iou75,test_tree_iou8,test_tree_iou85,test_tree_iou9 ,test_tree_iou95 = 0,0,0,0,0


unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        #Tree Mask IOU
        test_tree_iou05 += iou_score05(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou1  += iou_score1(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou15 += iou_score15(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou2  += iou_score2(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou25 += iou_score25(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou3  += iou_score3(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou35 += iou_score35(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou4  += iou_score4(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou45 += iou_score45(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou5  += iou_score5(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou55 += iou_score55(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou6  += iou_score6(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou65 += iou_score65(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou7  += iou_score7(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou75 += iou_score75(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou8  += iou_score8(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou85 += iou_score85(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou9  += iou_score9(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou95 += iou_score95(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        
    
# get iou1 per batch
test_tree_iou05 = test_tree_iou05/(len(my_dataset_test)/batch_size)
test_tree_iou1  = test_tree_iou1/(len(my_dataset_test)/batch_size)
test_tree_iou15 = test_tree_iou15/(len(my_dataset_test)/batch_size)
test_tree_iou2  = test_tree_iou2/(len(my_dataset_test)/batch_size)
test_tree_iou25 = test_tree_iou25/(len(my_dataset_test)/batch_size)
test_tree_iou3  = test_tree_iou3/(len(my_dataset_test)/batch_size)
test_tree_iou35 = test_tree_iou35/(len(my_dataset_test)/batch_size)
test_tree_iou4  = test_tree_iou4/(len(my_dataset_test)/batch_size)
test_tree_iou45 = test_tree_iou45/(len(my_dataset_test)/batch_size)
test_tree_iou5  = test_tree_iou5/(len(my_dataset_test)/batch_size)
test_tree_iou55 = test_tree_iou55/(len(my_dataset_test)/batch_size)
test_tree_iou6  = test_tree_iou6/(len(my_dataset_test)/batch_size)
test_tree_iou65 = test_tree_iou65/(len(my_dataset_test)/batch_size)
test_tree_iou7  = test_tree_iou7/(len(my_dataset_test)/batch_size)
test_tree_iou75 = test_tree_iou75/(len(my_dataset_test)/batch_size)
test_tree_iou8  = test_tree_iou8/(len(my_dataset_test)/batch_size)
test_tree_iou85 = test_tree_iou85/(len(my_dataset_test)/batch_size)
test_tree_iou9  = test_tree_iou9/(len(my_dataset_test)/batch_size)
test_tree_iou95 = test_tree_iou95/(len(my_dataset_test)/batch_size)

    
print("Tree IoU .05=",test_tree_iou05)
print("Tree IoU .1 =",test_tree_iou1 )
print("Tree IoU .15=",test_tree_iou15)
print("Tree IoU .2 =",test_tree_iou2 )
print("Tree IoU .25=",test_tree_iou25)
print("Tree IoU .3 =",test_tree_iou3 )
print("Tree IoU .35=",test_tree_iou35)
print("Tree IoU .4 =",test_tree_iou4 )
print("Tree IoU .45=",test_tree_iou45)
print("Tree IoU .5 =",test_tree_iou5 )
print("Tree IoU .55=",test_tree_iou55)
print("Tree IoU .6 =",test_tree_iou6 )
print("Tree IoU .65=",test_tree_iou65)
print("Tree IoU .7 =",test_tree_iou7 )
print("Tree IoU .75=",test_tree_iou75)
print("Tree IoU .8 =",test_tree_iou8 )
print("Tree IoU .85=",test_tree_iou85)
print("Tree IoU .9 =",test_tree_iou9 )
print("Tree IoU .95=",test_tree_iou95)

In [None]:
# Plot Losses/Metrics
plt.subplot(1, 2, 1)
plt.plot(range(1,30001,203),height_loss,"rx-",label="tree height loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.tight_layout()
plt.legend(loc="upper right")

plt.subplot(1, 2, 2)
plt.plot(range(1,30001,203),tmask_loss,"bx-",label="tree mask loss")
plt.xlabel("iteration")
plt.ylabel("loss")

plt.legend(loc="upper right")




plt.show()

plt.plot(range(1,50,1),height_mae,"m--",label="Tree height MAE")
plt.xlabel("epoch")
plt.legend(loc="upper right")

plt.show()

plt.plot(range(1,50,1),tmask_iou,"m--",label="Tree Mask IoU")
plt.xlabel("epoch")
plt.legend(loc="upper left")

plt.show()

tota_loss = np.add(np.array(height_loss), np.array(tmask_loss))

plt.plot(range(1,30001,203),tota_loss,"m--",label="total loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.legend(loc="upper right")
plt.show()



# Muti-task Auto Loss- Not Shared Decoding

In [None]:
# Set up a version with shared encoding paths but different encoding paths
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_partiallyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )  
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            self.upsample1a = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1b = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1c = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            
            self.upsample2a = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2b = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2c = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            
            self.upsample3a = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3b = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3c = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3a = double_conv(128+128, 128)
            self.dconv_up3b = double_conv(128+128, 128)
            self.dconv_up3c = double_conv(128+128, 128)
            
            self.dconv_up2a = double_conv(64 + 64, 64)
            self.dconv_up2b = double_conv(64 + 64, 64)
            self.dconv_up2c = double_conv(64 + 64, 64)
            
            self.dconv_up1a = double_conv(32 + 32, 32)
            self.dconv_up1b = double_conv(32 + 32, 32)
            self.dconv_up1c = double_conv(32 + 32, 32)
           # self.dconv_up1 = double_conv(128 + 64, 64)

            self.conv_lasta = nn.Conv2d(32, 1, 1)
            self.conv_lastb = nn.Conv2d(32, 1, 1)
            self.conv_lastc = nn.Conv2d(32, 1, 1)
                    
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            encoder_end = self.dconv_down4(x)
            
            # Now the model should split into three branches

            # Tree Height
            x1 = self.upsample1a(encoder_end)
            x1 = torch.cat([x1, conv3], dim=1)
            x1 = self.dconv_up3a(x1)

            x1 = self.upsample2a(x1)
            x1 = torch.cat([x1, conv2], dim=1)
            x1 = self.dconv_up2a(x1)

            x1 = self.upsample3a(x1)
            x1 = torch.cat([x1, conv1], dim=1)
            x1 = self.dconv_up1a(x1)
            out_tree_height = self.conv_lasta(x1) # looks like i don't need any additional activation here for linear

            # Tree Mask
            x2 = self.upsample1b(encoder_end)
            x2 = torch.cat([x2, conv3], dim=1)
            x2 = self.dconv_up3b(x2)

            x2 = self.upsample2b(x2)
            x2 = torch.cat([x2, conv2], dim=1)
            x2 = self.dconv_up2b(x2)

            x2 = self.upsample3b(x2)
            x2 = torch.cat([x2, conv1], dim=1)
            x2 = self.dconv_up1b(x2)
            out_tree_mask = self.sigmoid(self.conv_lastb(x2))                   
            
            return [out_tree_height, out_tree_mask]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_partiallyshared().to(device)
unet_1, summary(unet_1,input_size= (1,14,240,240))

In [None]:
#https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb
log_var_a = torch.zeros((1,), requires_grad=True)
log_var_b = torch.zeros((1,), requires_grad=True)
mse = nn.MSELoss()
bce_loss = nn.BCELoss()

# Remade my own version of the loss function
def loss_criterion(y_pred, y_true, log_vars):
    loss = 0
    for i in range(len(y_pred)):
        precision = torch.exp(-log_vars[i])
        if i==0:
            diff = mse(y_pred[i], y_true[i])
        else:
            diff = bce_loss(y_pred[i], y_true[i])
        loss += torch.sum(precision * diff + log_vars[i], -1)
    return torch.mean(loss)

params_all = ([p for p in unet_1.parameters()] + [log_var_a] + [log_var_b])


# Set up loss/optimizer/metrics
# Opitimiser
optimizer = optim.Adam(params=params_all,
                     lr=.001)

# optimizer = optim.Adam(params=unet_1.parameters(),
#                      lr=.001)

# Metrics
mean_absolute_error = MeanAbsoluteError().to(device)
iou_score = BinaryJaccardIndex().to(device)


class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save(model.state_dict(), 'models/pytorch_paper_final/pytorch_mtloss_partshared.pt')

save_best_model = SaveBestModel()

In [None]:
# Run a Train/Test Loop now
# Build the training Loop (and a testing loop)
torch.manual_seed(42)
epochs = 50

# Instatiate datasets/loaders
my_dataset_train = allbands_dataset_train(filelist_train=filelist_train)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)

my_dataloader_train = DataLoader(my_dataset_train, batch_size=16,shuffle=True, num_workers=0)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
log_var_a=log_var_a.to(device)
log_var_b=log_var_b.to(device)

batch_size= 16

# track individual losses
height_loss= []
tmask_loss= []

# track individual metrics
height_mae= []
tmask_iou= []

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.90483741803) 
early_stopper = EarlyStopper(patience=10, min_delta=0) # stop early if training loss does not improve after 10 epochs

train_time_start_on_cpu = timer()
# Training Loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} out of {epochs}")
    train_loss = 0
    
    # Loop through training batch data
    for i_batch, sample_batched in enumerate(my_dataloader_train):
        # Optimizer zero grad
        optimizer.zero_grad()
        
        X,Y= sample_batched
        X= X.to(device)
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        unet_1.train()
        # Forward Pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        # Calc loss (per batch) 
        loss = loss_criterion([pred_tree_height.squeeze(), pred_tree_mask.squeeze()],
                             [Y[0],Y[1]],
                             [log_var_a, log_var_b])

        train_loss += loss.item() # accumulate train loss

        # perform backpropagation on the loss
        loss.backward()
        
        # performm gradient descent
        optimizer.step()
        
        if i_batch % 200 == 0:
            precision1 = torch.exp(-log_var_a)
            diff1 = mse(pred_tree_height.squeeze(), Y[0])
            th_loss = torch.sum(precision1 * diff1 + log_var_a, -1)
        
            precision2 = torch.exp(-log_var_b)
            diff2 = bce_loss(pred_tree_mask.squeeze(), Y[1])
            tm_loss = torch.sum(precision2 * diff2 + log_var_b, -1)
        
            
            print(f"Batch {i_batch+1} out of {len(my_dataloader_train)} completed.", loss.item(),th_loss.item(),tm_loss.item())
            height_loss.append(th_loss.item())
            tmask_loss.append(tm_loss.item())


    # Divide total train loss by length of dataloader
    train_loss /= (len(my_dataset_train)/batch_size)
        
    ### Testing
    test_loss, test_height_mae, test_tree_iou, test_ndvi_iou = 0,0,0,0
    
    unet_1.eval()
    with torch.inference_mode():
        for i_batch, sample_batched in enumerate(my_dataloader_test):
            X,Y= sample_batched
            X= X.to(device)
            
            Y[0]=Y[0].to(device) # probably a better way to do this....
            Y[1]=Y[1].to(device)
            
            X = Variable(X.float().cuda())
            Y[0] = Variable(Y[0].float().cuda())
            Y[1] = Variable(Y[1].float().cuda())
            
            # Forward pass
            pred_tree_height, pred_tree_mask = unet_1(X)
            
            # loss accumulate
            test_loss += loss_criterion([pred_tree_height.squeeze(), pred_tree_mask.squeeze()],
                                     [Y[0],Y[1]],
                                      [log_var_a, log_var_b]).item()
            
            test_loss += loss.item() # accumulate train loss
            
            #Tree Height MAE
            test_height_mae += mean_absolute_error(torch.squeeze(pred_tree_height),Y[0])
            
            #Tree Mask IOU
            test_tree_iou += iou_score(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
            
            
        # get loss per batch
        test_loss /= (len(my_dataset_test)/batch_size)
         
        # get mae per batch
        test_height_mae /= (len(my_dataset_test)/batch_size)
        
        # get iou1 per batch
        test_tree_iou /= (len(my_dataset_test)/batch_size)

        
        # save metrics
        height_mae.append(test_height_mae.item())
        tmask_iou.append(test_tree_iou.item())
    
    lr=optimizer.param_groups[0]["lr"]        
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Tree Height MAE: {test_height_mae:.5f} | Test Tree Mask IOU: {test_tree_iou:.5f} | Learning Rate: {lr:.10f}")
    save_best_model(test_loss, epoch, unet_1)
    if epoch>1:
        scheduler.step()  # every 10 decay learning rate
    if early_stopper.early_stop(test_loss):             
        break
        
train_time_end_on_cpu = timer()    
total_train_time_on_cpu= print_train_time(start=train_time_start_on_cpu,
                                          end=train_time_end_on_cpu,
                                          device=str(next(unet_1.parameters()).device))

In [None]:
# Load in best model
unet_1 = defineUNetModel_partiallyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_mtloss_partshared.pt"))

In [None]:
%%time
# Calc Mae
mean_absolute_error = MeanAbsoluteError(nan_strategy='ignore').to(device)
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)
test_height_mae = []

batch_size=16
unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # mask actual tree height with the tree mask
        #https://stackoverflow.com/questions/58521595/masking-tensor-of-same-shape-in-pytorch
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        
        # mask predicted tree height with the tree mask
        pred_tree_mask = custom_replace(pred_tree_mask, .4)
        pred_tree_height[pred_tree_height  < 0 ] = 0
        
        pred_tree_height = torch.squeeze(pred_tree_height)*torch.squeeze(pred_tree_mask) #0s get rid of non tree pixels
#        mask_tensor1 =pred_tree_height>0
        
        # now i only want to compare Y[0]
        #actual_tree_height= Y[0]*Y[1]
        actual_tree_height= Y[0]*torch.squeeze(pred_tree_mask)
        #mask_tensor2 =actual_tree_height>0
        
        # get values where either are identified as trees
        #mask_tensor= torch.logical_and(pred_tree_mask, mask_tensor2)
        
        #actual_vals=torch.masked_select(actual_tree_height, mask_tensor)
        #pred_vals=torch.masked_select(torch.squeeze(pred_tree_height),mask_tensor)
        
        #Height MAE
        #test_height_mae += mean_absolute_error(pred_vals,actual_vals)
        test_height_mae.append(mean_absolute_error(pred_tree_height,actual_tree_height).item())
        

# with torch.inference_mode():
#     # Get Average Height MAE
#     test_height_mae /= (len(my_dataset_test)/batch_size)
print("Tree Height MAE=",max(test_height_mae),min(test_height_mae),len(test_height_mae),sum(test_height_mae) / len(test_height_mae))

In [None]:
%%time
# test IoU thresholds
my_dataset_test = allbands_dataset_test(filelist_test=filelist_test)
my_dataloader_test = DataLoader(my_dataset_test, batch_size=16,shuffle=False, num_workers=0)

iou_score05 = BinaryJaccardIndex(threshold=.05).to(device)
iou_score1 = BinaryJaccardIndex(threshold=.1).to(device)
iou_score15 = BinaryJaccardIndex(threshold=.15).to(device)
iou_score2 = BinaryJaccardIndex(threshold=.2).to(device)
iou_score25 = BinaryJaccardIndex(threshold=.25).to(device)
iou_score3 = BinaryJaccardIndex(threshold=.3).to(device)
iou_score35 = BinaryJaccardIndex(threshold=.35).to(device)
iou_score4 = BinaryJaccardIndex(threshold=.4).to(device)
iou_score45 = BinaryJaccardIndex(threshold=.45).to(device)
iou_score5 = BinaryJaccardIndex(threshold=.5).to(device)
iou_score55 = BinaryJaccardIndex(threshold=.55).to(device)
iou_score6 = BinaryJaccardIndex(threshold=.6).to(device)
iou_score65 = BinaryJaccardIndex(threshold=.65).to(device)
iou_score7 = BinaryJaccardIndex(threshold=.7).to(device)
iou_score75 = BinaryJaccardIndex(threshold=.75).to(device)
iou_score8 = BinaryJaccardIndex(threshold=.8).to(device)
iou_score85 = BinaryJaccardIndex(threshold=.85).to(device)
iou_score9 = BinaryJaccardIndex(threshold=.9).to(device)
iou_score95 = BinaryJaccardIndex(threshold=.95).to(device)
test_tree_iou05,test_tree_iou1,test_tree_iou15,test_tree_iou2,test_tree_iou25,test_tree_iou3,test_tree_iou35 = 0,0,0,0,0,0,0
test_tree_iou4,test_tree_iou45,test_tree_iou5,test_tree_iou55,test_tree_iou6,test_tree_iou65,test_tree_iou7 = 0,0,0,0,0,0,0
test_tree_iou75,test_tree_iou8,test_tree_iou85,test_tree_iou9 ,test_tree_iou95 = 0,0,0,0,0


unet_1.eval()
with torch.inference_mode():
    for i_batch, sample_batched in enumerate(my_dataloader_test):
        X,Y= sample_batched
        X= X.to(device)
        
        Y[0]=Y[0].to(device) # probably a better way to do this....
        Y[1]=Y[1].to(device)
        
        X = Variable(X.float().cuda())
        Y[0] = Variable(Y[0].float().cuda())
        Y[1] = Variable(Y[1].float().cuda())
        
        # Forward pass
        pred_tree_height, pred_tree_mask = unet_1(X)
        
        #Tree Mask IOU
        test_tree_iou05 += iou_score05(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou1  += iou_score1(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou15 += iou_score15(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou2  += iou_score2(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou25 += iou_score25(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou3  += iou_score3(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou35 += iou_score35(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou4  += iou_score4(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou45 += iou_score45(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou5  += iou_score5(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou55 += iou_score55(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou6  += iou_score6(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou65 += iou_score65(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou7  += iou_score7(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou75 += iou_score75(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou8  += iou_score8(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou85 += iou_score85(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou9  += iou_score9(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        test_tree_iou95 += iou_score95(torch.squeeze(pred_tree_mask),Y[1].type(torch.LongTensor).to(device))
        
    
# get iou1 per batch
test_tree_iou05 = test_tree_iou05/(len(my_dataset_test)/batch_size)
test_tree_iou1  = test_tree_iou1/(len(my_dataset_test)/batch_size)
test_tree_iou15 = test_tree_iou15/(len(my_dataset_test)/batch_size)
test_tree_iou2  = test_tree_iou2/(len(my_dataset_test)/batch_size)
test_tree_iou25 = test_tree_iou25/(len(my_dataset_test)/batch_size)
test_tree_iou3  = test_tree_iou3/(len(my_dataset_test)/batch_size)
test_tree_iou35 = test_tree_iou35/(len(my_dataset_test)/batch_size)
test_tree_iou4  = test_tree_iou4/(len(my_dataset_test)/batch_size)
test_tree_iou45 = test_tree_iou45/(len(my_dataset_test)/batch_size)
test_tree_iou5  = test_tree_iou5/(len(my_dataset_test)/batch_size)
test_tree_iou55 = test_tree_iou55/(len(my_dataset_test)/batch_size)
test_tree_iou6  = test_tree_iou6/(len(my_dataset_test)/batch_size)
test_tree_iou65 = test_tree_iou65/(len(my_dataset_test)/batch_size)
test_tree_iou7  = test_tree_iou7/(len(my_dataset_test)/batch_size)
test_tree_iou75 = test_tree_iou75/(len(my_dataset_test)/batch_size)
test_tree_iou8  = test_tree_iou8/(len(my_dataset_test)/batch_size)
test_tree_iou85 = test_tree_iou85/(len(my_dataset_test)/batch_size)
test_tree_iou9  = test_tree_iou9/(len(my_dataset_test)/batch_size)
test_tree_iou95 = test_tree_iou95/(len(my_dataset_test)/batch_size)

    
print("Tree IoU .05=",test_tree_iou05)
print("Tree IoU .1 =",test_tree_iou1 )
print("Tree IoU .15=",test_tree_iou15)
print("Tree IoU .2 =",test_tree_iou2 )
print("Tree IoU .25=",test_tree_iou25)
print("Tree IoU .3 =",test_tree_iou3 )
print("Tree IoU .35=",test_tree_iou35)
print("Tree IoU .4 =",test_tree_iou4 )
print("Tree IoU .45=",test_tree_iou45)
print("Tree IoU .5 =",test_tree_iou5 )
print("Tree IoU .55=",test_tree_iou55)
print("Tree IoU .6 =",test_tree_iou6 )
print("Tree IoU .65=",test_tree_iou65)
print("Tree IoU .7 =",test_tree_iou7 )
print("Tree IoU .75=",test_tree_iou75)
print("Tree IoU .8 =",test_tree_iou8 )
print("Tree IoU .85=",test_tree_iou85)
print("Tree IoU .9 =",test_tree_iou9 )
print("Tree IoU .95=",test_tree_iou95)

In [None]:
# Plot Losses/Metrics
plt.subplot(1, 2, 1)
plt.plot(range(1,30001,203),height_loss,"rx-",label="tree height loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.tight_layout()
plt.legend(loc="upper right")

plt.subplot(1, 2, 2)
plt.plot(range(1,30001,203),tmask_loss,"bx-",label="tree mask loss")
plt.xlabel("iteration")
plt.ylabel("loss")

plt.legend(loc="upper right")




plt.show()

plt.plot(range(1,50,1),height_mae,"m--",label="Tree height MAE")
plt.xlabel("epoch")
plt.legend(loc="upper right")

plt.show()

plt.plot(range(1,50,1),tmask_iou,"m--",label="Tree Mask IoU")
plt.xlabel("epoch")
plt.legend(loc="upper left")

plt.show()

tota_loss = np.add(np.array(height_loss), np.array(tmask_loss))

plt.plot(range(1,30001,203),tota_loss,"m--",label="total loss")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.legend(loc="upper right")
plt.show()



###############################################################################################################################

# Run Inference on 2021 data

In [None]:
# Load In final chosen models

# Set up a version with shared encoding paths but different encoding paths
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_partiallyshared():
#https://github.com/usuyama/pytorch-unet/blob/master/pytorch_unet.py
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )  
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            self.upsample1a = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1b = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1c = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            
            self.upsample2a = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2b = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2c = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            
            self.upsample3a = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3b = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3c = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3a = double_conv(128+128, 128)
            self.dconv_up3b = double_conv(128+128, 128)
            self.dconv_up3c = double_conv(128+128, 128)
            
            self.dconv_up2a = double_conv(64 + 64, 64)
            self.dconv_up2b = double_conv(64 + 64, 64)
            self.dconv_up2c = double_conv(64 + 64, 64)
            
            self.dconv_up1a = double_conv(32 + 32, 32)
            self.dconv_up1b = double_conv(32 + 32, 32)
            self.dconv_up1c = double_conv(32 + 32, 32)
           # self.dconv_up1 = double_conv(128 + 64, 64)

            self.conv_lasta = nn.Conv2d(32, 1, 1)
            self.conv_lastb = nn.Conv2d(32, 1, 1)
            self.conv_lastc = nn.Conv2d(32, 1, 1)
                    
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            encoder_end = self.dconv_down4(x)
            #x = self.maxpool(conv4)
            
            # Now the model should split into three branches

            # Tree Height
            x1 = self.upsample1a(encoder_end)
            x1 = torch.cat([x1, conv3], dim=1)
            x1 = self.dconv_up3a(x1)

            x1 = self.upsample2a(x1)
            x1 = torch.cat([x1, conv2], dim=1)
            x1 = self.dconv_up2a(x1)

            x1 = self.upsample3a(x1)
            x1 = torch.cat([x1, conv1], dim=1)
            x1 = self.dconv_up1a(x1)
            out_tree_height = self.conv_lasta(x1) # looks like i don't need any additional activation here for linear

            # Tree Mask
            x2 = self.upsample1b(encoder_end)
            x2 = torch.cat([x2, conv3], dim=1)
            x2 = self.dconv_up3b(x2)

            x2 = self.upsample2b(x2)
            x2 = torch.cat([x2, conv2], dim=1)
            x2 = self.dconv_up2b(x2)

            x2 = self.upsample3b(x2)
            x2 = torch.cat([x2, conv1], dim=1)
            x2 = self.dconv_up1b(x2)
            out_tree_mask = self.sigmoid(self.conv_lastb(x2))            
     
            
            return [out_tree_height, out_tree_mask]
    model=UNet()
    return model


unet_1 = defineUNetModel_partiallyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_mtloss_partshared_manual.pt"))

In [None]:
# get filepaths for inferences images
inputPath="2021_predictions" # 12972
filelist = []

# Load the images, and append them to a list.
for filepath in os.listdir(inputPath):
    if filepath.endswith((".tif")):
    #print(filepath)
        tempfile=inputPath+'/{0}'.format(filepath)
        filelist.append(tempfile)

len(filelist)

In [None]:
def custom_replace(tensor, cutpoint):
    res = tensor.clone()
    res[tensor>=cutpoint] = 1
    res[tensor<cutpoint] = 0
    return res

In [None]:
%%time
# i need to get each tif, append the new raster layers, remove the bands i dont need anymore (maybe keep rgb), and then save
inputPath="D:/final_data/2021_predictions/"

driver = gdal.GetDriverByName("GTiff")
driver.Register()
i=0
# Load the images, and append them to a list.
for filepath in os.listdir(inputPath):
    if filepath.endswith((".tif")):
        print(filepath)
        i=i+1
        print(i)
        images = []
        dataset = gdal.Open(inputPath+'/{0}'.format(filepath))
        gt = dataset.GetGeoTransform()
        proj = dataset.GetProjection()
        
        image = dataset.ReadAsArray()  # Returned image is a NumPy array with shape (16, 60, 60) for example.
        # predict values based on the two different models
        images.append(image)
        image = np.stack(images, axis= 0)
        X=image[:,:14,:,:].copy() # separate out the band values
        X[X  < .0000001] = 0
        X = np.transpose(X, axes=[0, 2, 3, 1])
        
        # normalize values of the input data to 0,1
        X = X/X.max(axis=(3),keepdims=True)
        
        X = np.transpose(X, axes=[0,3,1,2])
        X= torch.from_numpy(X)
        X= X.to(device)
        X = Variable(X.float().cuda())
        unet_1.eval()
        with torch.inference_mode():
            pred_tree_height, pred_tree_mask = unet_1(X)

        pred_tree_mask = np.asarray(pred_tree_mask.squeeze().cpu())
        pred_tree_mask[pred_tree_mask  >= .4] = 1
        pred_tree_mask[pred_tree_mask  < .4 ] = 0
        
        pred_tree_height = np.asarray(pred_tree_height.squeeze().cpu())*650
        pred_tree_height[pred_tree_height  < 0 ] = 0
        
        outcome_data=np.stack([pred_tree_height,pred_tree_mask])

        
        # save solution
        save_raster1 = driver.Create('2021_predictions_results_theight'+'/{0}'.format(filepath), 
                                    xsize=240, ysize=240, bands = 1)

        save_raster1.SetGeoTransform(gt)
        save_raster1.SetProjection(proj) 
        
        outband_1 = save_raster1.GetRasterBand(1)
        outband_1.WriteArray(outcome_data[0].astype(np.float32))
        outband_1.SetNoDataValue(np.nan)
        outband_1.FlushCache()
        #outband_1 = None
        save_raster1 = None
        
        # save solution
        save_raster2 = driver.Create('2021_predictions_results_tbinary'+'/{0}'.format(filepath), 
                                    xsize=240, ysize=240, bands = 1)

        save_raster2.SetGeoTransform(gt)
        save_raster2.SetProjection(proj) 
        
        outband_1 = save_raster2.GetRasterBand(1)
        outband_1.WriteArray(outcome_data[1].astype(np.float32))
        outband_1.SetNoDataValue(np.nan)
        outband_1.FlushCache()        

        #outband_5 = None
        
        save_raster2 = None

# CPU times: total: 15min 33s
# Wall time: 52min 3s