In [None]:
# Load in necessary packages
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from osgeo import gdal
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import pandas as pd

from torchinfo import summary
import gc
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
from torchmetrics import JaccardIndex
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics import MeanAbsoluteError
from torchmetrics.functional import dice_score
from scipy.ndimage import convolve
from tqdm.auto import tqdm # progress bar
from timeit import default_timer as timer
def print_train_time(start:float,
                    end:float,
                    device: torch.device= None):
    total_time=end-start
    print(f"Train time on {device} : {total_time:.3f} seconds")
    return total_time



In [None]:
### Create Dataloaders using the file paths ###
device = "cuda" if torch.cuda.is_available() else "cpu"

# Create two DataLoaders, one for training and one for test
class allbands_dataset_train(Dataset):
    def __init__(self,filelist_train, transform=None):
        """
        Args:
            filelist (string): List with all of the file paths
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform
        self.filelist = filelist_train           

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()    
            
        # Generate data
        dataset = np.load(self.filelist[idx])
        
        # X
        X=dataset[:14] # separate out the band values

        # canopy_height,tree/not tree,ndvi
        out_tree_height = dataset[14]         
        out_tree_mask = dataset[15]
        
        preds=[out_tree_height, out_tree_mask]
        
        return [X,preds]
    
class allbands_dataset_test(Dataset):
    def __init__(self,filelist_test, transform=None):
        """
        Args:
            filelist (string): List with all of the file paths
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform
        self.filelist = filelist_test           

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()    
        # Generate data
        dataset = np.load(self.filelist[idx])
        
        # X
        X=dataset[:14] # separate out the band values

        # canopy_height,tree/not tree,ndvi
        out_tree_height = dataset[14]         
        out_tree_mask = dataset[15]
        
        preds=[out_tree_height, out_tree_mask]
        
        return [X,preds]
    

# Muti-task Manual Loss- Not Shared Decoding

In [None]:
# Set up a version with shared encoding paths but different encoding paths
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def defineUNetModel_partiallyshared():
    def double_conv0(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding="same"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )  
    def double_conv(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.LeakyReLU(inplace=True)
        )
    
    class UNet(nn.Module):
        def __init__(self):
            super().__init__()

            self.dconv_down1 = double_conv0(14, 32)
            self.dconv_down2 = double_conv(32, 64)
            self.dconv_down3 = double_conv(64, 128)
            self.dconv_down4 = nn.Sequential(
                nn.Conv2d(128, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(p=0.2),
                nn.Conv2d(256, 256, 3, padding=1),
                nn.LeakyReLU(inplace=True))

            self.maxpool = nn.MaxPool2d(2)
            self.maxpool3 = nn.MaxPool2d(3)
            
            self.upsample1a = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1b = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            self.upsample1c = nn.ConvTranspose2d(256, 128, 3, stride=3, padding=0)
            
            self.upsample2a = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2b = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            self.upsample2c = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
            
            self.upsample3a = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3b = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            self.upsample3c = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
            
            
            self.dconv_up3a = double_conv(128+128, 128)
            self.dconv_up3b = double_conv(128+128, 128)
            self.dconv_up3c = double_conv(128+128, 128)
            
            self.dconv_up2a = double_conv(64 + 64, 64)
            self.dconv_up2b = double_conv(64 + 64, 64)
            self.dconv_up2c = double_conv(64 + 64, 64)
            
            self.dconv_up1a = double_conv(32 + 32, 32)
            self.dconv_up1b = double_conv(32 + 32, 32)
            self.dconv_up1c = double_conv(32 + 32, 32)
           # self.dconv_up1 = double_conv(128 + 64, 64)

            self.conv_lasta = nn.Conv2d(32, 1, 1)
            self.conv_lastb = nn.Conv2d(32, 1, 1)
            self.conv_lastc = nn.Conv2d(32, 1, 1)
                    
            self.linear = nn.MaxPool2d(2)
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            conv1 = self.dconv_down1(x)
            x = self.maxpool(conv1)

            conv2 = self.dconv_down2(x)
            x = self.maxpool(conv2)

            conv3 = self.dconv_down3(x)
            x = self.maxpool3(conv3)

            encoder_end = self.dconv_down4(x)
            #x = self.maxpool(conv4)
            
            # Now the model should split into three branches

            # Tree Height
            x1 = self.upsample1a(encoder_end)
            x1 = torch.cat([x1, conv3], dim=1)
            x1 = self.dconv_up3a(x1)

            x1 = self.upsample2a(x1)
            x1 = torch.cat([x1, conv2], dim=1)
            x1 = self.dconv_up2a(x1)

            x1 = self.upsample3a(x1)
            x1 = torch.cat([x1, conv1], dim=1)
            x1 = self.dconv_up1a(x1)
            out_tree_height = self.conv_lasta(x1) # looks like i don't need any additional activation here for linear

            # Tree Mask
            x2 = self.upsample1b(encoder_end)
            x2 = torch.cat([x2, conv3], dim=1)
            x2 = self.dconv_up3b(x2)

            x2 = self.upsample2b(x2)
            x2 = torch.cat([x2, conv2], dim=1)
            x2 = self.dconv_up2b(x2)

            x2 = self.upsample3b(x2)
            x2 = torch.cat([x2, conv1], dim=1)
            x2 = self.dconv_up1b(x2)
            out_tree_mask = self.sigmoid(self.conv_lastb(x2))            
     
            
            return [out_tree_height, out_tree_mask]
    model=UNet()
    return model

torch.manual_seed(42)
unet_1 = defineUNetModel_partiallyshared().to(device)
unet_1.load_state_dict(torch.load("models/pytorch_paper_final/pytorch_mtloss_partshared_manual.pt"))

###############################################################################################################################

# Run Inference on 2021 data

In [None]:
# get filepaths for inferences images
inputPath="2021_predictions" # 12972
filelist = []

# Load the images, and append them to a list.
for filepath in os.listdir(inputPath):
    if filepath.endswith((".tif")):
    #print(filepath)
        tempfile=inputPath+'/{0}'.format(filepath)
        filelist.append(tempfile)

len(filelist)

In [None]:
%%time
# i need to get each tif, append the new raster layers, remove the bands i dont need anymore (maybe keep rgb), and then save
inputPath="D:/final_data/2021_predictions/"

driver = gdal.GetDriverByName("GTiff")
driver.Register()
i=0
# Load the images, and append them to a list.
for filepath in os.listdir(inputPath):
    if filepath.endswith((".tif")):
        print(filepath)
        i=i+1
        print(i)
        images = []
        dataset = gdal.Open(inputPath+'/{0}'.format(filepath))
        gt = dataset.GetGeoTransform()
        proj = dataset.GetProjection()
        
        image = dataset.ReadAsArray()  # Returned image is a NumPy array with shape (16, 60, 60) for example.
        # predict values based on the two different models
        images.append(image)
        image = np.stack(images, axis= 0)
        X=image[:,:14,:,:].copy() # separate out the band values
        X[X  < .0000001] = 0
        X = np.transpose(X, axes=[0, 2, 3, 1])
        
        # normalize values of the input data to 0,1
        X = X/X.max(axis=(3),keepdims=True)
        
        X = np.transpose(X, axes=[0,3,1,2])
        X= torch.from_numpy(X)
        X= X.to(device)
        X = Variable(X.float().cuda())
        unet_1.eval()
        with torch.inference_mode():
            pred_tree_height, pred_tree_mask = unet_1(X)

        pred_tree_mask = np.asarray(pred_tree_mask.squeeze().cpu())
        pred_tree_mask[pred_tree_mask  >= .4] = 1
        pred_tree_mask[pred_tree_mask  < .4 ] = 0
        
        pred_tree_height = np.asarray(pred_tree_height.squeeze().cpu())*650
        pred_tree_height[pred_tree_height  < 0 ] = 0
        
        outcome_data=np.stack([pred_tree_height,pred_tree_mask])

        
        # save solution
        save_raster1 = driver.Create('2021_predictions_results_theight'+'/{0}'.format(filepath), 
                                    xsize=240, ysize=240, bands = 1)

        save_raster1.SetGeoTransform(gt)
        save_raster1.SetProjection(proj) 
        
        outband_1 = save_raster1.GetRasterBand(1)
        outband_1.WriteArray(outcome_data[0].astype(np.float32))
        outband_1.SetNoDataValue(np.nan)
        outband_1.FlushCache()
        #outband_1 = None
        save_raster1 = None
        
        # save solution
        save_raster2 = driver.Create('2021_predictions_results_tbinary'+'/{0}'.format(filepath), 
                                    xsize=240, ysize=240, bands = 1)

        save_raster2.SetGeoTransform(gt)
        save_raster2.SetProjection(proj) 
        
        outband_1 = save_raster2.GetRasterBand(1)
        outband_1.WriteArray(outcome_data[1].astype(np.float32))
        outband_1.SetNoDataValue(np.nan)
        outband_1.FlushCache()        

        #outband_5 = None
        
        save_raster2 = None

# CPU times: total: 15min 33s
# Wall time: 52min 3s