In [1]:
import numpy as np
import glob
from osgeo import gdal, osr
import pyproj
from shutil import copyfile

import PIL
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import os

PIL.Image.MAX_IMAGE_PIXELS = 933120000

In [2]:
class GeorectDataset(Dataset):
    def __init__(self, base_raster, image_paths,size_x = 3000, size_y=2000, scaled=True):
                
        self.toTensor = transforms.Compose([
            transforms.PILToTensor(),
            transforms.Resize((size_x, size_y))
        ])
        
        self.base_raster = base_raster
        self.base = Image.open(self.base_raster)
        shape = np.asarray(self.base).shape
        self.base = self.toTensor(self.base)[0, :, :].float().unsqueeze(0)
        
        scaling = np.array([
                [size_x / shape[0],                   0],
                [                0,   size_y / shape[1]]])
        
        self.image_paths = image_paths
        raw_affine = list()
        scaled_affine = list()
        scales = list()
        
        for path in tqdm(self.image_paths):
            affine = np.load(path[:-4]+"_affine.npy").reshape(2, 3)
            
            sc_affine = scaling @ affine
            
            raw_affine.append(affine.flatten())
            scaled_affine.append(sc_affine.flatten())
            scales.append(scaling.flatten())
        
        if scaled:
            self.affine = np.vstack(scaled_affine)
        else:
            self.affine = np.vstack(raw_affine)
            
        self.raw_affine = raw_affine
        self.scaled_affine = scaled_affine
        self.scales = scales
        
    def __getitem__(self, index):
        img1_path = self.image_paths[index]
        img1 = self.toTensor(Image.open(img1_path))
        img1 = torch.where(img1, 1, 255).float()
        affine = self.affine[index, :]
        return self.base, img1, affine

    def __len__(self):
        return len(self.image_paths)

In [3]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        
    def forward(self, x):
        batch_size = x.shape[0]
        channels = x.shape[1]
        y = self.avg_pool(x).view(batch_size, channels)
        y = F.relu(self.fc1(y))
        y = self.fc2(y).sigmoid()
        if x.dim() == 4:
            y = y.view(batch_size, channels, 1, 1)
        elif x.dim() == 3:
            y = y.view(batch_size, channels, 1)
        
        return x * y
        
        # Reshape to (batch_size, channels, 1, 1)
        out = out.unsqueeze(2)
        out = out.unsqueeze(3)
        # Multiply by input tensor to compute output
        out = x * out.expand_as(x)
        return out


In [4]:
class AttentionBlock(nn.Module):
    """
    Attention block module that takes in a 3D tensor and applies an attention mechanism to the channels dimension.
    """
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        
        self.in_channels = in_channels
        
        # Define the attention mechanism layers
        self.query_conv = nn.Conv2d(in_channels, in_channels//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        # Compute the query, key, and value tensors
        batch_size, _, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width*height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width*height)
        value = self.value_conv(x).view(batch_size, -1, width*height)
        
        # Compute the attention map and attention output
        attention = torch.bmm(query, key)
        attention = self.softmax(attention)
        attention_output = torch.bmm(value, attention.permute(0, 2, 1))
        attention_output = attention_output.view(batch_size, self.in_channels, width, height)
        
        # Apply the attention output to the input tensor
        out = self.gamma * attention_output + x
        
        return out

In [5]:
class GeorectNet(nn.Module):
    def __init__(self, num_scales = 3, size=(600, 400), channels=8, output_size=4):
        super(GeorectNet, self).__init__()
        
        # PARAMS
        self.size = size
        self.num_scales = num_scales        
        self.channels = channels
        
        # UPSAMPLING FOR LOWER-RES PYRAMID
        self.upsample = nn.Upsample(size=size)
        
        # ATTENTION BLOCK
        self.AttentionBlock = AttentionBlock(self.channels)
        self.AttentionBlock_pyramid = AttentionBlock(self.channels * self.num_scales * 2)
        
        # SE BLOCK
        self.SEBlock = SEBlock(self.channels)
        self.SE_pyramid = SEBlock(self.channels * 2)
        
        # CONVOLUTIONAL LAYERS
        self.conv = nn.Sequential(
            nn.Conv2d(1, self.channels // 2, kernel_size=5, stride=3, padding=11),
            nn.MaxPool2d(kernel_size=4, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.channels // 2, self.channels, kernel_size=5, stride=3, padding=11),
            nn.MaxPool2d(kernel_size=4, stride=2),
            nn.ReLU(inplace=True),
        )
            
        fin_conv = nn.Conv2d(
            in_channels=self.channels, 
            out_channels=output_size, 
            kernel_size=size
        )
        
        class Flatten(nn.Module):
            def forward(self, x):
                return x.view(x.size(0), -1)
        
        self.fin_conv = nn.Sequential(
            fin_conv,
            Flatten()
        )

    def pyramidStep(self, x1, x2):
        x1 = self.AttentionBlock(self.SEBlock(self.conv(x1)))
        x2 = self.AttentionBlock(self.SEBlock(self.conv(x2)))
        x = torch.cat((x1, x2), dim=1)
        return x
    
    def forward(self, x1, x2):
        
        outputs = list()
        
        # FOR EACH SCALE IN PYRAMID
        for i in range(self.num_scales):
            
            # DOWNSAMPLE
            scale = 1.0 / 2**i
            size = (int(x1.size()[-2] * scale), int(x1.size()[-1] * scale))
            x1_scale = F.interpolate(x1, size=size, mode='bilinear')
            x2_scale = F.interpolate(x2, size=size, mode='bilinear')
            
            # RUN SINGLE STEP 
            output_transformation = self.pyramidStep(x1_scale, x2_scale)
            
            # UPSAMPLE AND APPEND
            output_transformation = self.upsample(output_transformation)
            outputs.append(output_transformation)
            
        # OUTPUT ATTENTION AND SE BLOCKS
        outputs = torch.cat(outputs, 1)
        print(outputs.shape)
        outputs = self.AttentionBlock_pyramid(outputs)
        outputs = self.SE_pyramid(outputs)
        
        # FINAL CONVOLUTION
        current_affine = self.fin_conv(outputs)
        
        cuda0 = torch.device('cuda:0')
        zeros = torch.zeros(current_affine.shape[0], device=cuda0)
        
        affine = torch.vstack((current_affine[:, 0], 
                               zeros, 
                               current_affine[:, 1], 
                               zeros, 
                               current_affine[:, 2], 
                               current_affine[:, 3])).T
        
        return affine

In [6]:
def get_image_corners(image):
    """
    Computes the corner coordinates of an image as a tensor.
    
    Args:
        image (torch.Tensor): A 2D or 3D tensor representing an image.
    
    Returns:
        corners (torch.Tensor): A tensor of shape (4, 2) containing the corner coordinates of the image,
            in the order (top-left, top-right, bottom-left, bottom-right).
    """
    # Convert the input tensor to float and normalize it
    image = image.float() / 255.0
    
    # Compute the corner coordinates of the image
    height, width = image.shape[-2], image.shape[-1]
    corners = torch.tensor([[0, 0, 1], [width, 0, 1], [0, height, 1], [width, height, 1]], dtype=torch.float32, device=image.device)
    
    return corners

def corner_loss(src_affine_params, tgt_affine_params, image):
    """
    Calculates the loss based on the difference between the corners of the source and target images,
    after applying the affine transformations defined by the input parameters to the original image.
    
    Args:
        src_affine_params (torch.Tensor): A 2D tensor of shape (batch_size, 6) containing the affine transformation
            parameters for the source image, in the order (a11, a12, tx, a21, a22, ty).
        tgt_affine_params (torch.Tensor): A 2D tensor of shape (batch_size, 6) containing the affine transformation
            parameters for the target image, in the same order as src_affine_params.
        image (torch.Tensor): A 4D tensor of shape (batch_size, channels, height, width) containing the original
            input images.
    
    Returns:
        corner_loss (torch.Tensor): A scalar tensor representing the loss.
    """
    
    grid = get_image_corners(image).T
    
    corner_loss = 0
    
    for i in range(image.shape[0]):
        
        
        src_affine_matrix = torch.vstack((torch.squeeze(src_affine_params[i]).view(2, 3), 
                                          torch.cuda.FloatTensor([0,0,1])))
        tgt_affine_matrix = torch.vstack((torch.squeeze(tgt_affine_params[i]).view(2, 3), 
                                          torch.cuda.FloatTensor([0,0,1])))

        src_corners = torch.matmul(src_affine_matrix, grid)
        tgt_corners = torch.matmul(tgt_affine_matrix, grid)
        curr_loss =  F.l1_loss(src_corners, tgt_corners)
        corner_loss = corner_loss + curr_loss
        
    corner_loss = corner_loss / image.shape[0]
    
    return corner_loss

In [7]:
def savetextfile(epoch, srcaffine, tgtaffine, intermediate_dir ):
    tempt = np.dstack((srcaffine.detach().cpu(), tgtaffine.detach().cpu()))
    
    with open(f"{intermediate_dir}/{epoch}.txt", 'w') as f:
       for row in tempt:
           np.savetxt(f, row)

In [8]:
def train_Georect_net(net,
                      Georect_dataset, 
                      batch_size=1, 
                      num_epochs=300, 
                      learning_rate=0.0001, 
                      validation_split=0.2, 
                      device='cuda'):
    
    intermediate_dir = r"C:\Users\fhacesga\Desktop\FIRMsDigitizing\RECTDNN\intermediate_outputs"
    
    if not os.path.isdir(intermediate_dir):
        os.makedirs(intermediate_dir)
    
    # Split dataset into training and validation sets
    dataset_size = len(Georect_dataset)
    val_size = int(dataset_size * validation_split)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(Georect_dataset, [train_size, val_size])

    # Create data loaders for training and validation sets
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    # Define loss function and optimizer
    # criterion = nn.BCEWithLogitsLoss()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    bestval = torch.tensor(float('inf'))
    
    # Training loop
    for epoch in range(num_epochs):
        # Train the network
        net.train()
        train_loss = 0.0
        for batch_idx, (img1, img2, label) in tqdm(enumerate(train_loader), total=len(train_loader)):
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            optimizer.zero_grad()
            
            output = net(img1, img2)
            
            loss = criterion(output.float(), label.squeeze().float())
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Evaluate the network on the validation set
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (img1, img2, label) in enumerate(val_loader):
                img1, img2, label = img1.to(device), img2.to(device), label.to(device)
                output = net(img1, img2)
                savetextfile(epoch, label, output, intermediate_dir )
                corloss = corner_loss(output.float(), label.float(), img2)
                loss = criterion(output.float(), label.squeeze().float())
                val_loss += corloss
                
            if bestval > val_loss:
                bestval = val_loss
                checkpoint = {
                    'epoch': epoch, 
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }
                torch.save(checkpoint, "checkpoint.pth")

        # Print training and validation loss for the epoch
        print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Corners: {:.4f}'.format(epoch+1, num_epochs, train_loss/len(train_loader), val_loss/len(val_loader), corloss))

In [9]:
trainloc = r"C:\Users\fhacesga\Desktop\FIRMsDigitizing\RECTDNN\TrainDataset\\"
base_loc = r"D:\FloodChange\BaseRaster\BaseTest.tif"

files = glob.glob(f"{trainloc}*.tif")
train_dataset = GeorectDataset(base_loc, files)

  0%|          | 0/110 [00:00<?, ?it/s]

In [10]:
# Define Georect network and move it to device
net = GeorectNet().to('cuda')
train_Georect_net(net, train_dataset)



  0%|          | 0/88 [00:00<?, ?it/s]

  return F.linear(input, self.weight, self.bias)


torch.Size([1, 48, 600, 400])


OutOfMemoryError: CUDA out of memory. Tried to allocate 214.58 GiB (GPU 0; 8.00 GiB total capacity; 506.41 MiB already allocated; 6.58 GiB free; 552.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF